diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index 036da71a5b5..934a84996b8 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -24,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs/neutrinonotify" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -63,6 +64,10 @@ type Config struct { // state. ChanStateDB *channeldb.ChannelStateDB + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] + // BlockCache is the main cache for storing block information. BlockCache *blockcache.BlockCache diff --git a/channeldb/channel.go b/channeldb/channel.go index cb2490fcd7e..40bab3d700d 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -251,6 +251,10 @@ type chanAuxData struct { // tapscriptRoot is the optional Tapscript root the channel funding // output commits to. tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte] + + // customBlob is an optional TLV encoded blob of data representing + // custom channel funding information. + customBlob tlv.OptionalRecordT[tlv.TlvType7, tlv.Blob] } // encode serializes the chanAuxData to the given io.Writer. @@ -269,6 +273,9 @@ func (c *chanAuxData) encode(w io.Writer) error { tlvRecords = append(tlvRecords, root.Record()) }, ) + c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType7, tlv.Blob]) { + tlvRecords = append(tlvRecords, blob.Record()) + }) // Create the tlv stream. tlvStream, err := tlv.NewStream(tlvRecords...) @@ -283,6 +290,7 @@ func (c *chanAuxData) encode(w io.Writer) error { func (c *chanAuxData) decode(r io.Reader) error { memo := c.memo.Zero() tapscriptRoot := c.tapscriptRoot.Zero() + blob := c.customBlob.Zero() // Create the tlv stream. tlvStream, err := tlv.NewStream( @@ -292,6 +300,7 @@ func (c *chanAuxData) decode(r io.Reader) error { c.realScid.Record(), memo.Record(), tapscriptRoot.Record(), + blob.Record(), ) if err != nil { return err @@ -308,6 +317,9 @@ func (c *chanAuxData) decode(r io.Reader) error { if _, ok := tlvs[tapscriptRoot.TlvType()]; ok { c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot) } + if _, ok := tlvs[c.customBlob.TlvType()]; ok { + c.customBlob = tlv.SomeRecordT(blob) + } return nil } @@ -325,6 +337,9 @@ func (c *chanAuxData) toOpenChan(o *OpenChannel) { c.tapscriptRoot.WhenSomeV(func(h [32]byte) { o.TapscriptRoot = fn.Some[chainhash.Hash](h) }) + c.customBlob.WhenSomeV(func(blob tlv.Blob) { + o.CustomBlob = fn.Some(blob) + }) } // newChanAuxDataFromChan creates a new chanAuxData from the given channel. @@ -354,6 +369,11 @@ func newChanAuxDataFromChan(openChan *OpenChannel) *chanAuxData { tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h), ) }) + openChan.CustomBlob.WhenSome(func(blob tlv.Blob) { + c.customBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType7](blob), + ) + }) return c } @@ -607,6 +627,74 @@ type ChannelConfig struct { HtlcBasePoint keychain.KeyDescriptor } +// commitAuxData stores all the optional data that may be stored as a TLV stream +// at the _end_ of the normal serialized commit on disk. +type commitAuxData struct { + // customBlob is a custom blob that may store extra data for custom + // channels. + customBlob tlv.OptionalRecordT[tlv.TlvType1, tlv.Blob] +} + +// encode encodes the aux data into the passed io.Writer. +func (c *commitAuxData) encode(w io.Writer) error { + var tlvRecords []tlv.Record + c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType1, tlv.Blob]) { + tlvRecords = append(tlvRecords, blob.Record()) + }) + + // Create the tlv stream. + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// decode attempts to ecode the aux data from the passed io.Reader. +func (c *commitAuxData) decode(r io.Reader) error { + blob := c.customBlob.Zero() + + tlvStream, err := tlv.NewStream( + blob.Record(), + ) + if err != nil { + return err + } + + tlvs, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + if _, ok := tlvs[c.customBlob.TlvType()]; ok { + c.customBlob = tlv.SomeRecordT(blob) + } + + return nil +} + +// toChanCommit extracts the optional data stored in the commitAuxData struct +// and stores it in the ChannelCommitment. +func (c *commitAuxData) toChanCommit(commit *ChannelCommitment) { + c.customBlob.WhenSomeV(func(blob tlv.Blob) { + commit.CustomBlob = fn.Some(blob) + }) +} + +// newCommitAuxData creates an aux data struct from the normal chan commitment. +func newCommitAuxData(commit *ChannelCommitment) commitAuxData { + var c commitAuxData + + commit.CustomBlob.WhenSome(func(blob tlv.Blob) { + c.customBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](blob), + ) + }) + + return c +} + // ChannelCommitment is a snapshot of the commitment state at a particular // point in the commitment chain. With each state transition, a snapshot of the // current state along with all non-settled HTLCs are recorded. These snapshots @@ -673,6 +761,11 @@ type ChannelCommitment struct { // able by us. CommitTx *wire.MsgTx + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This may track some custom + // specific state for this given commitment. + CustomBlob fn.Option[tlv.Blob] + // CommitSig is one half of the signature required to fully complete // the script for the commitment transaction above. This is the // signature signed by the remote party for our version of the @@ -682,9 +775,6 @@ type ChannelCommitment struct { // Htlcs is the set of HTLC's that are pending at this particular // commitment height. Htlcs []HTLC - - // TODO(roasbeef): pending commit pointer? - // * lets just walk through } // ChannelStatus is a bit vector used to indicate whether an OpenChannel is in @@ -982,6 +1072,12 @@ type OpenChannel struct { // funding output. TapscriptRoot fn.Option[chainhash.Hash] + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob fn.Option[tlv.Blob] + // TODO(roasbeef): eww Db *ChannelStateDB @@ -2793,6 +2889,16 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl } } + // We'll also encode the commit aux data stream here. We do this here + // rather than above (at the call to serializeChanCommit), to ensure + // backwards compat for reads to existing non-custom channels. + // + // TODO(roasbeef): migrate it after all? + auxData := newCommitAuxData(&diff.Commitment) + if err := auxData.encode(w); err != nil { + return fmt.Errorf("unable to write aux data: %w", err) + } + return nil } @@ -2853,6 +2959,17 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { } } + // As a final step, we'll read out any aux commit data that we have at + // the end of this byte stream. We do this here to ensure backward + // compatibility, as otherwise we risk erroneously reading into the + // wrong field. + var auxData commitAuxData + if err := auxData.decode(r); err != nil { + return nil, fmt.Errorf("unable to decode aux data: %w", err) + } + + auxData.toChanCommit(&d.Commitment) + return &d, nil } @@ -3831,6 +3948,13 @@ func (c *OpenChannel) Snapshot() *ChannelSnapshot { }, } + localCommit.CustomBlob.WhenSome(func(blob tlv.Blob) { + blobCopy := make([]byte, len(blob)) + copy(blobCopy, blob) + + snapshot.ChannelCommitment.CustomBlob = fn.Some(blobCopy) + }) + // Copy over the current set of HTLCs to ensure the caller can't mutate // our internal state. snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs)) @@ -4222,6 +4346,12 @@ func putChanCommitment(chanBucket kvdb.RwBucket, c *ChannelCommitment, return err } + // Before we write to disk, we'll also write our aux data as well. + auxData := newCommitAuxData(c) + if err := auxData.encode(&b); err != nil { + return fmt.Errorf("unable to write aux data: %w", err) + } + return chanBucket.Put(commitKey, b.Bytes()) } @@ -4367,7 +4497,9 @@ func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { return c, nil } -func fetchChanCommitment(chanBucket kvdb.RBucket, local bool) (ChannelCommitment, error) { +func fetchChanCommitment(chanBucket kvdb.RBucket, + local bool) (ChannelCommitment, error) { + var commitKey []byte if local { commitKey = append(chanCommitmentKey, byte(0x00)) @@ -4381,7 +4513,23 @@ func fetchChanCommitment(chanBucket kvdb.RBucket, local bool) (ChannelCommitment } r := bytes.NewReader(commitBytes) - return deserializeChanCommit(r) + chanCommit, err := deserializeChanCommit(r) + if err != nil { + return ChannelCommitment{}, fmt.Errorf("unable to decode "+ + "chan commit: %w", err) + } + + // We'll also check to see if we have any aux data stored as the end of + // the stream. + var auxData commitAuxData + if err := auxData.decode(r); err != nil { + return ChannelCommitment{}, fmt.Errorf("unable to decode "+ + "chan aux data: %w", err) + } + + auxData.toChanCommit(&chanCommit) + + return chanCommit, nil } func fetchChanCommitments(chanBucket kvdb.RBucket, channel *OpenChannel) error { diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 2389015cfa3..8eca3e7dd4c 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -337,6 +337,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { FeePerKw: btcutil.Amount(5000), CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), + CustomBlob: fn.Some([]byte{1, 2, 3}), }, RemoteCommitment: ChannelCommitment{ CommitHeight: 0, @@ -346,6 +347,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { FeePerKw: btcutil.Amount(5000), CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), + CustomBlob: fn.Some([]byte{4, 5, 6}), }, NumConfsRequired: 4, RemoteCurrentRevocation: privKey.PubKey(), @@ -360,6 +362,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { InitialRemoteBalance: lnwire.MilliSatoshi(3000), Memo: []byte("test"), TapscriptRoot: fn.Some(tapscriptRoot), + CustomBlob: fn.Some([]byte{1, 2, 3}), } } @@ -567,24 +570,32 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment, r *RevocationLog) { + t.Helper() + // Check the common fields. require.EqualValues( - t, r.CommitTxHash, c.CommitTx.TxHash(), "CommitTx mismatch", + t, r.CommitTxHash.Val, c.CommitTx.TxHash(), "CommitTx mismatch", ) // Now check the common fields from the HTLCs. require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch") for i, rHtlc := range r.HTLCEntries { cHtlc := c.Htlcs[i] - require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch") - require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(), - "Amt mismatch") - require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout, - "RefundTimeout mismatch") - require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex, - "OutputIndex mismatch") - require.Equal(t, rHtlc.Incoming, cHtlc.Incoming, - "Incoming mismatch") + require.Equal(t, rHtlc.RHash.Val[:], cHtlc.RHash[:], "RHash") + require.Equal( + t, rHtlc.Amt.Val.Int(), cHtlc.Amt.ToSatoshis(), "Amt", + ) + require.Equal( + t, rHtlc.RefundTimeout.Val, cHtlc.RefundTimeout, + "RefundTimeout", + ) + require.EqualValues( + t, rHtlc.OutputIndex.Val, cHtlc.OutputIndex, + "OutputIndex", + ) + require.Equal( + t, rHtlc.Incoming.Val, cHtlc.Incoming, "Incoming", + ) } } @@ -649,6 +660,7 @@ func TestChannelStateTransition(t *testing.T) { CommitTx: newTx, CommitSig: newSig, Htlcs: htlcs, + CustomBlob: fn.Some([]byte{4, 5, 6}), } // First update the local node's broadcastable state and also add a @@ -686,9 +698,14 @@ func TestChannelStateTransition(t *testing.T) { // have been updated. updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub) require.NoError(t, err, "unable to fetch updated channel") - assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment) + + assertCommitmentEqual( + t, &commitment, &updatedChannel[0].LocalCommitment, + ) + numDiskUpdates, err := updatedChannel[0].CommitmentHeight() require.NoError(t, err, "unable to read commitment height from disk") + if numDiskUpdates != uint64(commitment.CommitHeight) { t.Fatalf("num disk updates doesn't match: %v vs %v", numDiskUpdates, commitment.CommitHeight) @@ -791,10 +808,10 @@ func TestChannelStateTransition(t *testing.T) { // Check the output indexes are saved as expected. require.EqualValues( - t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, + t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex.Val, ) require.EqualValues( - t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, + t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex.Val, ) // The two deltas (the original vs the on-disk version) should @@ -836,10 +853,10 @@ func TestChannelStateTransition(t *testing.T) { // Check the output indexes are saved as expected. require.EqualValues( - t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, + t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex.Val, ) require.EqualValues( - t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, + t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex.Val, ) assertRevocationLogEntryEqual(t, &oldRemoteCommit, prevCommit) diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index f062ac0860e..4334ec64dd9 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -7,6 +7,7 @@ import ( "math" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -16,16 +17,15 @@ import ( const ( // OutputIndexEmpty is used when the output index doesn't exist. OutputIndexEmpty = math.MaxUint16 +) - // A set of tlv type definitions used to serialize the body of - // revocation logs to the database. - // - // NOTE: A migration should be added whenever this list changes. - revLogOurOutputIndexType tlv.Type = 0 - revLogTheirOutputIndexType tlv.Type = 1 - revLogCommitTxHashType tlv.Type = 2 - revLogOurBalanceType tlv.Type = 3 - revLogTheirBalanceType tlv.Type = 4 +type ( + // BigSizeAmount is a type alias for a TLV record of a btcutil.Amount. + BigSizeAmount = tlv.BigSizeT[btcutil.Amount] + + // BigSizeMilliSatoshi is a type alias for a TLV record of a + // lnwire.MilliSatoshi. + BigSizeMilliSatoshi = tlv.BigSizeT[lnwire.MilliSatoshi] ) var ( @@ -54,6 +54,74 @@ var ( ErrOutputIndexTooBig = errors.New("output index is over uint16") ) +// SparsePayHash is a type alias for a 32 byte array, which when serialized is +// able to save some space by not including an empty payment hash on disk. +type SparsePayHash [32]byte + +// NewSparsePayHash creates a new SparsePayHash from a 32 byte array. +func NewSparsePayHash(rHash [32]byte) SparsePayHash { + return SparsePayHash(rHash) +} + +// Record returns a tlv record for the SparsePayHash. +func (s *SparsePayHash) Record() tlv.Record { + // We use a zero for the type here, as this'll be used along with the + // RecordT type. + return tlv.MakeDynamicRecord( + 0, s, s.hashLen, + sparseHashEncoder, sparseHashDecoder, + ) +} + +// hashLen is used by MakeDynamicRecord to return the size of the RHash. +// +// NOTE: for zero hash, we return a length 0. +func (s *SparsePayHash) hashLen() uint64 { + if bytes.Equal(s[:], lntypes.ZeroHash[:]) { + return 0 + } + + return 32 +} + +// sparseHashEncoder is the customized encoder which skips encoding the empty +// hash. +func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the value is an empty hash, we will skip encoding it. + if bytes.Equal(v[:], lntypes.ZeroHash[:]) { + return nil + } + + vArray := (*[32]byte)(v) + + return tlv.EBytes32(w, vArray, buf) +} + +// sparseHashDecoder is the customized decoder which skips decoding the empty +// hash. +func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the length is zero, we will skip encoding the empty hash. + if l == 0 { + return nil + } + + vArray := (*[32]byte)(v) + + return tlv.DBytes32(r, vArray, buf, 32) +} + // HTLCEntry specifies the minimal info needed to be stored on disk for ALL the // historical HTLCs, which is useful for constructing RevocationLog when a // breach is detected. @@ -72,116 +140,84 @@ var ( // made into tlv records without further conversion. type HTLCEntry struct { // RHash is the payment hash of the HTLC. - RHash [32]byte + RHash tlv.RecordT[tlv.TlvType0, SparsePayHash] // RefundTimeout is the absolute timeout on the HTLC that the sender // must wait before reclaiming the funds in limbo. - RefundTimeout uint32 + RefundTimeout tlv.RecordT[tlv.TlvType1, uint32] // OutputIndex is the output index for this particular HTLC output // within the commitment transaction. // // NOTE: we use uint16 instead of int32 here to save us 2 bytes, which // gives us a max number of HTLCs of 65K. - OutputIndex uint16 + OutputIndex tlv.RecordT[tlv.TlvType2, uint16] // Incoming denotes whether we're the receiver or the sender of this // HTLC. - // - // NOTE: this field is the memory representation of the field - // incomingUint. - Incoming bool + Incoming tlv.RecordT[tlv.TlvType3, bool] // Amt is the amount of satoshis this HTLC escrows. - // - // NOTE: this field is the memory representation of the field amtUint. - Amt btcutil.Amount + Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] - // amtTlv is the uint64 format of Amt. This field is created so we can - // easily make it into a tlv record and save it to disk. - // - // NOTE: we keep this field for accounting purpose only. If the disk - // space becomes an issue, we could delete this field to save us extra - // 8 bytes. - amtTlv uint64 - - // incomingTlv is the uint8 format of Incoming. This field is created - // so we can easily make it into a tlv record and save it to disk. - incomingTlv uint8 -} + // CustomBlob is an optional blob that can be used to store information + // specific to revocation handling for a custom channel type. + CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] -// RHashLen is used by MakeDynamicRecord to return the size of the RHash. -// -// NOTE: for zero hash, we return a length 0. -func (h *HTLCEntry) RHashLen() uint64 { - if h.RHash == lntypes.ZeroHash { - return 0 - } - return 32 + // HtlcIndex is the index of the HTLC in the channel. + HtlcIndex tlv.RecordT[tlv.TlvType6, uint16] } -// RHashEncoder is the customized encoder which skips encoding the empty hash. -func RHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { - v, ok := val.(*[32]byte) - if !ok { - return tlv.NewTypeForEncodingErr(val, "RHash") - } - - // If the value is an empty hash, we will skip encoding it. - if *v == lntypes.ZeroHash { - return nil +// toTlvStream converts an HTLCEntry record into a tlv representation. +func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { + records := []tlv.Record{ + h.RHash.Record(), + h.RefundTimeout.Record(), + h.OutputIndex.Record(), + h.Incoming.Record(), + h.Amt.Record(), + h.HtlcIndex.Record(), } - return tlv.EBytes32(w, v, buf) -} - -// RHashDecoder is the customized decoder which skips decoding the empty hash. -func RHashDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { - v, ok := val.(*[32]byte) - if !ok { - return tlv.NewTypeForEncodingErr(val, "RHash") - } + h.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { + records = append(records, r.Record()) + }) - // If the length is zero, we will skip encoding the empty hash. - if l == 0 { - return nil - } + tlv.SortRecords(records) - return tlv.DBytes32(r, v, buf, 32) + return tlv.NewStream(records...) } -// toTlvStream converts an HTLCEntry record into a tlv representation. -func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { - const ( - // A set of tlv type definitions used to serialize htlc entries - // to the database. We define it here instead of the head of - // the file to avoid naming conflicts. - // - // NOTE: A migration should be added whenever this list - // changes. - rHashType tlv.Type = 0 - refundTimeoutType tlv.Type = 1 - outputIndexType tlv.Type = 2 - incomingType tlv.Type = 3 - amtType tlv.Type = 4 - ) - - return tlv.NewStream( - tlv.MakeDynamicRecord( - rHashType, &h.RHash, h.RHashLen, - RHashEncoder, RHashDecoder, +// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC. +func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { + h := &HTLCEntry{ + RHash: tlv.NewRecordT[tlv.TlvType0]( + NewSparsePayHash(htlc.RHash), ), - tlv.MakePrimitiveRecord( - refundTimeoutType, &h.RefundTimeout, + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1]( + htlc.RefundTimeout, ), - tlv.MakePrimitiveRecord( - outputIndexType, &h.OutputIndex, + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint16(htlc.OutputIndex), ), - tlv.MakePrimitiveRecord(incomingType, &h.incomingTlv), - // We will save 3 bytes if the amount is less or equal to - // 4,294,967,295 msat, or roughly 0.043 bitcoin. - tlv.MakeBigSizeRecord(amtType, &h.amtTlv), - ) + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), + ), + HtlcIndex: tlv.NewPrimitiveRecord[tlv.TlvType6]( + uint16(htlc.HtlcIndex), + ), + } + + if len(htlc.ExtraData) != 0 { + h.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob]( + htlc.ExtraData, + ), + ) + } + + return h } // RevocationLog stores the info needed to construct a breach retribution. Its @@ -191,15 +227,15 @@ func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { type RevocationLog struct { // OurOutputIndex specifies our output index in this commitment. In a // remote commitment transaction, this is the to remote output index. - OurOutputIndex uint16 + OurOutputIndex tlv.RecordT[tlv.TlvType0, uint16] // TheirOutputIndex specifies their output index in this commitment. In // a remote commitment transaction, this is the to local output index. - TheirOutputIndex uint16 + TheirOutputIndex tlv.RecordT[tlv.TlvType1, uint16] // CommitTxHash is the hash of the latest version of the commitment // state, broadcast able by us. - CommitTxHash [32]byte + CommitTxHash tlv.RecordT[tlv.TlvType2, [32]byte] // HTLCEntries is the set of HTLCEntry's that are pending at this // particular commitment height. @@ -209,21 +245,65 @@ type RevocationLog struct { // directly spendable by us. In other words, it is the value of the // to_remote output on the remote parties' commitment transaction. // - // NOTE: this is a pointer so that it is clear if the value is zero or + // NOTE: this is an option so that it is clear if the value is zero or // nil. Since migration 30 of the channeldb initially did not include // this field, it could be the case that the field is not present for // all revocation logs. - OurBalance *lnwire.MilliSatoshi + OurBalance tlv.OptionalRecordT[tlv.TlvType3, BigSizeMilliSatoshi] // TheirBalance is the current available balance within the channel // directly spendable by the remote node. In other words, it is the // value of the to_local output on the remote parties' commitment. // - // NOTE: this is a pointer so that it is clear if the value is zero or + // NOTE: this is an option so that it is clear if the value is zero or // nil. Since migration 30 of the channeldb initially did not include // this field, it could be the case that the field is not present for // all revocation logs. - TheirBalance *lnwire.MilliSatoshi + TheirBalance tlv.OptionalRecordT[tlv.TlvType4, BigSizeMilliSatoshi] + + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] +} + +// NewRevocationLog creates a new RevocationLog from the given parameters. +func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16, + commitHash [32]byte, ourBalance, + theirBalance fn.Option[lnwire.MilliSatoshi], htlcs []*HTLCEntry, + customBlob fn.Option[tlv.Blob]) RevocationLog { + + rl := RevocationLog{ + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + ourOutputIndex, + ), + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + theirOutputIndex, + ), + CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2](commitHash), + HTLCEntries: htlcs, + } + + ourBalance.WhenSome(func(balance lnwire.MilliSatoshi) { + rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( + tlv.NewBigSizeT(balance), + )) + }) + + theirBalance.WhenSome(func(balance lnwire.MilliSatoshi) { + rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(balance), + )) + }) + + customBlob.WhenSome(func(blob tlv.Blob) { + rl.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + }) + + return rl } // putRevocationLog uses the fields `CommitTx` and `Htlcs` from a @@ -242,15 +322,32 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, } rl := &RevocationLog{ - OurOutputIndex: uint16(ourOutputIndex), - TheirOutputIndex: uint16(theirOutputIndex), - CommitTxHash: commit.CommitTx.TxHash(), - HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint16(ourOutputIndex), + ), + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint16(theirOutputIndex), + ), + CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2, [32]byte]( + commit.CommitTx.TxHash(), + ), + HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), } + commit.CustomBlob.WhenSome(func(blob tlv.Blob) { + rl.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + }) + if !noAmtData { - rl.OurBalance = &commit.LocalBalance - rl.TheirBalance = &commit.RemoteBalance + rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( + tlv.NewBigSizeT(commit.LocalBalance), + )) + + rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(commit.RemoteBalance), + )) } for _, htlc := range commit.Htlcs { @@ -265,13 +362,7 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, return ErrOutputIndexTooBig } - entry := &HTLCEntry{ - RHash: htlc.RHash, - RefundTimeout: htlc.RefundTimeout, - Incoming: htlc.Incoming, - OutputIndex: uint16(htlc.OutputIndex), - Amt: htlc.Amt.ToSatoshis(), - } + entry := NewHTLCEntryFromHTLC(htlc) rl.HTLCEntries = append(rl.HTLCEntries, entry) } @@ -306,31 +397,27 @@ func fetchRevocationLog(log kvdb.RBucket, func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { // Add the tlv records for all non-optional fields. records := []tlv.Record{ - tlv.MakePrimitiveRecord( - revLogOurOutputIndexType, &rl.OurOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogTheirOutputIndexType, &rl.TheirOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogCommitTxHashType, &rl.CommitTxHash, - ), + rl.OurOutputIndex.Record(), + rl.TheirOutputIndex.Record(), + rl.CommitTxHash.Record(), } // Now we add any optional fields that are non-nil. - if rl.OurBalance != nil { - lb := uint64(*rl.OurBalance) - records = append(records, tlv.MakeBigSizeRecord( - revLogOurBalanceType, &lb, - )) - } + rl.OurBalance.WhenSome( + func(r tlv.RecordT[tlv.TlvType3, BigSizeMilliSatoshi]) { + records = append(records, r.Record()) + }, + ) - if rl.TheirBalance != nil { - rb := uint64(*rl.TheirBalance) - records = append(records, tlv.MakeBigSizeRecord( - revLogTheirBalanceType, &rb, - )) - } + rl.TheirBalance.WhenSome( + func(r tlv.RecordT[tlv.TlvType4, BigSizeMilliSatoshi]) { + records = append(records, r.Record()) + }, + ) + + rl.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { + records = append(records, r.Record()) + }) // Create the tlv stream. tlvStream, err := tlv.NewStream(records...) @@ -351,14 +438,6 @@ func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { // format. func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { for _, htlc := range htlcs { - // Patch the incomingTlv field. - if htlc.Incoming { - htlc.incomingTlv = 1 - } - - // Patch the amtTlv field. - htlc.amtTlv = uint64(htlc.Amt) - // Create the tlv stream. tlvStream, err := htlc.toTlvStream() if err != nil { @@ -376,27 +455,20 @@ func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { // deserializeRevocationLog deserializes a RevocationLog based on tlv format. func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { - var ( - rl RevocationLog - ourBalance uint64 - theirBalance uint64 - ) + var rl RevocationLog + + ourBalance := rl.OurBalance.Zero() + theirBalance := rl.TheirBalance.Zero() + customBlob := rl.CustomBlob.Zero() // Create the tlv stream. tlvStream, err := tlv.NewStream( - tlv.MakePrimitiveRecord( - revLogOurOutputIndexType, &rl.OurOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogTheirOutputIndexType, &rl.TheirOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogCommitTxHashType, &rl.CommitTxHash, - ), - tlv.MakeBigSizeRecord(revLogOurBalanceType, &ourBalance), - tlv.MakeBigSizeRecord( - revLogTheirBalanceType, &theirBalance, - ), + rl.OurOutputIndex.Record(), + rl.TheirOutputIndex.Record(), + rl.CommitTxHash.Record(), + ourBalance.Record(), + theirBalance.Record(), + customBlob.Record(), ) if err != nil { return rl, err @@ -408,14 +480,16 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { return rl, err } - if t, ok := parsedTypes[revLogOurBalanceType]; ok && t == nil { - lb := lnwire.MilliSatoshi(ourBalance) - rl.OurBalance = &lb + if t, ok := parsedTypes[ourBalance.TlvType()]; ok && t == nil { + rl.OurBalance = tlv.SomeRecordT(ourBalance) + } + + if t, ok := parsedTypes[theirBalance.TlvType()]; ok && t == nil { + rl.TheirBalance = tlv.SomeRecordT(theirBalance) } - if t, ok := parsedTypes[revLogTheirBalanceType]; ok && t == nil { - rb := lnwire.MilliSatoshi(theirBalance) - rl.TheirBalance = &rb + if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { + rl.CustomBlob = tlv.SomeRecordT(customBlob) } // Read the HTLC entries. @@ -432,14 +506,27 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { for { var htlc HTLCEntry + customBlob := htlc.CustomBlob.Zero() + // Create the tlv stream. - tlvStream, err := htlc.toTlvStream() + records := []tlv.Record{ + htlc.RHash.Record(), + htlc.RefundTimeout.Record(), + htlc.OutputIndex.Record(), + htlc.Incoming.Record(), + htlc.Amt.Record(), + customBlob.Record(), + htlc.HtlcIndex.Record(), + } + + tlvStream, err := tlv.NewStream(records...) if err != nil { return nil, err } // Read the HTLC entry. - if _, err := readTlvStream(r, tlvStream); err != nil { + parsedTypes, err := readTlvStream(r, tlvStream) + if err != nil { // We've reached the end when hitting an EOF. if err == io.ErrUnexpectedEOF { break @@ -447,14 +534,10 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { return nil, err } - // Patch the Incoming field. - if htlc.incomingTlv == 1 { - htlc.Incoming = true + if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { + htlc.CustomBlob = tlv.SomeRecordT(customBlob) } - // Patch the Amt field. - htlc.Amt = btcutil.Amount(htlc.amtTlv) - // Append the entry. htlcs = append(htlcs, &htlc) } @@ -469,6 +552,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error { if err := s.Encode(&b); err != nil { return err } + // Write the stream's length as a varint. err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{}) if err != nil { diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index fc5303a48dc..d21c41fbf93 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lnwire" @@ -33,17 +34,29 @@ var ( 0xff, // value = 255 } + blobBytes = tlv.Blob{ + 0x01, 0x02, 0x03, 0x04, + } + testHTLCEntry = HTLCEntry{ - RefundTimeout: 740_000, - OutputIndex: 10, - Incoming: true, - Amt: 1000_000, - amtTlv: 1000_000, - incomingTlv: 1, + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32]( + 740_000, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16]( + 10, + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(1_000_000)), + ), + CustomBlob: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5](blobBytes), + ), + HtlcIndex: tlv.NewPrimitiveRecord[tlv.TlvType6, uint16](3), } testHTLCEntryBytes = []byte{ - // Body length 23. - 0x16, + // Body length 32. + 0x20, // Rhash tlv. 0x0, 0x0, // RefundTimeout tlv. @@ -54,6 +67,46 @@ var ( 0x3, 0x1, 0x1, // Amt tlv. 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, + // Custom blob tlv. + 0x5, 0x4, 0x1, 0x2, 0x3, 0x4, + // HLTC index tlv. + 0x6, 0x2, 0x0, 0x03, + } + + testHTLCEntryHash = HTLCEntry{ + RHash: tlv.NewPrimitiveRecord[tlv.TlvType0](NewSparsePayHash( + [32]byte{0x33, 0x44, 0x55}, + )), + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32]( + 740_000, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16]( + 10, + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(1_000_000)), + ), + } + testHTLCEntryHashBytes = []byte{ + // Body length 58. + 0x3a, + // Rhash tlv. + 0x0, 0x20, + 0x33, 0x44, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // RefundTimeout tlv. + 0x1, 0x4, 0x0, 0xb, 0x4a, 0xa0, + // OutputIndex tlv. + 0x2, 0x2, 0x0, 0xa, + // Incoming tlv. + 0x3, 0x1, 0x1, + // Amt tlv. + 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, + // HLTC index tlv. + 0x6, 0x2, 0x0, 0x00, } localBalance = lnwire.MilliSatoshi(9000) @@ -68,24 +121,26 @@ var ( CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), Htlcs: []HTLC{{ - RefundTimeout: testHTLCEntry.RefundTimeout, - OutputIndex: int32(testHTLCEntry.OutputIndex), - Incoming: testHTLCEntry.Incoming, + RefundTimeout: testHTLCEntry.RefundTimeout.Val, + OutputIndex: int32(testHTLCEntry.OutputIndex.Val), + HtlcIndex: uint64(testHTLCEntry.HtlcIndex.Val), + Incoming: testHTLCEntry.Incoming.Val, Amt: lnwire.NewMSatFromSatoshis( - testHTLCEntry.Amt, + testHTLCEntry.Amt.Val.Int(), ), + ExtraData: blobBytes, }}, + CustomBlob: fn.Some(blobBytes), } - testRevocationLogNoAmts = RevocationLog{ - OurOutputIndex: 0, - TheirOutputIndex: 1, - CommitTxHash: testChannelCommit.CommitTx.TxHash(), - HTLCEntries: []*HTLCEntry{&testHTLCEntry}, - } + testRevocationLogNoAmts = NewRevocationLog( + 0, 1, testChannelCommit.CommitTx.TxHash(), + fn.None[lnwire.MilliSatoshi](), fn.None[lnwire.MilliSatoshi](), + []*HTLCEntry{&testHTLCEntry}, fn.Some(blobBytes), + ) testRevocationLogNoAmtsBytes = []byte{ - // Body length 42. - 0x2a, + // Body length 48. + 0x30, // OurOutputIndex tlv. 0x0, 0x2, 0x0, 0x0, // TheirOutputIndex tlv. @@ -96,19 +151,18 @@ var ( 0x6e, 0x60, 0x29, 0x23, 0x1d, 0x5e, 0xc5, 0xe6, 0xbd, 0xf7, 0xd3, 0x9b, 0x16, 0x7d, 0x0, 0xff, 0xc8, 0x22, 0x51, 0xb1, 0x5b, 0xa0, 0xbf, 0xd, + // Custom blob tlv. + 0x5, 0x4, 0x1, 0x2, 0x3, 0x4, } - testRevocationLogWithAmts = RevocationLog{ - OurOutputIndex: 0, - TheirOutputIndex: 1, - CommitTxHash: testChannelCommit.CommitTx.TxHash(), - HTLCEntries: []*HTLCEntry{&testHTLCEntry}, - OurBalance: &localBalance, - TheirBalance: &remoteBalance, - } + testRevocationLogWithAmts = NewRevocationLog( + 0, 1, testChannelCommit.CommitTx.TxHash(), + fn.Some(localBalance), fn.Some(remoteBalance), + []*HTLCEntry{&testHTLCEntry}, fn.Some(blobBytes), + ) testRevocationLogWithAmtsBytes = []byte{ - // Body length 52. - 0x34, + // Body length 58. + 0x3a, // OurOutputIndex tlv. 0x0, 0x2, 0x0, 0x0, // TheirOutputIndex tlv. @@ -123,6 +177,8 @@ var ( 0x3, 0x3, 0xfd, 0x23, 0x28, // Remote Balance. 0x4, 0x3, 0xfd, 0x0b, 0xb8, + // Custom blob tlv. + 0x5, 0x4, 0x1, 0x2, 0x3, 0x4, } ) @@ -193,11 +249,6 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) { // Copy the testHTLCEntry. entry := testHTLCEntry - // Set the internal fields to empty values so we can test the bytes are - // padded. - entry.incomingTlv = 0 - entry.amtTlv = 0 - // Write the tlv stream. buf := bytes.NewBuffer([]byte{}) err := serializeHTLCEntries(buf, []*HTLCEntry{&entry}) @@ -207,6 +258,21 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) { require.Equal(t, testHTLCEntryBytes, buf.Bytes()) } +func TestSerializeHTLCEntriesWithRHash(t *testing.T) { + t.Parallel() + + // Copy the testHTLCEntry. + entry := testHTLCEntryHash + + // Write the tlv stream. + buf := bytes.NewBuffer([]byte{}) + err := serializeHTLCEntries(buf, []*HTLCEntry{&entry}) + require.NoError(t, err) + + // Check the bytes are read as expected. + require.Equal(t, testHTLCEntryHashBytes, buf.Bytes()) +} + func TestSerializeHTLCEntries(t *testing.T) { t.Parallel() @@ -215,7 +281,7 @@ func TestSerializeHTLCEntries(t *testing.T) { // Create a fake rHash. rHashBytes := bytes.Repeat([]byte{10}, 32) - copy(entry.RHash[:], rHashBytes) + copy(entry.RHash.Val[:], rHashBytes) // Construct the serialized bytes. // @@ -224,7 +290,7 @@ func TestSerializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - expectedBytes := []byte{0x36, 0x0, 0x20} + expectedBytes := []byte{0x40, 0x0, 0x20} expectedBytes = append(expectedBytes, rHashBytes...) // Append the rest. @@ -269,7 +335,7 @@ func TestSerializeAndDeserializeRevLog(t *testing.T) { t, &test.revLog, test.revLogBytes, ) - testDerializeRevocationLog( + testDeserializeRevocationLog( t, &test.revLog, test.revLogBytes, ) }) @@ -293,7 +359,7 @@ func testSerializeRevocationLog(t *testing.T, rl *RevocationLog, require.Equal(t, revLogBytes, buf.Bytes()[:bodyIndex]) } -func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog, +func testDeserializeRevocationLog(t *testing.T, revLog *RevocationLog, revLogBytes []byte) { // Construct the full bytes. @@ -309,7 +375,7 @@ func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog, require.Equal(t, *revLog, rl) } -func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) { +func TestDeserializeHTLCEntriesEmptyRHash(t *testing.T) { t.Parallel() // Read the tlv stream. @@ -322,7 +388,7 @@ func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) { require.Equal(t, &testHTLCEntry, htlcs[0]) } -func TestDerializeHTLCEntries(t *testing.T) { +func TestDeserializeHTLCEntries(t *testing.T) { t.Parallel() // Copy the testHTLCEntry. @@ -330,7 +396,7 @@ func TestDerializeHTLCEntries(t *testing.T) { // Create a fake rHash. rHashBytes := bytes.Repeat([]byte{10}, 32) - copy(entry.RHash[:], rHashBytes) + copy(entry.RHash.Val[:], rHashBytes) // Construct the serialized bytes. // @@ -339,7 +405,7 @@ func TestDerializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - testBytes := append([]byte{0x36, 0x0, 0x20}, rHashBytes...) + testBytes := append([]byte{0x40, 0x0, 0x20}, rHashBytes...) // Append the rest. testBytes = append(testBytes, partialBytes...) @@ -398,11 +464,11 @@ func TestDeleteLogBucket(t *testing.T) { err = kvdb.Update(backend, func(tx kvdb.RwTx) error { // Create the buckets. - chanBucket, _, err := createTestRevocatoinLogBuckets(tx) + chanBucket, _, err := createTestRevocationLogBuckets(tx) require.NoError(t, err) // Create the buckets again should give us an error. - _, _, err = createTestRevocatoinLogBuckets(tx) + _, _, err = createTestRevocationLogBuckets(tx) require.ErrorIs(t, err, kvdb.ErrBucketExists) // Delete both buckets. @@ -410,7 +476,7 @@ func TestDeleteLogBucket(t *testing.T) { require.NoError(t, err) // Create the buckets again should give us NO error. - _, _, err = createTestRevocatoinLogBuckets(tx) + _, _, err = createTestRevocationLogBuckets(tx) return err }, func() {}) require.NoError(t, err) @@ -516,7 +582,7 @@ func TestPutRevocationLog(t *testing.T) { // Construct the testing db transaction. dbTx := func(tx kvdb.RwTx) (RevocationLog, error) { // Create the buckets. - _, bucket, err := createTestRevocatoinLogBuckets(tx) + _, bucket, err := createTestRevocationLogBuckets(tx) require.NoError(t, err) // Save the log. @@ -686,7 +752,7 @@ func TestFetchRevocationLogCompatible(t *testing.T) { } } -func createTestRevocatoinLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket, +func createTestRevocationLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket, kvdb.RwBucket, error) { chanBucket, err := tx.CreateTopLevelBucket(openChannelBucket) diff --git a/config_builder.go b/config_builder.go index bbf291506b5..7af32738243 100644 --- a/config_builder.go +++ b/config_builder.go @@ -34,6 +34,7 @@ import ( "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -104,7 +105,7 @@ type DatabaseBuilder interface { type WalletConfigBuilder interface { // BuildWalletConfig is responsible for creating or unlocking and then // fully initializing a wallet. - BuildWalletConfig(context.Context, *DatabaseInstances, + BuildWalletConfig(context.Context, *DatabaseInstances, *AuxComponents, *rpcperms.InterceptorChain, []*ListenerWithSignal) (*chainreg.PartialChainControl, *btcwallet.Config, func(), error) @@ -145,6 +146,17 @@ type ImplementationCfg struct { // ChainControlBuilder is a type that can provide a custom wallet // implementation. ChainControlBuilder + // AuxComponents is a set of auxiliary components that can be used by + // lnd for certain custom channel types. + AuxComponents +} + +// AuxComponents is a set of auxiliary components that can be used by lnd for +// certain custom channel types. +type AuxComponents struct { + // AuxLeafStore is an optional data source that can be used by custom + // channels to fetch+store various data. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // DefaultWalletImpl is the default implementation of our normal, btcwallet @@ -229,7 +241,8 @@ func (d *DefaultWalletImpl) Permissions() map[string][]bakery.Op { // // NOTE: This is part of the WalletConfigBuilder interface. func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, - dbs *DatabaseInstances, interceptorChain *rpcperms.InterceptorChain, + dbs *DatabaseInstances, aux *AuxComponents, + interceptorChain *rpcperms.InterceptorChain, grpcListeners []*ListenerWithSignal) (*chainreg.PartialChainControl, *btcwallet.Config, func(), error) { @@ -549,6 +562,7 @@ func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, HeightHintDB: dbs.HeightHintDB, ChanStateDB: dbs.ChanStateDB.ChannelStateDB(), NeutrinoCS: neutrinoCS, + AuxLeafStore: aux.AuxLeafStore, ActiveNetParams: d.cfg.ActiveNetParams, FeeURL: d.cfg.FeeURL, Dialer: func(addr string) (net.Conn, error) { @@ -607,8 +621,9 @@ func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, // proxyBlockEpoch proxies a block epoch subsections to the underlying neutrino // rebroadcaster client. -func proxyBlockEpoch(notifier chainntnfs.ChainNotifier, -) func() (*blockntfns.Subscription, error) { +func proxyBlockEpoch( + notifier chainntnfs.ChainNotifier) func() (*blockntfns.Subscription, + error) { return func() (*blockntfns.Subscription, error) { blockEpoch, err := notifier.RegisterBlockEpochNtfn( @@ -699,6 +714,7 @@ func (d *DefaultWalletImpl) BuildChainControl( ChainIO: walletController, NetParams: *walletConfig.NetParams, CoinSelectionStrategy: walletConfig.CoinSelectionStrategy, + AuxLeafStore: partialChainControl.Cfg.AuxLeafStore, } // The broadcast is already always active for neutrino nodes, so we @@ -878,6 +894,10 @@ type DatabaseInstances struct { // for native SQL queries for tables that already support it. This may // be nil if the use-native-sql flag was not set. NativeSQLStore *sqldb.BaseDB + + // AuxLeafStore is an optional data source that can be used by custom + // channels to fetch+store various data. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // DefaultDatabaseBuilder is a type that builds the default database backends diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index 2fe4644db96..babb427ea29 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -22,6 +22,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntest/channels" @@ -1590,6 +1591,7 @@ func testBreachSpends(t *testing.T, test breachTest) { // Notify the breach arbiter about the breach. retribution, err := lnwallet.NewBreachRetribution( alice.State(), height, 1, forceCloseTx, + fn.None[lnwallet.AuxLeafStore](), ) require.NoError(t, err, "unable to create breach retribution") @@ -1799,6 +1801,7 @@ func TestBreachDelayedJusticeConfirmation(t *testing.T) { // Notify the breach arbiter about the breach. retribution, err := lnwallet.NewBreachRetribution( alice.State(), height, uint32(blockHeight), forceCloseTx, + fn.None[lnwallet.AuxLeafStore](), ) require.NoError(t, err, "unable to create breach retribution") diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index fbddd81f0ee..3245162cb59 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -217,6 +217,10 @@ type ChainArbitratorConfig struct { // meanwhile, turn `PaymentCircuit` into an interface or bring it to a // lower package. QueryIncomingCircuit func(circuit models.CircuitKey) *models.CircuitKey + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // ChainArbitrator is a sub-system that oversees the on-chain resolution of all @@ -299,8 +303,13 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions, return nil, err } + var chanOpts []lnwallet.ChannelOpt + a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + chanMachine, err := lnwallet.NewLightningChannel( - a.c.cfg.Signer, channel, nil, + a.c.cfg.Signer, channel, nil, chanOpts..., ) if err != nil { return nil, err @@ -344,10 +353,15 @@ func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error) return nil, err } + var chanOpts []lnwallet.ChannelOpt + a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // Finally, we'll force close the channel completing // the force close workflow. chanMachine, err := lnwallet.NewLightningChannel( - a.c.cfg.Signer, channel, nil, + a.c.cfg.Signer, channel, nil, chanOpts..., ) if err != nil { return nil, err diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 5372d8c0dde..89280d1bd1c 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -188,6 +188,9 @@ type chainWatcherConfig struct { // obfuscater. This is used by the chain watcher to identify which // state was broadcast and confirmed on-chain. extractStateNumHint func(*wire.MsgTx, [lnwallet.StateHintSize]byte) uint64 + + // auxLeafStore can be used to fetch information for custom channels. + auxLeafStore fn.Option[lnwallet.AuxLeafStore] } // chainWatcher is a system that's assigned to every active channel. The duty @@ -421,15 +424,30 @@ func (c *chainWatcher) handleUnknownLocalState( &c.cfg.chanState.LocalChanCfg, &c.cfg.chanState.RemoteChanCfg, ) + auxLeaves, err := lnwallet.AuxLeavesFromCommit( + c.cfg.chanState, c.cfg.chanState.LocalCommitment, + c.cfg.auxLeafStore, *commitKeyRing, + ) + if err != nil { + return false, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // With the keys derived, we'll construct the remote script that'll be // present if they have a non-dust balance on the commitment. var leaseExpiry uint32 if c.cfg.chanState.ChanType.HasLeaseExpiration() { leaseExpiry = c.cfg.chanState.ThawHeight } + + remoteAuxLeaf := fn.MapOption( + func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + }, + )(auxLeaves) remoteScript, _, err := lnwallet.CommitScriptToRemote( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, commitKeyRing.ToRemoteKey, leaseExpiry, + fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return false, err @@ -438,10 +456,16 @@ func (c *chainWatcher) handleUnknownLocalState( // Next, we'll derive our script that includes the revocation base for // the remote party allowing them to claim this output before the CSV // delay if we breach. + localAuxLeaf := fn.MapOption( + func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + }, + )(auxLeaves) localScript, err := lnwallet.CommitScriptToSelf( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, commitKeyRing.ToLocalKey, commitKeyRing.RevocationKey, uint32(c.cfg.chanState.LocalChanCfg.CsvDelay), leaseExpiry, + fn.FlattenOption(localAuxLeaf), ) if err != nil { return false, err @@ -861,7 +885,7 @@ func (c *chainWatcher) handlePossibleBreach(commitSpend *chainntnfs.SpendDetail, spendHeight := uint32(commitSpend.SpendingHeight) retribution, err := lnwallet.NewBreachRetribution( c.cfg.chanState, broadcastStateNum, spendHeight, - commitSpend.SpendingTx, + commitSpend.SpendingTx, c.cfg.auxLeafStore, ) switch { @@ -1072,8 +1096,8 @@ func (c *chainWatcher) dispatchLocalForceClose( "detected", c.cfg.chanState.FundingOutpoint) forceClose, err := lnwallet.NewLocalForceCloseSummary( - c.cfg.chanState, c.cfg.signer, - commitSpend.SpendingTx, stateNum, + c.cfg.chanState, c.cfg.signer, commitSpend.SpendingTx, stateNum, + c.cfg.auxLeafStore, ) if err != nil { return err @@ -1166,7 +1190,7 @@ func (c *chainWatcher) dispatchRemoteForceClose( // channel on-chain. uniClose, err := lnwallet.NewUnilateralCloseSummary( c.cfg.chanState, c.cfg.signer, commitSpend, - remoteCommit, commitPoint, + remoteCommit, commitPoint, c.cfg.auxLeafStore, ) if err != nil { return err diff --git a/funding/manager.go b/funding/manager.go index 227bb0e6632..4a9c0122d61 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -538,6 +538,10 @@ type Config struct { // AliasManager is an implementation of the aliasHandler interface that // abstracts away the handling of many alias functions. AliasManager aliasHandler + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // Manager acts as an orchestrator/bridge between the wallet's @@ -1056,9 +1060,14 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, } } + var chanOpts []lnwallet.ChannelOpt + f.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // We create the state-machine object which wraps the database state. lnChannel, err := lnwallet.NewLightningChannel( - nil, channel, nil, + nil, channel, nil, chanOpts..., ) if err != nil { log.Errorf("Unable to create LightningChannel(%v): %v", diff --git a/go.mod b/go.mod index b8a658edd76..75e43e5731e 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,7 @@ require ( github.com/lightningnetwork/lnd/queue v1.1.1 github.com/lightningnetwork/lnd/sqldb v1.0.1 github.com/lightningnetwork/lnd/ticker v1.1.1 - github.com/lightningnetwork/lnd/tlv v1.2.3 + github.com/lightningnetwork/lnd/tlv v1.2.5 github.com/lightningnetwork/lnd/tor v1.1.2 github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 github.com/miekg/dns v1.1.43 diff --git a/go.sum b/go.sum index c2f3fdb804b..c3f8b2e8265 100644 --- a/go.sum +++ b/go.sum @@ -457,8 +457,8 @@ github.com/lightningnetwork/lnd/sqldb v1.0.1 h1:lpNoJ6qRh3D02oeIUsKQLZUzjcgZ9ppM github.com/lightningnetwork/lnd/sqldb v1.0.1/go.mod h1:nSovU1U+gTPDWhfwmXu/kW8l8EJpwbvZQ05ijnkQzkA= github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM= github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA= -github.com/lightningnetwork/lnd/tlv v1.2.3 h1:If5ibokA/UoCBGuCKaY6Vn2SJU0l9uAbehCnhTZjEP8= -github.com/lightningnetwork/lnd/tlv v1.2.3/go.mod h1:zDkmqxOczP6LaLTvSFDQ1SJUfHcQRCMKFj93dn3eMB8= +github.com/lightningnetwork/lnd/tlv v1.2.5 h1:/VsoWw628t78OiDN90pHDbqwOcuZ9JMicxXZVQjBwX0= +github.com/lightningnetwork/lnd/tlv v1.2.5/go.mod h1:/CmY4VbItpOldksocmGT4lxiJqRP9oLxwSZOda2kzNQ= github.com/lightningnetwork/lnd/tor v1.1.2 h1:3zv9z/EivNFaMF89v3ciBjCS7kvCj4ZFG7XvD2Qq0/k= github.com/lightningnetwork/lnd/tor v1.1.2/go.mod h1:j7T9uJ2NLMaHwE7GiBGnpYLn4f7NRoTM6qj+ul6/ycA= github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 h1:sjOGyegMIhvgfq5oaue6Td+hxZuf3tDC8lAPrFldqFw= diff --git a/input/script_utils.go b/input/script_utils.go index 9e639bca849..5ad0a0e90d9 100644 --- a/input/script_utils.go +++ b/input/script_utils.go @@ -673,6 +673,13 @@ type HtlcScriptTree struct { // TimeoutTapLeaf is the tapleaf for the timeout path. TimeoutTapLeaf txscript.TapLeaf + // AuxLeaf is an auxiliary leaf that can be used to extend the base + // HTLC script tree with new spend paths, or just as extra commitment + // space. When present, this leaf will always be in the left-most or + // right-most area of the tapscript tree. + AuxLeaf AuxTapLeaf + + // htlcType is the type of HTLC script this is. htlcType htlcType } @@ -748,8 +755,8 @@ var _ TapscriptDescriptor = (*HtlcScriptTree)(nil) // senderHtlcTapScriptTree builds the tapscript tree which is used to anchor // the HTLC key for HTLCs on the sender's commitment. func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, - revokeKey *btcec.PublicKey, payHash []byte, - hType htlcType) (*HtlcScriptTree, error) { + revokeKey *btcec.PublicKey, payHash []byte, hType htlcType, + auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { // First, we'll obtain the tap leaves for both the success and timeout // path. @@ -766,11 +773,14 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, return nil, err } + tapLeaves := []txscript.TapLeaf{successTapLeaf, timeoutTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With the two leaves obtained, we'll now make the tapscript tree, // then obtain the root from that - tapscriptTree := txscript.AssembleTaprootScriptTree( - successTapLeaf, timeoutTapLeaf, - ) + tapscriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapscriptTree.RootNode.TapHash() @@ -789,6 +799,7 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, }, SuccessTapLeaf: successTapLeaf, TimeoutTapLeaf: timeoutTapLeaf, + AuxLeaf: auxLeaf, htlcType: hType, }, nil } @@ -822,8 +833,8 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // The top level keyspend key is the revocation key, which allows a defender to // unilaterally spend the created output. func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, - revokeKey *btcec.PublicKey, payHash []byte, - localCommit bool) (*HtlcScriptTree, error) { + revokeKey *btcec.PublicKey, payHash []byte, localCommit bool, + auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { var hType htlcType if localCommit { @@ -836,8 +847,8 @@ func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, // tree that includes the top level output script, as well as the two // tap leaf paths. return senderHtlcTapScriptTree( - senderHtlcKey, receiverHtlcKey, revokeKey, payHash, - hType, + senderHtlcKey, receiverHtlcKey, revokeKey, payHash, hType, + auxLeaf, ) } @@ -1307,8 +1318,8 @@ func ReceiverHtlcTapLeafSuccess(receiverHtlcKey *btcec.PublicKey, // receiverHtlcTapScriptTree builds the tapscript tree which is used to anchor // the HTLC key for HTLCs on the receiver's commitment. func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, - revokeKey *btcec.PublicKey, payHash []byte, - cltvExpiry uint32, hType htlcType) (*HtlcScriptTree, error) { + revokeKey *btcec.PublicKey, payHash []byte, cltvExpiry uint32, + hType htlcType, auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { // First, we'll obtain the tap leaves for both the success and timeout // path. @@ -1325,11 +1336,14 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, return nil, err } + tapLeaves := []txscript.TapLeaf{timeoutTapLeaf, successTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With the two leaves obtained, we'll now make the tapscript tree, // then obtain the root from that - tapscriptTree := txscript.AssembleTaprootScriptTree( - timeoutTapLeaf, successTapLeaf, - ) + tapscriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapscriptTree.RootNode.TapHash() @@ -1348,6 +1362,7 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, }, SuccessTapLeaf: successTapLeaf, TimeoutTapLeaf: timeoutTapLeaf, + AuxLeaf: auxLeaf, htlcType: hType, }, nil } @@ -1382,7 +1397,8 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // the tap leaf are returned. func ReceiverHTLCScriptTaproot(cltvExpiry uint32, senderHtlcKey, receiverHtlcKey, revocationKey *btcec.PublicKey, - payHash []byte, ourCommit bool) (*HtlcScriptTree, error) { + payHash []byte, ourCommit bool, auxLeaf AuxTapLeaf) (*HtlcScriptTree, + error) { var hType htlcType if ourCommit { @@ -1396,7 +1412,7 @@ func ReceiverHTLCScriptTaproot(cltvExpiry uint32, // tap leaf paths. return receiverHtlcTapScriptTree( senderHtlcKey, receiverHtlcKey, revocationKey, payHash, - cltvExpiry, hType, + cltvExpiry, hType, auxLeaf, ) } @@ -1625,9 +1641,9 @@ func TaprootSecondLevelTapLeaf(delayKey *btcec.PublicKey, } // SecondLevelHtlcTapscriptTree construct the indexed tapscript tree needed to -// generate the taptweak to create the final output and also control block. -func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, - csvDelay uint32) (*txscript.IndexedTapScriptTree, error) { +// generate the tap tweak to create the final output and also control block. +func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, csvDelay uint32, + auxLeaf AuxTapLeaf) (*txscript.IndexedTapScriptTree, error) { // First grab the second level leaf script we need to create the top // level output. @@ -1636,9 +1652,14 @@ func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, return nil, err } + tapLeaves := []txscript.TapLeaf{secondLevelTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // Now that we have the sole second level script, we can create the // tapscript tree that commits to both the leaves. - return txscript.AssembleTaprootScriptTree(secondLevelTapLeaf), nil + return txscript.AssembleTaprootScriptTree(tapLeaves...), nil } // TaprootSecondLevelHtlcScript is the uniform script that's used as the output @@ -1658,12 +1679,12 @@ func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, // // The keyspend path require knowledge of the top level revocation private key. func TaprootSecondLevelHtlcScript(revokeKey, delayKey *btcec.PublicKey, - csvDelay uint32) (*btcec.PublicKey, error) { + csvDelay uint32, auxLeaf AuxTapLeaf) (*btcec.PublicKey, error) { // First, we'll make the tapscript tree that commits to the redemption // path. tapScriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey, csvDelay, + delayKey, csvDelay, auxLeaf, ) if err != nil { return nil, err @@ -1688,17 +1709,21 @@ type SecondLevelScriptTree struct { // SuccessTapLeaf is the tapleaf for the redemption path. SuccessTapLeaf txscript.TapLeaf + + // AuxLeaf is an optional leaf that can be used to extend the script + // tree. + AuxLeaf AuxTapLeaf } // TaprootSecondLevelScriptTree constructs the tapscript tree used to spend the // second level HTLC output. func TaprootSecondLevelScriptTree(revokeKey, delayKey *btcec.PublicKey, - csvDelay uint32) (*SecondLevelScriptTree, error) { + csvDelay uint32, auxLeaf AuxTapLeaf) (*SecondLevelScriptTree, error) { // First, we'll make the tapscript tree that commits to the redemption // path. tapScriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey, csvDelay, + delayKey, csvDelay, auxLeaf, ) if err != nil { return nil, err @@ -1719,6 +1744,7 @@ func TaprootSecondLevelScriptTree(revokeKey, delayKey *btcec.PublicKey, InternalKey: revokeKey, }, SuccessTapLeaf: tapScriptTree.LeafMerkleProofs[0].TapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2095,6 +2121,12 @@ type CommitScriptTree struct { // RevocationLeaf is the leaf used to spend the output with the // revocation key signature. RevocationLeaf txscript.TapLeaf + + // AuxLeaf is an auxiliary leaf that can be used to extend the base + // commitment script tree with new spend paths, or just as extra + // commitment space. When present, this leaf will always be in the + // left-most or right-most area of the tapscript tree. + AuxLeaf AuxTapLeaf } // A compile time check to ensure CommitScriptTree implements the @@ -2154,8 +2186,9 @@ func (c *CommitScriptTree) CtrlBlockForPath(path ScriptPath, // NewLocalCommitScriptTree returns a new CommitScript tree that can be used to // create and spend the commitment output for the local party. -func NewLocalCommitScriptTree(csvTimeout uint32, - selfKey, revokeKey *btcec.PublicKey) (*CommitScriptTree, error) { +func NewLocalCommitScriptTree(csvTimeout uint32, selfKey, + revokeKey *btcec.PublicKey, auxLeaf AuxTapLeaf) (*CommitScriptTree, + error) { // First, we'll need to construct the tapLeaf that'll be our delay CSV // clause. @@ -2175,9 +2208,13 @@ func NewLocalCommitScriptTree(csvTimeout uint32, // the two leaves, and then obtain a root from that. delayTapLeaf := txscript.NewBaseTapLeaf(delayScript) revokeTapLeaf := txscript.NewBaseTapLeaf(revokeScript) - tapScriptTree := txscript.AssembleTaprootScriptTree( - delayTapLeaf, revokeTapLeaf, - ) + + tapLeaves := []txscript.TapLeaf{delayTapLeaf, revokeTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + + tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapScriptTree.RootNode.TapHash() // Now that we have our root, we can arrive at the final output script @@ -2195,6 +2232,7 @@ func NewLocalCommitScriptTree(csvTimeout uint32, }, SettleLeaf: delayTapLeaf, RevocationLeaf: revokeTapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2264,7 +2302,7 @@ func TaprootCommitScriptToSelf(csvTimeout uint32, selfKey, revokeKey *btcec.PublicKey) (*btcec.PublicKey, error) { commitScriptTree, err := NewLocalCommitScriptTree( - csvTimeout, selfKey, revokeKey, + csvTimeout, selfKey, revokeKey, NoneTapLeaf(), ) if err != nil { return nil, err @@ -2593,7 +2631,7 @@ func CommitScriptToRemoteConfirmed(key *btcec.PublicKey) ([]byte, error) { // NewRemoteCommitScriptTree constructs a new script tree for the remote party // to sweep their funds after a hard coded 1 block delay. func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, -) (*CommitScriptTree, error) { + auxLeaf AuxTapLeaf) (*CommitScriptTree, error) { // First, construct the remote party's tapscript they'll use to sweep // their outputs. @@ -2609,10 +2647,16 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, return nil, err } + tapLeaf := txscript.NewBaseTapLeaf(remoteScript) + + tapLeaves := []txscript.TapLeaf{tapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With this script constructed, we'll map that into a tapLeaf, then // make a new tapscript root from that. - tapLeaf := txscript.NewBaseTapLeaf(remoteScript) - tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaf) + tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapScriptTree.RootNode.TapHash() // Now that we have our root, we can arrive at the final output script @@ -2629,6 +2673,7 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, InternalKey: &TaprootNUMSKey, }, SettleLeaf: tapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2645,9 +2690,9 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, // OP_CHECKSIG // 1 OP_CHECKSEQUENCEVERIFY OP_DROP func TaprootCommitScriptToRemote(remoteKey *btcec.PublicKey, -) (*btcec.PublicKey, error) { + auxLeaf AuxTapLeaf) (*btcec.PublicKey, error) { - commitScriptTree, err := NewRemoteCommitScriptTree(remoteKey) + commitScriptTree, err := NewRemoteCommitScriptTree(remoteKey, auxLeaf) if err != nil { return nil, err } diff --git a/input/size_test.go b/input/size_test.go index 1f447cf89fa..4cc1680ba48 100644 --- a/input/size_test.go +++ b/input/size_test.go @@ -851,7 +851,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} commitScriptTree, err := input.NewLocalCommitScriptTree( testCSVDelay, testKey.PubKey(), - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -885,7 +885,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} commitScriptTree, err := input.NewLocalCommitScriptTree( testCSVDelay, testKey.PubKey(), - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -919,7 +919,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} //nolint:lll commitScriptTree, err := input.NewRemoteCommitScriptTree( - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -986,6 +986,7 @@ var witnessSizeTests = []witnessSizeTest{ scriptTree, err := input.SecondLevelHtlcTapscriptTree( testKey.PubKey(), testCSVDelay, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1025,6 +1026,7 @@ var witnessSizeTests = []witnessSizeTest{ scriptTree, err := input.SecondLevelHtlcTapscriptTree( testKey.PubKey(), testCSVDelay, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1073,6 +1075,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], false, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1114,7 +1117,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], false, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1156,7 +1159,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], false, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1203,6 +1206,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], false, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1263,6 +1267,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], false, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1308,7 +1313,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], false, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1381,7 +1386,7 @@ func genTimeoutTx(t *testing.T, // Create the unsigned timeout tx. timeoutTx, err := lnwallet.CreateHtlcTimeoutTx( chanType, false, testOutPoint, testAmt, testCLTVExpiry, - testCSVDelay, 0, testPubkey, testPubkey, + testCSVDelay, 0, testPubkey, testPubkey, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1394,6 +1399,7 @@ func genTimeoutTx(t *testing.T, if chanType.IsTaproot() { tapscriptTree, err = input.SenderHTLCScriptTaproot( testPubkey, testPubkey, testPubkey, testHash160, false, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1449,7 +1455,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx { // Create the unsigned success tx. successTx, err := lnwallet.CreateHtlcSuccessTx( chanType, false, testOutPoint, testAmt, testCSVDelay, 0, - testPubkey, testPubkey, + testPubkey, testPubkey, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1462,7 +1468,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx { if chanType.IsTaproot() { tapscriptTree, err = input.ReceiverHTLCScriptTaproot( testCLTVExpiry, testPubkey, testPubkey, testPubkey, - testHash160, false, + testHash160, false, input.NoneTapLeaf(), ) require.NoError(t, err) diff --git a/input/taproot.go b/input/taproot.go index 34cdb974d5c..6edaf5de621 100644 --- a/input/taproot.go +++ b/input/taproot.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/waddrmgr" + "github.com/lightningnetwork/lnd/fn" ) const ( @@ -21,6 +22,33 @@ const ( PubKeyFormatCompressedOdd byte = 0x03 ) +// AuxTapLeaf is a type alias for an optional tapscript leaf that may be added +// to the tapscript tree of HTLC and commitment outputs. +type AuxTapLeaf = fn.Option[txscript.TapLeaf] + +// NoneTapLeaf returns an empty optional tapscript leaf. +func NoneTapLeaf() AuxTapLeaf { + return fn.None[txscript.TapLeaf]() +} + +// HtlcIndex represents the monotonically increasing counter that is used to +// identify HTLCs created a peer. +type HtlcIndex = uint64 + +// HtlcAuxLeaf is a type that represents an auxiliary leaf for an HTLC output. +// An HTLC may have up to two aux leaves: one for the output on the commitment +// transaction, and one for the second level HTLC. +type HtlcAuxLeaf struct { + AuxTapLeaf + + // SecondLevelLeaf is the auxiliary leaf for the second level HTLC + // success or timeout transaction. + SecondLevelLeaf AuxTapLeaf +} + +// AuxTapLeaves is a type alias for a slice of optional tapscript leaves. +type AuxTapLeaves = map[HtlcIndex]HtlcAuxLeaf + // NewTxSigHashesV0Only returns a new txscript.TxSigHashes instance that will // only calculate the sighash midstate values for segwit v0 inputs and can // therefore never be used for transactions that want to spend segwit v1 diff --git a/input/taproot_test.go b/input/taproot_test.go index 801b0fef4d5..ad4f16c4be6 100644 --- a/input/taproot_test.go +++ b/input/taproot_test.go @@ -1,13 +1,16 @@ package input import ( + "bytes" "crypto/rand" + "fmt" "testing" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" @@ -31,7 +34,9 @@ type testSenderHtlcScriptTree struct { htlcAmt int64 } -func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree { +func newTestSenderHtlcScriptTree(t *testing.T, + auxLeaf AuxTapLeaf) *testSenderHtlcScriptTree { + var preImage lntypes.Preimage _, err := rand.Read(preImage[:]) require.NoError(t, err) @@ -48,7 +53,7 @@ func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], false, auxLeaf, ) require.NoError(t, err) @@ -207,13 +212,9 @@ func htlcSenderTimeoutWitnessGen(sigHash txscript.SigHashType, } } -// TestTaprootSenderHtlcSpend tests that all the positive and negative paths -// for the sender HTLC tapscript tree work as expected. -func TestTaprootSenderHtlcSpend(t *testing.T) { - t.Parallel() - +func testTaprootSenderHtlcSpend(t *testing.T, auxLeaf AuxTapLeaf) { // First, create a new test script tree. - htlcScriptTree := newTestSenderHtlcScriptTree(t) + htlcScriptTree := newTestSenderHtlcScriptTree(t, auxLeaf) spendTx := wire.NewMsgTx(2) spendTx.AddTxIn(&wire.TxIn{}) @@ -432,6 +433,28 @@ func TestTaprootSenderHtlcSpend(t *testing.T) { } } +// TestTaprootSenderHtlcSpend tests that all the positive and negative paths +// for the sender HTLC tapscript tree work as expected. +func TestTaprootSenderHtlcSpend(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootSenderHtlcSpend(t, auxLeaf) + }) + } +} + type testReceiverHtlcScriptTree struct { preImage lntypes.Preimage @@ -452,7 +475,9 @@ type testReceiverHtlcScriptTree struct { lockTime int32 } -func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree { +func newTestReceiverHtlcScriptTree(t *testing.T, + auxLeaf AuxTapLeaf) *testReceiverHtlcScriptTree { + var preImage lntypes.Preimage _, err := rand.Read(preImage[:]) require.NoError(t, err) @@ -471,7 +496,7 @@ func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := ReceiverHTLCScriptTaproot( cltvExpiry, senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], false, auxLeaf, ) require.NoError(t, err) @@ -629,15 +654,11 @@ func htlcReceiverSuccessWitnessGen(sigHash txscript.SigHashType, } } -// TestTaprootReceiverHtlcSpend tests that all possible paths for redeeming an -// accepted HTLC (on the commitment transaction) of the receiver work properly. -func TestTaprootReceiverHtlcSpend(t *testing.T) { - t.Parallel() - +func testTaprootReceiverHtlcSpend(t *testing.T, auxLeaf AuxTapLeaf) { // We'll start by creating the HTLC script tree (contains all 3 valid // spend paths), and also a mock spend transaction that we'll be // signing below. - htlcScriptTree := newTestReceiverHtlcScriptTree(t) + htlcScriptTree := newTestReceiverHtlcScriptTree(t, auxLeaf) // TODO(roasbeef): issue with revoke key??? ctrl block even/odd @@ -891,6 +912,28 @@ func TestTaprootReceiverHtlcSpend(t *testing.T) { } } +// TestTaprootReceiverHtlcSpend tests that all possible paths for redeeming an +// accepted HTLC (on the commitment transaction) of the receiver work properly. +func TestTaprootReceiverHtlcSpend(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootReceiverHtlcSpend(t, auxLeaf) + }) + } +} + type testCommitScriptTree struct { csvDelay uint32 @@ -905,7 +948,9 @@ type testCommitScriptTree struct { *CommitScriptTree } -func newTestCommitScriptTree(local bool) (*testCommitScriptTree, error) { +func newTestCommitScriptTree(local bool, + auxLeaf AuxTapLeaf) (*testCommitScriptTree, error) { + selfKey, err := btcec.NewPrivateKey() if err != nil { return nil, err @@ -925,10 +970,11 @@ func newTestCommitScriptTree(local bool) (*testCommitScriptTree, error) { if local { commitScriptTree, err = NewLocalCommitScriptTree( csvDelay, selfKey.PubKey(), revokeKey.PubKey(), + auxLeaf, ) } else { commitScriptTree, err = NewRemoteCommitScriptTree( - selfKey.PubKey(), + selfKey.PubKey(), auxLeaf, ) } if err != nil { @@ -1020,12 +1066,8 @@ func localCommitRevokeWitGen(sigHash txscript.SigHashType, } } -// TestTaprootCommitScriptToSelf tests that the taproot script for redeeming -// one's output after a force close behaves as expected. -func TestTaprootCommitScriptToSelf(t *testing.T) { - t.Parallel() - - commitScriptTree, err := newTestCommitScriptTree(true) +func testTaprootCommitScriptToSelf(t *testing.T, auxLeaf AuxTapLeaf) { + commitScriptTree, err := newTestCommitScriptTree(true, auxLeaf) require.NoError(t, err) spendTx := wire.NewMsgTx(2) @@ -1187,6 +1229,28 @@ func TestTaprootCommitScriptToSelf(t *testing.T) { } } +// TestTaprootCommitScriptToSelf tests that the taproot script for redeeming +// one's output after a force close behaves as expected. +func TestTaprootCommitScriptToSelf(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootCommitScriptToSelf(t, auxLeaf) + }) + } +} + func remoteCommitSweepWitGen(sigHash txscript.SigHashType, commitScriptTree *testCommitScriptTree) witnessGen { @@ -1220,12 +1284,8 @@ func remoteCommitSweepWitGen(sigHash txscript.SigHashType, } } -// TestTaprootCommitScriptRemote tests that the remote party can properly sweep -// their output after force close. -func TestTaprootCommitScriptRemote(t *testing.T) { - t.Parallel() - - commitScriptTree, err := newTestCommitScriptTree(false) +func testTaprootCommitScriptRemote(t *testing.T, auxLeaf AuxTapLeaf) { + commitScriptTree, err := newTestCommitScriptTree(false, auxLeaf) require.NoError(t, err) spendTx := wire.NewMsgTx(2) @@ -1364,6 +1424,28 @@ func TestTaprootCommitScriptRemote(t *testing.T) { } } +// TestTaprootCommitScriptRemote tests that the remote party can properly sweep +// their output after force close. +func TestTaprootCommitScriptRemote(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootCommitScriptRemote(t, auxLeaf) + }) + } +} + type testAnchorScriptTree struct { sweepKey *btcec.PrivateKey @@ -1599,25 +1681,21 @@ type testSecondLevelHtlcTree struct { tapScriptRoot []byte } -func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { +func newTestSecondLevelHtlcTree(t *testing.T, + auxLeaf AuxTapLeaf) *testSecondLevelHtlcTree { + delayKey, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } + require.NoError(t, err) revokeKey, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } + require.NoError(t, err) const csvDelay = 6 scriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey.PubKey(), csvDelay, + delayKey.PubKey(), csvDelay, auxLeaf, ) - if err != nil { - return nil, err - } + require.NoError(t, err) tapScriptRoot := scriptTree.RootNode.TapHash() @@ -1626,9 +1704,7 @@ func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { ) pkScript, err := PayToTaprootScript(htlcKey) - if err != nil { - return nil, err - } + require.NoError(t, err) const amt = 100 @@ -1643,7 +1719,7 @@ func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { amt: amt, scriptTree: scriptTree, tapScriptRoot: tapScriptRoot[:], - }, nil + } } func secondLevelHtlcSuccessWitGen(sigHash txscript.SigHashType, @@ -1713,13 +1789,8 @@ func secondLevelHtlcRevokeWitnessgen(sigHash txscript.SigHashType, } } -// TestTaprootSecondLevelHtlcScript tests that a channel peer can properly -// spend the second level HTLC script to resolve HTLCs. -func TestTaprootSecondLevelHtlcScript(t *testing.T) { - t.Parallel() - - htlcScriptTree, err := newTestSecondLevelHtlcTree() - require.NoError(t, err) +func testTaprootSecondLevelHtlcScript(t *testing.T, auxLeaf AuxTapLeaf) { + htlcScriptTree := newTestSecondLevelHtlcTree(t, auxLeaf) spendTx := wire.NewMsgTx(2) spendTx.AddTxIn(&wire.TxIn{}) @@ -1879,3 +1950,25 @@ func TestTaprootSecondLevelHtlcScript(t *testing.T) { }) } } + +// TestTaprootSecondLevelHtlcScript tests that a channel peer can properly +// spend the second level HTLC script to resolve HTLCs. +func TestTaprootSecondLevelHtlcScript(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootSecondLevelHtlcScript(t, auxLeaf) + }) + } +} diff --git a/lnd.go b/lnd.go index f4ef03a078b..9f628ff8125 100644 --- a/lnd.go +++ b/lnd.go @@ -437,7 +437,8 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, defer cleanUp() partialChainControl, walletConfig, cleanUp, err := implCfg.BuildWalletConfig( - ctx, dbs, interceptorChain, grpcListeners, + ctx, dbs, &implCfg.AuxComponents, interceptorChain, + grpcListeners, ) if err != nil { return mkErr("error creating wallet config: %v", err) @@ -580,7 +581,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, server, err := newServer( cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, activeChainControl.Cfg.WalletUnlockParams.ChansToRestore, - multiAcceptor, torController, tlsManager, + multiAcceptor, torController, tlsManager, implCfg, ) if err != nil { return mkErr("unable to create server: %v", err) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 67bf6506c26..b048a37b43a 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -32,6 +32,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -334,6 +335,10 @@ type PaymentDescriptor struct { // NOTE: Populated only on add payment descriptor entry types. OnionBlob []byte + // CustomRecrods also stores the set of optional custom records that + // may have been attached to a sent HTLC. + CustomRecords fn.Option[tlv.Blob] + // ShaOnionBlob is a sha of the onion blob. // // NOTE: Populated only in payment descriptor with MalformedFail type. @@ -546,6 +551,10 @@ type commitment struct { // on this commitment transaction. incomingHTLCs []PaymentDescriptor + // customBlob stores opaque bytes that may be used by custom channels + // to store extra data for a given commitment state. + customBlob fn.Option[tlv.Blob] + // [outgoing|incoming]HTLCIndex is an index that maps an output index // on the commitment transaction to the payment descriptor that // represents the HTLC output. @@ -726,6 +735,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { CommitTx: c.txn, CommitSig: c.sig, Htlcs: make([]channeldb.HTLC, 0, numHtlcs), + CustomBlob: c.customBlob, } for _, htlc := range c.outgoingHTLCs { @@ -746,6 +756,12 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { } copy(h.OnionBlob[:], htlc.OnionBlob) + // If the HTLC had custom records, then we'll copy that over so + // we restore with the same information. + htlc.CustomRecords.WhenSome(func(b tlv.Blob) { + copy(h.ExtraData, b) + }) + if ourCommit && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -770,6 +786,13 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { BlindingPoint: htlc.BlindingPoint, } copy(h.OnionBlob[:], htlc.OnionBlob) + + // If the HTLC had custom records, then we'll copy that over so + // we restore with the same information. + htlc.CustomRecords.WhenSome(func(b tlv.Blob) { + copy(h.ExtraData, b) + }) + if ourCommit && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -787,8 +810,8 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { // restart a channel session. func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, commitHeight uint64, htlc *channeldb.HTLC, localCommitKeys, - remoteCommitKeys *CommitmentKeyRing, isLocal bool) (PaymentDescriptor, - error) { + remoteCommitKeys *CommitmentKeyRing, isLocal bool, + auxLeaf input.AuxTapLeaf) (PaymentDescriptor, error) { // The proper pkScripts for this PaymentDescriptor must be // generated so we can easily locate them within the commitment @@ -812,7 +835,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, if !isDustLocal && localCommitKeys != nil { scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, true, htlc.RefundTimeout, - htlc.RHash, localCommitKeys, + htlc.RHash, localCommitKeys, auxLeaf, ) if err != nil { return pd, err @@ -827,7 +850,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, if !isDustRemote && remoteCommitKeys != nil { scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, false, htlc.RefundTimeout, - htlc.RHash, remoteCommitKeys, + htlc.RHash, remoteCommitKeys, auxLeaf, ) if err != nil { return pd, err @@ -852,7 +875,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // With the scripts reconstructed (depending on if this is our commit // vs theirs or a pending commit for the remote party), we can now // re-create the original payment descriptor. - return PaymentDescriptor{ + pd = PaymentDescriptor{ RHash: htlc.RHash, Timeout: htlc.RefundTimeout, Amount: htlc.Amt, @@ -867,7 +890,15 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, theirPkScript: theirP2WSH, theirWitnessScript: theirWitnessScript, BlindingPoint: htlc.BlindingPoint, - }, nil + } + + // Ensure that we'll restore any custom records which were stored as + // extra data on disk. + if len(htlc.ExtraData) != 0 { + pd.CustomRecords = fn.Some[tlv.Blob](htlc.ExtraData) + } + + return pd, nil } // extractPayDescs will convert all HTLC's present within a disk commit state @@ -876,7 +907,8 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // for each side. func (lc *LightningChannel) extractPayDescs(commitHeight uint64, feeRate chainfee.SatPerKWeight, htlcs []channeldb.HTLC, localCommitKeys, - remoteCommitKeys *CommitmentKeyRing, isLocal bool) ([]PaymentDescriptor, + remoteCommitKeys *CommitmentKeyRing, isLocal bool, + auxLeaves fn.Option[CommitAuxLeaves]) ([]PaymentDescriptor, []PaymentDescriptor, error) { var ( @@ -894,10 +926,21 @@ func (lc *LightningChannel) extractPayDescs(commitHeight uint64, htlc := htlc + auxLeaf := fn.MapOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + if htlc.Incoming { + leaves = l.IncomingHtlcLeaves + } + + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(auxLeaves) + payDesc, err := lc.diskHtlcToPayDesc( feeRate, commitHeight, &htlc, localCommitKeys, remoteCommitKeys, - isLocal, + isLocal, fn.FlattenOption(auxLeaf), ) if err != nil { return incomingHtlcs, outgoingHtlcs, err @@ -941,14 +984,28 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, ) } + auxLeaves, err := AuxLeavesFromCommit( + lc.channelState, *diskCommit, lc.leafStore, + func() CommitmentKeyRing { + if isLocal { + return *localCommitKeys + } + + return *remoteCommitKeys + }(), + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // With the key rings re-created, we'll now convert all the on-disk // HTLC"s into PaymentDescriptor's so we can re-insert them into our // update log. incomingHtlcs, outgoingHtlcs, err := lc.extractPayDescs( diskCommit.CommitHeight, chainfee.SatPerKWeight(diskCommit.FeePerKw), - diskCommit.Htlcs, localCommitKeys, remoteCommitKeys, - isLocal, + diskCommit.Htlcs, localCommitKeys, remoteCommitKeys, isLocal, + auxLeaves, ) if err != nil { return nil, err @@ -971,6 +1028,7 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, feePerKw: chainfee.SatPerKWeight(diskCommit.FeePerKw), incomingHTLCs: incomingHtlcs, outgoingHTLCs: outgoingHtlcs, + customBlob: diskCommit.CustomBlob, } if isLocal { commit.dustLimit = lc.channelState.LocalChanCfg.DustLimit @@ -1262,6 +1320,10 @@ type LightningChannel struct { // machine. Signer input.Signer + // leafStore is used to retrieve extra tapscript leaves for special + // custom channel types. + leafStore fn.Option[AuxLeafStore] + // signDesc is the primary sign descriptor that is capable of signing // the commitment transaction that spends the multi-sig output. signDesc *input.SignDescriptor @@ -1337,6 +1399,8 @@ type channelOpts struct { localNonce *musig2.Nonces remoteNonce *musig2.Nonces + leafStore fn.Option[AuxLeafStore] + skipNonceInit bool } @@ -1367,6 +1431,13 @@ func WithSkipNonceInit() ChannelOpt { } } +// WithLeafStore is used to specify a custom leaf store for the channel. +func WithLeafStore(store AuxLeafStore) ChannelOpt { + return func(o *channelOpts) { + o.leafStore = fn.Some[AuxLeafStore](store) + } +} + // defaultChannelOpts returns the set of default options for a new channel. func defaultChannelOpts() *channelOpts { return &channelOpts{} @@ -1408,13 +1479,16 @@ func NewLightningChannel(signer input.Signer, } lc := &LightningChannel{ - Signer: signer, - sigPool: sigPool, - currentHeight: localCommit.CommitHeight, - remoteCommitChain: newCommitmentChain(), - localCommitChain: newCommitmentChain(), - channelState: state, - commitBuilder: NewCommitmentBuilder(state), + Signer: signer, + leafStore: opts.leafStore, + sigPool: sigPool, + currentHeight: localCommit.CommitHeight, + remoteCommitChain: newCommitmentChain(), + localCommitChain: newCommitmentChain(), + channelState: state, + commitBuilder: NewCommitmentBuilder( + state, opts.leafStore, + ), localUpdateLog: localUpdateLog, remoteUpdateLog: remoteUpdateLog, Capacity: state.Capacity, @@ -1535,7 +1609,8 @@ func (lc *LightningChannel) ResetState() { func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, remoteUpdateLog *updateLog, commitHeight uint64, feeRate chainfee.SatPerKWeight, remoteCommitKeys *CommitmentKeyRing, - remoteDustLimit btcutil.Amount) (*PaymentDescriptor, error) { + remoteDustLimit btcutil.Amount, + auxLeaves fn.Option[CommitAuxLeaves]) (*PaymentDescriptor, error) { // Depending on the type of update message we'll map that to a distinct // PaymentDescriptor instance. @@ -1571,10 +1646,17 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, wireMsg.Amount.ToSatoshis(), remoteDustLimit, ) if !isDustRemote { + auxLeaf := fn.MapOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + return leaves[pd.HtlcIndex].AuxTapLeaf + }, + )(auxLeaves) + scriptInfo, err := genHtlcScript( lc.channelState.ChanType, false, false, wireMsg.Expiry, wireMsg.PaymentHash, - remoteCommitKeys, + remoteCommitKeys, fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -2242,6 +2324,14 @@ func (lc *LightningChannel) restorePendingLocalUpdates( pendingCommit := pendingRemoteCommitDiff.Commitment pendingHeight := pendingCommit.CommitHeight + auxLeaves, err := AuxLeavesFromCommit( + lc.channelState, pendingCommit, lc.leafStore, + *pendingRemoteKeys, + ) + if err != nil { + return fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // If we did have a dangling commit, then we'll examine which updates // we included in that state and re-insert them into our update log. for _, logUpdate := range pendingRemoteCommitDiff.LogUpdates { @@ -2251,7 +2341,7 @@ func (lc *LightningChannel) restorePendingLocalUpdates( &logUpdate, lc.remoteUpdateLog, pendingHeight, chainfee.SatPerKWeight(pendingCommit.FeePerKw), pendingRemoteKeys, - lc.channelState.RemoteChanCfg.DustLimit, + lc.channelState.RemoteChanCfg.DustLimit, auxLeaves, ) if err != nil { return err @@ -2414,7 +2504,8 @@ type BreachRetribution struct { // required to construct the BreachRetribution. If the revocation log is missing // the required fields then ErrRevLogDataMissing will be returned. func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, - breachHeight uint32, spendTx *wire.MsgTx) (*BreachRetribution, error) { + breachHeight uint32, spendTx *wire.MsgTx, + leafStore fn.Option[AuxLeafStore]) (*BreachRetribution, error) { // Query the on-disk revocation log for the snapshot which was recorded // at this particular state num. Based on whether a legacy revocation @@ -2457,21 +2548,35 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, leaseExpiry = chanState.ThawHeight } + auxLeaves, err := auxLeavesFromRevocation( + chanState, revokedLog, leafStore, *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // Since it is the remote breach we are reconstructing, the output // going to us will be a to-remote script with our local params. + localAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + })(auxLeaves) isRemoteInitiator := !chanState.IsInitiator ourScript, ourDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, fn.FlattenOption(localAuxLeaf), ) if err != nil { return nil, err } + remoteAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + })(auxLeaves) theirDelay := uint32(chanState.RemoteChanCfg.CsvDelay) theirScript, err := CommitScriptToSelf( chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey, keyRing.RevocationKey, theirDelay, leaseExpiry, + fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return nil, err @@ -2489,7 +2594,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, if revokedLog != nil { br, ourAmt, theirAmt, err = createBreachRetribution( revokedLog, spendTx, chanState, keyRing, - commitmentSecret, leaseExpiry, + commitmentSecret, leaseExpiry, auxLeaves, ) if err != nil { return nil, err @@ -2623,7 +2728,8 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, func createHtlcRetribution(chanState *channeldb.OpenChannel, keyRing *CommitmentKeyRing, commitHash chainhash.Hash, commitmentSecret *btcec.PrivateKey, leaseExpiry uint32, - htlc *channeldb.HTLCEntry) (HtlcRetribution, error) { + htlc *channeldb.HTLCEntry, + auxLeaves fn.Option[CommitAuxLeaves]) (HtlcRetribution, error) { var emptyRetribution HtlcRetribution @@ -2633,10 +2739,21 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // We'll generate the original second level witness script now, as // we'll need it if we're revoking an HTLC output on the remote // commitment transaction, and *they* go to the second level. + secondLevelAuxLeaf := fn.MapOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + idx := input.HtlcIndex(htlc.HtlcIndex.Val) + + if htlc.Incoming.Val { + return l.IncomingHtlcLeaves[idx].SecondLevelLeaf + } + + return l.OutgoingHtlcLeaves[idx].SecondLevelLeaf + }, + )(auxLeaves) secondLevelScript, err := SecondLevelHtlcScript( chanState.ChanType, isRemoteInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, theirDelay, - leaseExpiry, + leaseExpiry, fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return emptyRetribution, err @@ -2647,9 +2764,19 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // HTLC script. Otherwise, is this was an outgoing HTLC that we sent, // then from the PoV of the remote commitment state, they're the // receiver of this HTLC. + htlcLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + idx := input.HtlcIndex(htlc.HtlcIndex.Val) + + if htlc.Incoming.Val { + return l.IncomingHtlcLeaves[idx].AuxTapLeaf + } + + return l.OutgoingHtlcLeaves[idx].AuxTapLeaf + })(auxLeaves) scriptInfo, err := genHtlcScript( - chanState.ChanType, htlc.Incoming, false, - htlc.RefundTimeout, htlc.RHash, keyRing, + chanState.ChanType, htlc.Incoming.Val, false, + htlc.RefundTimeout.Val, htlc.RHash.Val, keyRing, + fn.FlattenOption(htlcLeaf), ) if err != nil { return emptyRetribution, err @@ -2662,7 +2789,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, WitnessScript: scriptInfo.WitnessScriptToSign(), Output: &wire.TxOut{ PkScript: scriptInfo.PkScript(), - Value: int64(htlc.Amt), + Value: int64(htlc.Amt.Val.Int()), }, HashType: sweepSigHash(chanState.ChanType), } @@ -2695,10 +2822,10 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, SignDesc: signDesc, OutPoint: wire.OutPoint{ Hash: commitHash, - Index: uint32(htlc.OutputIndex), + Index: uint32(htlc.OutputIndex.Val), }, SecondLevelWitnessScript: secondLevelWitnessScript, - IsIncoming: htlc.Incoming, + IsIncoming: htlc.Incoming.Val, SecondLevelTapTweak: secondLevelTapTweak, }, nil } @@ -2713,7 +2840,9 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, func createBreachRetribution(revokedLog *channeldb.RevocationLog, spendTx *wire.MsgTx, chanState *channeldb.OpenChannel, keyRing *CommitmentKeyRing, commitmentSecret *btcec.PrivateKey, - leaseExpiry uint32) (*BreachRetribution, int64, int64, error) { + leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*BreachRetribution, int64, int64, + error) { commitHash := revokedLog.CommitTxHash @@ -2721,8 +2850,8 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, htlcRetributions := make([]HtlcRetribution, len(revokedLog.HTLCEntries)) for i, htlc := range revokedLog.HTLCEntries { hr, err := createHtlcRetribution( - chanState, keyRing, commitHash, - commitmentSecret, leaseExpiry, htlc, + chanState, keyRing, commitHash.Val, + commitmentSecret, leaseExpiry, htlc, auxLeaves, ) if err != nil { return nil, 0, 0, err @@ -2734,10 +2863,10 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // Construct the our outpoint. ourOutpoint := wire.OutPoint{ - Hash: commitHash, + Hash: commitHash.Val, } - if revokedLog.OurOutputIndex != channeldb.OutputIndexEmpty { - ourOutpoint.Index = uint32(revokedLog.OurOutputIndex) + if revokedLog.OurOutputIndex.Val != channeldb.OutputIndexEmpty { + ourOutpoint.Index = uint32(revokedLog.OurOutputIndex.Val) // If the spend transaction is provided, then we use it to get // the value of our output. @@ -2760,26 +2889,29 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // contains our output amount. Due to a previous // migration, this field may be empty in which case an // error will be returned. - if revokedLog.OurBalance == nil { - return nil, 0, 0, ErrRevLogDataMissing + b, err := revokedLog.OurBalance.ValOpt().UnwrapOrErr( + ErrRevLogDataMissing, + ) + if err != nil { + return nil, 0, 0, err } - ourAmt = int64(revokedLog.OurBalance.ToSatoshis()) + ourAmt = int64(b.Int().ToSatoshis()) } } // Construct the their outpoint. theirOutpoint := wire.OutPoint{ - Hash: commitHash, + Hash: commitHash.Val, } - if revokedLog.TheirOutputIndex != channeldb.OutputIndexEmpty { - theirOutpoint.Index = uint32(revokedLog.TheirOutputIndex) + if revokedLog.TheirOutputIndex.Val != channeldb.OutputIndexEmpty { + theirOutpoint.Index = uint32(revokedLog.TheirOutputIndex.Val) // If the spend transaction is provided, then we use it to get // the value of the remote parties' output. if spendTx != nil { // Sanity check that TheirOutputIndex is within range. - if int(revokedLog.TheirOutputIndex) >= + if int(revokedLog.TheirOutputIndex.Val) >= len(spendTx.TxOut) { return nil, 0, 0, fmt.Errorf("%w: theirs=%v, "+ @@ -2797,16 +2929,19 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // contains remote parties' output amount. Due to a // previous migration, this field may be empty in which // case an error will be returned. - if revokedLog.TheirBalance == nil { - return nil, 0, 0, ErrRevLogDataMissing + b, err := revokedLog.TheirBalance.ValOpt().UnwrapOrErr( + ErrRevLogDataMissing, + ) + if err != nil { + return nil, 0, 0, err } - theirAmt = int64(revokedLog.TheirBalance.ToSatoshis()) + theirAmt = int64(b.Int().ToSatoshis()) } } return &BreachRetribution{ - BreachTxHash: commitHash, + BreachTxHash: commitHash.Val, ChainHash: chanState.ChainHash, LocalOutpoint: ourOutpoint, RemoteOutpoint: theirOutpoint, @@ -2860,16 +2995,11 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, continue } - entry := &channeldb.HTLCEntry{ - RHash: htlc.RHash, - RefundTimeout: htlc.RefundTimeout, - OutputIndex: uint16(htlc.OutputIndex), - Incoming: htlc.Incoming, - Amt: htlc.Amt.ToSatoshis(), - } + entry := channeldb.NewHTLCEntryFromHTLC(htlc) hr, err := createHtlcRetribution( chanState, keyRing, commitHash, commitmentSecret, leaseExpiry, entry, + fn.None[CommitAuxLeaves](), ) if err != nil { return nil, 0, 0, err @@ -2934,18 +3064,29 @@ func HtlcIsDust(chanType channeldb.ChannelType, return (htlcAmt - htlcFee) < dustLimit } -// htlcView represents the "active" HTLCs at a particular point within the +// HtlcView represents the "active" HTLCs at a particular point within the // history of the HTLC update log. -type htlcView struct { - ourUpdates []*PaymentDescriptor - theirUpdates []*PaymentDescriptor - feePerKw chainfee.SatPerKWeight +type HtlcView struct { + // NextHeight is the height of the commitment transaction that will be + // created using this view. + NextHeight uint64 + + // OurUpdates are our outgoing HTLCs. + OurUpdates []*PaymentDescriptor + + // TheirUpdates are their incoming HTLCs. + TheirUpdates []*PaymentDescriptor + + // FeePerKw is the fee rate in sat/kw of the commitment transaction. + FeePerKw chainfee.SatPerKWeight } // fetchHTLCView returns all the candidate HTLC updates which should be // considered for inclusion within a commitment based on the passed HTLC log // indexes. -func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *htlcView { +func (lc *LightningChannel) fetchHTLCView(theirLogIndex, + ourLogIndex uint64) *HtlcView { + var ourHTLCs []*PaymentDescriptor for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() { htlc := e.Value.(*PaymentDescriptor) @@ -2970,9 +3111,9 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *ht } } - return &htlcView{ - ourUpdates: ourHTLCs, - theirUpdates: theirHTLCs, + return &HtlcView{ + OurUpdates: ourHTLCs, + TheirUpdates: theirHTLCs, } } @@ -3007,12 +3148,15 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, if err != nil { return nil, err } - feePerKw := filteredHTLCView.feePerKw + feePerKw := filteredHTLCView.FeePerKw + + htlcView.NextHeight = nextHeight + filteredHTLCView.NextHeight = nextHeight // Actually generate unsigned commitment transaction for this view. commitTx, err := lc.commitBuilder.createUnsignedCommitmentTx( ourBalance, theirBalance, !remoteChain, feePerKw, nextHeight, - filteredHTLCView, keyRing, + htlcView, filteredHTLCView, keyRing, commitChain.tip(), ) if err != nil { return nil, err @@ -3047,6 +3191,16 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, effFeeRate, spew.Sdump(commitTx)) } + // Given the custom blob of the past state, and this new HTLC view, + // we'll generate a new blob for the latest commitment. + newCommitBlob, err := updateAuxBlob( + lc.channelState, commitChain.tip().customBlob, htlcView, + !remoteChain, ourBalance, theirBalance, lc.leafStore, *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // With the commitment view created, store the resulting balances and // transaction with the other parameters for this height. c := &commitment{ @@ -3062,17 +3216,22 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, feePerKw: feePerKw, dustLimit: dustLimit, isOurs: !remoteChain, + customBlob: newCommitBlob, } // In order to ensure _none_ of the HTLC's associated with this new // commitment are mutated, we'll manually copy over each HTLC to its // respective slice. - c.outgoingHTLCs = make([]PaymentDescriptor, len(filteredHTLCView.ourUpdates)) - for i, htlc := range filteredHTLCView.ourUpdates { + c.outgoingHTLCs = make( + []PaymentDescriptor, len(filteredHTLCView.OurUpdates), + ) + for i, htlc := range filteredHTLCView.OurUpdates { c.outgoingHTLCs[i] = *htlc } - c.incomingHTLCs = make([]PaymentDescriptor, len(filteredHTLCView.theirUpdates)) - for i, htlc := range filteredHTLCView.theirUpdates { + c.incomingHTLCs = make( + []PaymentDescriptor, len(filteredHTLCView.TheirUpdates), + ) + for i, htlc := range filteredHTLCView.TheirUpdates { c.incomingHTLCs[i] = *htlc } @@ -3107,15 +3266,16 @@ func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn { // once for each height, and only in concert with signing a new commitment. // TODO(halseth): return htlcs to mutate instead of mutating inside // method. -func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, +func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - remoteChain, mutateState bool) (*htlcView, error) { + remoteChain, mutateState bool) (*HtlcView, error) { // We initialize the view's fee rate to the fee rate of the unfiltered // view. If any fee updates are found when evaluating the view, it will // be updated. - newView := &htlcView{ - feePerKw: view.feePerKw, + newView := &HtlcView{ + FeePerKw: view.FeePerKw, + NextHeight: nextHeight, } // We use two maps, one for the local log and one for the remote log to @@ -3128,7 +3288,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // First we run through non-add entries in both logs, populating the // skip sets and mutating the current chain state (crediting balances, // etc) to reflect the settle/timeout entry encountered. - for _, entry := range view.ourUpdates { + for _, entry := range view.OurUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -3148,6 +3308,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // number of satoshis we've received within the channel. if mutateState && entry.EntryType == Settle && !remoteChain && entry.removeCommitHeightLocal == 0 { + lc.channelState.TotalMSatReceived += entry.Amount } @@ -3157,10 +3318,13 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } skipThem[addEntry.HtlcIndex] = struct{}{} - processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, remoteChain, true, mutateState) + + processRemoveEntry( + entry, ourBalance, theirBalance, nextHeight, + remoteChain, true, mutateState, + ) } - for _, entry := range view.theirUpdates { + for _, entry := range view.TheirUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -3190,32 +3354,41 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } skipUs[addEntry.HtlcIndex] = struct{}{} - processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, remoteChain, false, mutateState) + + processRemoveEntry( + entry, ourBalance, theirBalance, nextHeight, + remoteChain, false, mutateState, + ) } // Next we take a second pass through all the log entries, skipping any // settled HTLCs, and debiting the chain state balance due to any newly // added HTLCs. - for _, entry := range view.ourUpdates { + for _, entry := range view.OurUpdates { isAdd := entry.EntryType == Add if _, ok := skipUs[entry.HtlcIndex]; !isAdd || ok { continue } - processAddEntry(entry, ourBalance, theirBalance, nextHeight, - remoteChain, false, mutateState) - newView.ourUpdates = append(newView.ourUpdates, entry) + processAddEntry( + entry, ourBalance, theirBalance, nextHeight, + remoteChain, false, mutateState, + ) + + newView.OurUpdates = append(newView.OurUpdates, entry) } - for _, entry := range view.theirUpdates { + for _, entry := range view.TheirUpdates { isAdd := entry.EntryType == Add if _, ok := skipThem[entry.HtlcIndex]; !isAdd || ok { continue } - processAddEntry(entry, ourBalance, theirBalance, nextHeight, - remoteChain, true, mutateState) - newView.theirUpdates = append(newView.theirUpdates, entry) + processAddEntry( + entry, ourBalance, theirBalance, nextHeight, + remoteChain, true, mutateState, + ) + + newView.TheirUpdates = append(newView.TheirUpdates, entry) } return newView, nil @@ -3305,6 +3478,9 @@ func processAddEntry(htlc *PaymentDescriptor, ourBalance, theirBalance *lnwire.M *ourBalance -= htlc.Amount } + // TODO(roasbef): also have it modify balances here + // * obtain for HTLC as well? + if mutateState { *addHeight = nextHeight } @@ -3363,7 +3539,7 @@ func processRemoveEntry(htlc *PaymentDescriptor, ourBalance, // processFeeUpdate processes a log update that updates the current commitment // fee. func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, - remoteChain bool, mutateState bool, view *htlcView) { + remoteChain bool, mutateState bool, view *HtlcView) { // Fee updates are applied for all commitments after they are // sent/received, so we consider them being added and removed at the @@ -3384,7 +3560,7 @@ func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, // If the update wasn't already locked in, update the current fee rate // to reflect this update. - view.feePerKw = chainfee.SatPerKWeight(feeUpdate.Amount.ToSatoshis()) + view.FeePerKw = chainfee.SatPerKWeight(feeUpdate.Amount.ToSatoshis()) if mutateState { *addHeight = nextHeight @@ -3399,9 +3575,16 @@ func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, // signature can be submitted to the sigPool to generate all the signatures // asynchronously and in parallel. func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, - chanType channeldb.ChannelType, isRemoteInitiator bool, - leaseExpiry uint32, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, - remoteCommitView *commitment) ([]SignJob, chan struct{}, error) { + chanState *channeldb.OpenChannel, leaseExpiry uint32, + remoteCommitView *commitment, + leafStore fn.Option[AuxLeafStore]) ([]SignJob, chan struct{}, error) { + + var ( + isRemoteInitiator = !chanState.IsInitiator + localChanCfg = chanState.LocalChanCfg + remoteChanCfg = chanState.RemoteChanCfg + chanType = chanState.ChanType + ) txHash := remoteCommitView.txn.TxHash() dustLimit := remoteChanCfg.DustLimit @@ -3418,6 +3601,15 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, var err error cancelChan := make(chan struct{}) + auxLeaves, err := AuxLeavesFromCommit( + chanState, *remoteCommitView.toDiskCommit(false), leafStore, + *keyRing, + ) + if err != nil { + return nil, nil, fmt.Errorf("unable to fetch aux leaves: "+ + "%w", err) + } + // For each outgoing and incoming HTLC, if the HTLC isn't considered a // dust output after taking into account second-level HTLC fees, then a // sigJob will be generated and appended to the current batch. @@ -3444,6 +3636,14 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.MapOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.IncomingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) + // With the fee calculate, we can properly create the HTLC // timeout transaction using the HTLC amount minus the fee. op := wire.OutPoint{ @@ -3454,11 +3654,14 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, chanType, isRemoteInitiator, op, outputAmt, htlc.Timeout, uint32(remoteChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + fn.FlattenOption(auxLeaf), ) if err != nil { return nil, nil, err } + // TODO(roasbeef): hook up signer interface here + // Construct a full hash cache as we may be signing a segwit v1 // sighash. txOut := remoteCommitView.txn.TxOut[htlc.remoteOutputIndex] @@ -3485,7 +3688,8 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // If this is a taproot channel, then we'll need to set the // method type to ensure we generate a valid signature. if chanType.IsTaproot() { - sigJob.SignDesc.SignMethod = input.TaprootScriptSpendSignMethod //nolint:lll + //nolint:lll + sigJob.SignDesc.SignMethod = input.TaprootScriptSpendSignMethod } sigBatch = append(sigBatch, sigJob) @@ -3511,6 +3715,14 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.MapOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.OutgoingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) + // With the proper output amount calculated, we can now // generate the success transaction using the remote party's // CSV delay. @@ -3522,6 +3734,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, chanType, isRemoteInitiator, op, outputAmt, uint32(remoteChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + fn.FlattenOption(auxLeaf), ) if err != nil { return nil, nil, err @@ -3567,9 +3780,9 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // validate this new state. This function is called right before sending the // new commitment to the remote party. The commit diff returned contains all // information necessary for retransmission. -func (lc *LightningChannel) createCommitDiff( - newCommit *commitment, commitSig lnwire.Sig, - htlcSigs []lnwire.Sig) (*channeldb.CommitDiff, error) { +func (lc *LightningChannel) createCommitDiff(newCommit *commitment, + commitSig lnwire.Sig, htlcSigs []lnwire.Sig) (*channeldb.CommitDiff, + error) { // First, we need to convert the funding outpoint into the ID that's // used on the wire to identify this channel. We'll use this shortly @@ -3959,10 +4172,10 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // appropriate update log, in order to validate the sanity of the // commitment resulting from _actually adding_ this HTLC to the state. if predictOurAdd != nil { - view.ourUpdates = append(view.ourUpdates, predictOurAdd) + view.OurUpdates = append(view.OurUpdates, predictOurAdd) } if predictTheirAdd != nil { - view.theirUpdates = append(view.theirUpdates, predictTheirAdd) + view.TheirUpdates = append(view.TheirUpdates, predictTheirAdd) } ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView( @@ -3972,7 +4185,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, return err } - feePerKw := filteredView.feePerKw + feePerKw := filteredView.FeePerKw // Ensure that the fee being applied is enough to be relayed across the // network in a reasonable time frame. @@ -4116,7 +4329,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // First check that the remote updates won't violate it's channel // constraints. err = validateUpdates( - filteredView.theirUpdates, &lc.channelState.RemoteChanCfg, + filteredView.TheirUpdates, &lc.channelState.RemoteChanCfg, ) if err != nil { return err @@ -4125,7 +4338,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // Secondly check that our updates won't violate our channel // constraints. err = validateUpdates( - filteredView.ourUpdates, &lc.channelState.LocalChanCfg, + filteredView.OurUpdates, &lc.channelState.LocalChanCfg, ) if err != nil { return err @@ -4268,9 +4481,8 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { leaseExpiry = lc.channelState.ThawHeight } sigBatch, cancelChan, err := genRemoteHtlcSigJobs( - keyRing, lc.channelState.ChanType, !lc.channelState.IsInitiator, - leaseExpiry, &lc.channelState.LocalChanCfg, - &lc.channelState.RemoteChanCfg, newCommitView, + keyRing, lc.channelState, leaseExpiry, newCommitView, + lc.leafStore, ) if err != nil { return nil, err @@ -4727,7 +4939,7 @@ func (lc *LightningChannel) ProcessChanSyncMsg( return updates, openedCircuits, closedCircuits, nil } -// computeView takes the given htlcView, and calculates the balances, filtered +// computeView takes the given HtlcView, and calculates the balances, filtered // view (settling unsettled HTLCs), commitment weight and feePerKw, after // applying the HTLCs to the latest commitment. The returned balances are the // balances *before* subtracting the commitment fee from the initiator's @@ -4735,9 +4947,9 @@ func (lc *LightningChannel) ProcessChanSyncMsg( // // If the updateState boolean is set true, the add and remove heights of the // HTLCs will be set to the next commitment height. -func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, +func (lc *LightningChannel) computeView(view *HtlcView, remoteChain bool, updateState bool) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, int64, - *htlcView, error) { + *HtlcView, error) { commitChain := lc.localCommitChain dustLimit := lc.channelState.LocalChanCfg.DustLimit @@ -4768,7 +4980,10 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, // Initiate feePerKw to the last committed fee for this chain as we'll // need this to determine which HTLCs are dust, and also the final fee // rate. - view.feePerKw = commitChain.tip().feePerKw + view.FeePerKw = commitChain.tip().feePerKw + + // TODO(roasbeef): also need to pass blob here as well for final + // balances? // We evaluate the view at this stage, meaning settled and failed HTLCs // will remove their corresponding added HTLCs. The resulting filtered @@ -4776,12 +4991,14 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, // channel constraints to the final commitment state. If any fee // updates are found in the logs, the commitment fee rate should be // changed, so we'll also set the feePerKw to this new value. - filteredHTLCView, err := lc.evaluateHTLCView(view, &ourBalance, - &theirBalance, nextHeight, remoteChain, updateState) + filteredHTLCView, err := lc.evaluateHTLCView( + view, &ourBalance, &theirBalance, nextHeight, remoteChain, + updateState, + ) if err != nil { return 0, 0, 0, nil, err } - feePerKw := filteredHTLCView.feePerKw + feePerKw := filteredHTLCView.FeePerKw // We need to first check ourBalance and theirBalance to be negative // because MilliSathoshi is a unsigned type and can underflow in @@ -4799,7 +5016,7 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, // Now go through all HTLCs at this stage, to calculate the total // weight, needed to calculate the transaction fee. var totalHtlcWeight int64 - for _, htlc := range filteredHTLCView.ourUpdates { + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( lc.channelState.ChanType, false, !remoteChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -4810,7 +5027,7 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, totalHtlcWeight += input.HTLCWeight } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( lc.channelState.ChanType, true, !remoteChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -4831,10 +5048,18 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, // meant to verify all the signatures for HTLC's attached to a newly created // commitment state. The jobs generated are fully populated, and can be sent // directly into the pool of workers. -func genHtlcSigValidationJobs(localCommitmentView *commitment, - keyRing *CommitmentKeyRing, htlcSigs []lnwire.Sig, - chanType channeldb.ChannelType, isLocalInitiator bool, leaseExpiry uint32, - localChanCfg, remoteChanCfg *channeldb.ChannelConfig) ([]VerifyJob, error) { +// +//nolint:funlen +func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, + localCommitmentView *commitment, keyRing *CommitmentKeyRing, + htlcSigs []lnwire.Sig, leaseExpiry uint32, + leafStore fn.Option[AuxLeafStore]) ([]VerifyJob, error) { + + var ( + isLocalInitiator = chanState.IsInitiator + localChanCfg = chanState.LocalChanCfg + chanType = chanState.ChanType + ) txHash := localCommitmentView.txn.TxHash() feePerKw := localCommitmentView.feePerKw @@ -4848,6 +5073,15 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment, len(localCommitmentView.outgoingHTLCs)) verifyJobs := make([]VerifyJob, 0, numHtlcs) + auxLeaves, err := AuxLeavesFromCommit( + chanState, *localCommitmentView.toDiskCommit(true), leafStore, + *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", + err) + } + // We'll iterate through each output in the commitment transaction, // populating the sigHash closure function if it's detected to be an // HLTC output. Given the sighash, and the signing key, we'll be able @@ -4881,11 +5115,20 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.MapOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.IncomingHtlcLeaves + idx := htlc.HtlcIndex + return leaves[idx].SecondLevelLeaf + })(auxLeaves) + successTx, err := CreateHtlcSuccessTx( chanType, isLocalInitiator, op, outputAmt, uint32(localChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -4965,12 +5208,21 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.MapOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.OutgoingHtlcLeaves + idx := htlc.HtlcIndex + return leaves[idx].SecondLevelLeaf + })(auxLeaves) + timeoutTx, err := CreateHtlcTimeoutTx( chanType, isLocalInitiator, op, outputAmt, htlc.Timeout, uint32(localChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -5227,10 +5479,8 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { leaseExpiry = lc.channelState.ThawHeight } verifyJobs, err := genHtlcSigValidationJobs( - localCommitmentView, keyRing, commitSigs.HtlcSigs, - lc.channelState.ChanType, lc.channelState.IsInitiator, - leaseExpiry, &lc.channelState.LocalChanCfg, - &lc.channelState.RemoteChanCfg, + lc.channelState, localCommitmentView, keyRing, + commitSigs.HtlcSigs, leaseExpiry, lc.leafStore, ) if err != nil { return err @@ -5814,7 +6064,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( // before the change since the indexes are meant for the current, // revoked remote commitment. ourOutputIndex, theirOutputIndex, err := findOutputIndexesFromRemote( - revocation, lc.channelState, + revocation, lc.channelState, lc.leafStore, ) if err != nil { return nil, nil, nil, nil, err @@ -6088,7 +6338,7 @@ func (lc *LightningChannel) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error { func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, openKey *models.CircuitKey) *PaymentDescriptor { - return &PaymentDescriptor{ + pd := &PaymentDescriptor{ EntryType: Add, RHash: PaymentHash(htlc.PaymentHash), Timeout: htlc.Expiry, @@ -6099,6 +6349,14 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, OpenCircuitKey: openKey, BlindingPoint: htlc.BlindingPoint, } + + // Copy over any extra data included to ensure we can forward and + // process this HTLC properly. + if len(htlc.ExtraData) != 0 { + pd.CustomRecords = fn.Some[tlv.Blob](htlc.ExtraData[:]) + } + + return pd } // validateAddHtlc validates the addition of an outgoing htlc to our local and @@ -6158,6 +6416,12 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err BlindingPoint: htlc.BlindingPoint, } + // Copy over any extra data included to ensure we can forward and + // process this HTLC properly. + if htlc.ExtraData != nil { + pd.CustomRecords = fn.Some(tlv.Blob(htlc.ExtraData[:])) + } + localACKedIndex := lc.remoteCommitChain.tail().ourMessageIndex // Clamp down on the number of HTLC's we can receive by checking the @@ -6663,10 +6927,10 @@ type UnilateralCloseSummary struct { // happen in case we have lost state) it should be set to an empty struct, in // which case we will attempt to sweep the non-HTLC output using the passed // commitPoint. -func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Signer, - commitSpend *chainntnfs.SpendDetail, - remoteCommit channeldb.ChannelCommitment, - commitPoint *btcec.PublicKey) (*UnilateralCloseSummary, error) { +func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, + signer input.Signer, commitSpend *chainntnfs.SpendDetail, + remoteCommit channeldb.ChannelCommitment, commitPoint *btcec.PublicKey, + leafStore fn.Option[AuxLeafStore]) (*UnilateralCloseSummary, error) { // First, we'll generate the commitment point and the revocation point // so we can re-construct the HTLC state and also our payment key. @@ -6676,6 +6940,13 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) + auxLeaves, err := AuxLeavesFromCommit( + chanState, remoteCommit, leafStore, *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // Next, we'll obtain HTLC resolutions for all the outgoing HTLC's we // had on their commitment transaction. var leaseExpiry uint32 @@ -6687,7 +6958,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si chainfee.SatPerKWeight(remoteCommit.FeePerKw), isOurCommit, signer, remoteCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitSpend.SpendingTx, - chanState.ChanType, isRemoteInitiator, leaseExpiry, + chanState.ChanType, isRemoteInitiator, leaseExpiry, auxLeaves, ) if err != nil { return nil, fmt.Errorf("unable to create htlc "+ @@ -6699,9 +6970,14 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si // Before we can generate the proper sign descriptor, we'll need to // locate the output index of our non-delayed output on the commitment // transaction. + // + // TODO(roasbeef): helper func to hide flatten + remoteAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + })(auxLeaves) selfScript, maturityDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return nil, fmt.Errorf("unable to create self commit "+ @@ -6940,8 +7216,8 @@ func newOutgoingHtlcResolution(signer input.Signer, localChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, - localCommit, isCommitFromInitiator bool, - chanType channeldb.ChannelType) (*OutgoingHtlcResolution, error) { + localCommit, isCommitFromInitiator bool, chanType channeldb.ChannelType, + auxLeaves fn.Option[CommitAuxLeaves]) (*OutgoingHtlcResolution, error) { op := wire.OutPoint{ Hash: commitTx.TxHash(), @@ -6950,9 +7226,12 @@ func newOutgoingHtlcResolution(signer input.Signer, // First, we'll re-generate the script used to send the HTLC to the // remote party within their commitment transaction. + auxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.OutgoingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf + })(auxLeaves) htlcScriptInfo, err := genHtlcScript( chanType, false, localCommit, htlc.RefundTimeout, htlc.RHash, - keyRing, + keyRing, fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -7025,10 +7304,17 @@ func newOutgoingHtlcResolution(signer input.Signer, // With the fee calculated, re-construct the second level timeout // transaction. + secondLevelAuxLeaf := fn.MapOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) timeoutTx, err := CreateHtlcTimeoutTx( chanType, isCommitFromInitiator, op, secondLevelOutputAmt, - htlc.RefundTimeout, csvDelay, leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + htlc.RefundTimeout, csvDelay, leaseExpiry, + keyRing.RevocationKey, keyRing.ToLocalKey, + fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7111,6 +7397,7 @@ func newOutgoingHtlcResolution(signer input.Signer, htlcSweepScript, err = SecondLevelHtlcScript( chanType, isCommitFromInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, leaseExpiry, + fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7119,6 +7406,7 @@ func newOutgoingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, + fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7192,8 +7480,8 @@ func newIncomingHtlcResolution(signer input.Signer, localChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, - localCommit, isCommitFromInitiator bool, chanType channeldb.ChannelType) ( - *IncomingHtlcResolution, error) { + localCommit, isCommitFromInitiator bool, chanType channeldb.ChannelType, + auxLeaves fn.Option[CommitAuxLeaves]) (*IncomingHtlcResolution, error) { op := wire.OutPoint{ Hash: commitTx.TxHash(), @@ -7202,9 +7490,12 @@ func newIncomingHtlcResolution(signer input.Signer, // First, we'll re-generate the script the remote party used to // send the HTLC to us in their commitment transaction. + auxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.IncomingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf + })(auxLeaves) scriptInfo, err := genHtlcScript( chanType, true, localCommit, htlc.RefundTimeout, htlc.RHash, - keyRing, + keyRing, fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -7264,6 +7555,13 @@ func newIncomingHtlcResolution(signer input.Signer, }, nil } + secondLevelAuxLeaf := fn.MapOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.IncomingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) + // Otherwise, we'll need to go to the second level to sweep this HTLC. // // First, we'll reconstruct the original HTLC success transaction, @@ -7273,7 +7571,7 @@ func newIncomingHtlcResolution(signer input.Signer, successTx, err := CreateHtlcSuccessTx( chanType, isCommitFromInitiator, op, secondLevelOutputAmt, csvDelay, leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + keyRing.ToLocalKey, fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7356,6 +7654,7 @@ func newIncomingHtlcResolution(signer input.Signer, htlcSweepScript, err = SecondLevelHtlcScript( chanType, isCommitFromInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, leaseExpiry, + fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7364,6 +7663,7 @@ func newIncomingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, + fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7455,7 +7755,8 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, signer input.Signer, htlcs []channeldb.HTLC, keyRing *CommitmentKeyRing, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, chanType channeldb.ChannelType, - isCommitFromInitiator bool, leaseExpiry uint32) (*HtlcResolutions, error) { + isCommitFromInitiator bool, leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*HtlcResolutions, error) { // TODO(roasbeef): don't need to swap csv delay? dustLimit := remoteChanCfg.DustLimit @@ -7490,6 +7791,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, signer, localChanCfg, commitTx, &htlc, keyRing, feePerKw, uint32(csvDelay), leaseExpiry, ourCommit, isCommitFromInitiator, chanType, + auxLeaves, ) if err != nil { return nil, fmt.Errorf("incoming resolution "+ @@ -7503,7 +7805,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, ohr, err := newOutgoingHtlcResolution( signer, localChanCfg, commitTx, &htlc, keyRing, feePerKw, uint32(csvDelay), leaseExpiry, ourCommit, - isCommitFromInitiator, chanType, + isCommitFromInitiator, chanType, auxLeaves, ) if err != nil { return nil, fmt.Errorf("outgoing resolution "+ @@ -7603,7 +7905,7 @@ func (lc *LightningChannel) ForceClose() (*LocalForceCloseSummary, error) { localCommitment := lc.channelState.LocalCommitment summary, err := NewLocalForceCloseSummary( lc.channelState, lc.Signer, commitTx, - localCommitment.CommitHeight, + localCommitment.CommitHeight, lc.leafStore, ) if err != nil { return nil, fmt.Errorf("unable to gen force close "+ @@ -7620,8 +7922,8 @@ func (lc *LightningChannel) ForceClose() (*LocalForceCloseSummary, error) { // channel state. The passed commitTx must be a fully signed commitment // transaction corresponding to localCommit. func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, - signer input.Signer, commitTx *wire.MsgTx, stateNum uint64) ( - *LocalForceCloseSummary, error) { + signer input.Signer, commitTx *wire.MsgTx, stateNum uint64, + leafStore fn.Option[AuxLeafStore]) (*LocalForceCloseSummary, error) { // Re-derive the original pkScript for to-self output within the // commitment transaction. We'll need this to find the corresponding @@ -7642,6 +7944,8 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) + // TODO(roasbeef): fetch aux leave + var leaseExpiry uint32 if chanState.ChanType.HasLeaseExpiration() { leaseExpiry = chanState.ThawHeight @@ -7649,6 +7953,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, toLocalScript, err := CommitScriptToSelf( chanState.ChanType, chanState.IsInitiator, keyRing.ToLocalKey, keyRing.RevocationKey, csvTimeout, leaseExpiry, + input.NoneTapLeaf(), ) if err != nil { return nil, err @@ -7735,11 +8040,19 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, // recovery there is not much we can do with HTLCs, so we'll always // use what we have in our latest state when extracting resolutions. localCommit := chanState.LocalCommitment + + auxLeaves, err := AuxLeavesFromCommit( + chanState, localCommit, leafStore, *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + htlcResolutions, err := extractHtlcResolutions( chainfee.SatPerKWeight(localCommit.FeePerKw), true, signer, localCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitTx, chanState.ChanType, - chanState.IsInitiator, leaseExpiry, + chanState.IsInitiator, leaseExpiry, auxLeaves, ) if err != nil { return nil, fmt.Errorf("unable to gen htlc resolution: %w", err) @@ -8269,13 +8582,13 @@ func (lc *LightningChannel) availableBalance( } // availableCommitmentBalance attempts to calculate the balance we have -// available for HTLCs on the local/remote commitment given the htlcView. To +// available for HTLCs on the local/remote commitment given the HtlcView. To // account for sending HTLCs of different sizes, it will report the balance // available for sending non-dust HTLCs, which will be manifested on the // commitment, increasing the commitment fee we must pay as an initiator, // eating into our balance. It will make sure we won't violate the channel // reserve constraints for this amount. -func (lc *LightningChannel) availableCommitmentBalance(view *htlcView, +func (lc *LightningChannel) availableCommitmentBalance(view *HtlcView, remoteChain bool, buffer BufferType) (lnwire.MilliSatoshi, int64) { // Compute the current balances for this commitment. This will take @@ -8303,7 +8616,7 @@ func (lc *LightningChannel) availableCommitmentBalance(view *htlcView, // Calculate the commitment fee in the case where we would add another // HTLC to the commitment, as only the balance remaining after this fee // has been paid is actually available for sending. - feePerKw := filteredView.feePerKw + feePerKw := filteredView.FeePerKw additionalHtlcFee := lnwire.NewMSatFromSatoshis( feePerKw.FeeForWeight(input.HTLCWeight), ) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index c73abdbe0e3..4914fa13946 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -21,6 +21,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -5694,6 +5695,7 @@ func TestChannelUnilateralCloseHtlcResolution(t *testing.T) { spendDetail, aliceChannel.channelState.RemoteCommitment, aliceChannel.channelState.RemoteCurrentRevocation, + fn.None[AuxLeafStore](), ) require.NoError(t, err, "unable to create alice close summary") @@ -5843,6 +5845,7 @@ func TestChannelUnilateralClosePendingCommit(t *testing.T) { spendDetail, aliceChannel.channelState.RemoteCommitment, aliceChannel.channelState.RemoteCurrentRevocation, + fn.None[AuxLeafStore](), ) require.NoError(t, err, "unable to create alice close summary") @@ -5860,6 +5863,7 @@ func TestChannelUnilateralClosePendingCommit(t *testing.T) { spendDetail, aliceRemoteChainTip.Commitment, aliceChannel.channelState.RemoteNextRevocation, + fn.None[AuxLeafStore](), ) require.NoError(t, err, "unable to create alice close summary") @@ -6740,6 +6744,7 @@ func TestNewBreachRetributionSkipsDustHtlcs(t *testing.T) { breachTx := aliceChannel.channelState.RemoteCommitment.CommitTx breachRet, err := NewBreachRetribution( aliceChannel.channelState, revokedStateNum, 100, breachTx, + fn.None[AuxLeafStore](), ) require.NoError(t, err, "unable to create breach retribution") @@ -8540,10 +8545,10 @@ func TestEvaluateView(t *testing.T) { } } - view := &htlcView{ - ourUpdates: test.ourHtlcs, - theirUpdates: test.theirHtlcs, - feePerKw: feePerKw, + view := &HtlcView{ + OurUpdates: test.ourHtlcs, + TheirUpdates: test.theirHtlcs, + FeePerKw: feePerKw, } var ( @@ -8564,17 +8569,17 @@ func TestEvaluateView(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if result.feePerKw != test.expectedFee { + if result.FeePerKw != test.expectedFee { t.Fatalf("expected fee: %v, got: %v", - test.expectedFee, result.feePerKw) + test.expectedFee, result.FeePerKw) } checkExpectedHtlcs( - t, result.ourUpdates, test.ourExpectedHtlcs, + t, result.OurUpdates, test.ourExpectedHtlcs, ) checkExpectedHtlcs( - t, result.theirUpdates, test.theirExpectedHtlcs, + t, result.TheirUpdates, test.theirExpectedHtlcs, ) if lc.channelState.TotalMSatSent != test.expectSent { @@ -8797,15 +8802,15 @@ func TestProcessFeeUpdate(t *testing.T) { EntryType: FeeUpdate, } - view := &htlcView{ - feePerKw: chainfee.SatPerKWeight(feePerKw), + view := &HtlcView{ + FeePerKw: chainfee.SatPerKWeight(feePerKw), } processFeeUpdate( update, nextHeight, test.remoteChain, test.mutate, view, ) - if view.feePerKw != test.expectedFee { + if view.FeePerKw != test.expectedFee { t.Fatalf("expected fee: %v, got: %v", test.expectedFee, feePerKw) } @@ -9950,15 +9955,17 @@ func TestCreateHtlcRetribution(t *testing.T) { aliceChannel.channelState, ) htlc := &channeldb.HTLCEntry{ - Amt: testAmt, - Incoming: true, - OutputIndex: 1, + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(testAmt), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16](1), } // Create the htlc retribution. hr, err := createHtlcRetribution( aliceChannel.channelState, keyRing, commitHash, - dummyPrivate, leaseExpiry, htlc, + dummyPrivate, leaseExpiry, htlc, fn.None[CommitAuxLeaves](), ) // Expect no error. require.NoError(t, err) @@ -9966,8 +9973,8 @@ func TestCreateHtlcRetribution(t *testing.T) { // Check the fields have expected values. require.EqualValues(t, testAmt, hr.SignDesc.Output.Value) require.Equal(t, commitHash, hr.OutPoint.Hash) - require.EqualValues(t, htlc.OutputIndex, hr.OutPoint.Index) - require.Equal(t, htlc.Incoming, hr.IsIncoming) + require.EqualValues(t, htlc.OutputIndex.Val, hr.OutPoint.Index) + require.Equal(t, htlc.Incoming.Val, hr.IsIncoming) } // TestCreateBreachRetribution checks that `createBreachRetribution` behaves as @@ -10007,30 +10014,31 @@ func TestCreateBreachRetribution(t *testing.T) { aliceChannel.channelState, ) htlc := &channeldb.HTLCEntry{ - Amt: btcutil.Amount(testAmt), - Incoming: true, - OutputIndex: uint16(htlcIndex), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(testAmt)), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint16(htlcIndex), + ), } // Create a dummy revocation log. ourAmtMsat := lnwire.MilliSatoshi(ourAmt * 1000) theirAmtMsat := lnwire.MilliSatoshi(theirAmt * 1000) - revokedLog := channeldb.RevocationLog{ - CommitTxHash: commitHash, - OurOutputIndex: uint16(localIndex), - TheirOutputIndex: uint16(remoteIndex), - HTLCEntries: []*channeldb.HTLCEntry{htlc}, - TheirBalance: &theirAmtMsat, - OurBalance: &ourAmtMsat, - } + revokedLog := channeldb.NewRevocationLog( + uint16(localIndex), uint16(remoteIndex), commitHash, + fn.Some(ourAmtMsat), fn.Some(theirAmtMsat), + []*channeldb.HTLCEntry{htlc}, fn.None[tlv.Blob](), + ) // Create a log with an empty local output index. revokedLogNoLocal := revokedLog - revokedLogNoLocal.OurOutputIndex = channeldb.OutputIndexEmpty + revokedLogNoLocal.OurOutputIndex.Val = channeldb.OutputIndexEmpty // Create a log with an empty remote output index. revokedLogNoRemote := revokedLog - revokedLogNoRemote.TheirOutputIndex = channeldb.OutputIndexEmpty + revokedLogNoRemote.TheirOutputIndex.Val = channeldb.OutputIndexEmpty testCases := []struct { name string @@ -10060,14 +10068,20 @@ func TestCreateBreachRetribution(t *testing.T) { { name: "fail due to our index too big", revocationLog: &channeldb.RevocationLog{ - OurOutputIndex: uint16(htlcIndex + 1), + //nolint:lll + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint16(htlcIndex + 1), + ), }, expectedErr: ErrOutputIndexOutOfRange, }, { name: "fail due to their index too big", revocationLog: &channeldb.RevocationLog{ - TheirOutputIndex: uint16(htlcIndex + 1), + //nolint:lll + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint16(htlcIndex + 1), + ), }, expectedErr: ErrOutputIndexOutOfRange, }, @@ -10136,11 +10150,12 @@ func TestCreateBreachRetribution(t *testing.T) { require.Equal(t, remote, br.RemoteOutpoint) for _, hr := range br.HtlcRetributions { - require.EqualValues(t, testAmt, - hr.SignDesc.Output.Value) + require.EqualValues( + t, testAmt, hr.SignDesc.Output.Value, + ) require.Equal(t, commitHash, hr.OutPoint.Hash) require.EqualValues(t, htlcIndex, hr.OutPoint.Index) - require.Equal(t, htlc.Incoming, hr.IsIncoming) + require.Equal(t, htlc.Incoming.Val, hr.IsIncoming) } } @@ -10156,6 +10171,7 @@ func TestCreateBreachRetribution(t *testing.T) { tc.revocationLog, tx, aliceChannel.channelState, keyRing, dummyPrivate, leaseExpiry, + fn.None[CommitAuxLeaves](), ) // Check the error if expected. @@ -10274,6 +10290,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // error as there are no past delta state saved as revocation logs yet. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, breachTx, + fn.None[AuxLeafStore](), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10281,6 +10298,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // provided. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, nil, + fn.None[AuxLeafStore](), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10326,6 +10344,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // successfully. br, err := NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, breachTx, + fn.None[AuxLeafStore](), ) require.NoError(t, err) @@ -10337,6 +10356,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // since the necessary info should now be found in the revocation log. br, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, nil, + fn.None[AuxLeafStore](), ) require.NoError(t, err) assertRetribution(br, 1, 0) @@ -10345,6 +10365,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // error. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum+1, breachHeight, breachTx, + fn.None[AuxLeafStore](), ) require.ErrorIs(t, err, channeldb.ErrLogEntryNotFound) @@ -10352,6 +10373,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // provided. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum+1, breachHeight, nil, + fn.None[AuxLeafStore](), ) require.ErrorIs(t, err, channeldb.ErrLogEntryNotFound) } @@ -10389,7 +10411,7 @@ func TestExtractPayDescs(t *testing.T) { // NOTE: we use nil commitment key rings to avoid checking the htlc // scripts(`genHtlcScript`) as it should be tested independently. incomingPDs, outgoingPDs, err := lnChan.extractPayDescs( - 0, 0, htlcs, nil, nil, true, + 0, 0, htlcs, nil, nil, true, fn.None[CommitAuxLeaves](), ) require.NoError(t, err) diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 6ecef795b14..bdd10555a10 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -11,9 +11,11 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" ) // anchorSize is the constant anchor output size. @@ -225,8 +227,7 @@ func (w *WitnessScriptDesc) WitnessScriptForPath(_ input.ScriptPath, // the settled funds in the channel, plus the unsettled funds. func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, selfKey, revokeKey *btcec.PublicKey, csvDelay, leaseExpiry uint32, -) ( - input.ScriptDescriptor, error) { + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { switch { // For taproot scripts, we'll need to make a slightly modified script @@ -236,7 +237,7 @@ func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, // Our "redeem" script here is just the taproot witness program. case chanType.IsTaproot(): return input.NewLocalCommitScriptTree( - csvDelay, selfKey, revokeKey, + csvDelay, selfKey, revokeKey, auxLeaf, ) // If we are the initiator of a leased channel, then we have an @@ -290,8 +291,8 @@ func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, // script for. The second return value is the CSV delay of the output script, // what must be satisfied in order to spend the output. func CommitScriptToRemote(chanType channeldb.ChannelType, initiator bool, - remoteKey *btcec.PublicKey, - leaseExpiry uint32) (input.ScriptDescriptor, uint32, error) { + remoteKey *btcec.PublicKey, leaseExpiry uint32, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, uint32, error) { switch { // If we are not the initiator of a leased channel, then the remote @@ -320,7 +321,7 @@ func CommitScriptToRemote(chanType channeldb.ChannelType, initiator bool, // with the sole tap leaf enforcing the 1 CSV delay. case chanType.IsTaproot(): toRemoteScriptTree, err := input.NewRemoteCommitScriptTree( - remoteKey, + remoteKey, auxLeaf, ) if err != nil { return nil, 0, err @@ -419,14 +420,14 @@ func sweepSigHash(chanType channeldb.ChannelType) txscript.SigHashType { // argument should correspond to the owner of the commitment transaction which // we are generating the to_local script for. func SecondLevelHtlcScript(chanType channeldb.ChannelType, initiator bool, - revocationKey, delayKey *btcec.PublicKey, - csvDelay, leaseExpiry uint32) (input.ScriptDescriptor, error) { + revocationKey, delayKey *btcec.PublicKey, csvDelay, leaseExpiry uint32, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { switch { // For taproot channels, the pkScript is a segwit v1 p2tr output. case chanType.IsTaproot(): return input.TaprootSecondLevelScriptTree( - revocationKey, delayKey, csvDelay, + revocationKey, delayKey, csvDelay, auxLeaf, ) // If we are the initiator of a leased channel, then we have an @@ -610,11 +611,82 @@ func CommitScriptAnchors(chanType channeldb.ChannelType, return localAnchor, remoteAnchor, nil } +// CommitSortFunc is a function type alias for a function that sorts the +// commitment transaction outputs. The second parameter is a list of CLTV +// timeouts that must correspond to the number of transaction outputs, with the +// value of 0 for non-HTLC outputs. +type CommitSortFunc func(*wire.MsgTx, []uint32) error + +// CommitAuxLeaves stores two potential auxiliary leaves for the remote and +// local output that may be used to argument the final tapscript trees of the +// commitment transaction. +type CommitAuxLeaves struct { + // LocalAuxLeaf is the local party's auxiliary leaf. + LocalAuxLeaf input.AuxTapLeaf + + // RemoteAuxLeaf is the remote party's auxiliary leaf. + RemoteAuxLeaf input.AuxTapLeaf + + // OutgoingHTLCLeaves is the set of aux leaves for the outgoing HTLCs + // on this commitment transaction. + OutgoingHtlcLeaves input.AuxTapLeaves + + // IncomingHTLCLeaves is the set of aux leaves for the incoming HTLCs + // on this commitment transaction. + IncomingHtlcLeaves input.AuxTapLeaves +} + +// ForRemoteCommit returns the local+remote aux leaves from the PoV of the +// remote party's commitment. +func (c *CommitAuxLeaves) ForRemoteCommit() CommitAuxLeaves { + return CommitAuxLeaves{ + LocalAuxLeaf: c.RemoteAuxLeaf, + RemoteAuxLeaf: c.LocalAuxLeaf, + OutgoingHtlcLeaves: c.IncomingHtlcLeaves, + IncomingHtlcLeaves: c.OutgoingHtlcLeaves, + } +} + +// AuxLeafStore is used to optionally fetch auxiliary tapscript leaves for the +// commitment transaction given an opaque blob. This is also used to implement +// a state transition function for the blobs to allow them to be refreshed with +// each state. +type AuxLeafStore interface { + // FetchLeavesFromView attempts to fetch the auxiliary leaves that + // correspond to the passed aux blob, and pending original (unfiltered) + // HTLC view. + FetchLeavesFromView(chanState *channeldb.OpenChannel, prevBlob tlv.Blob, + unfilteredView *HtlcView, isOurCommit bool, ourBalance, + theirBalance lnwire.MilliSatoshi, + keyRing CommitmentKeyRing) (fn.Option[CommitAuxLeaves], + CommitSortFunc, error) + + // FetchLeavesFromCommit attempts to fetch the auxiliary leaves that + // correspond to the passed aux blob, and an existing channel + // commitment. + FetchLeavesFromCommit(chanState *channeldb.OpenChannel, + commit channeldb.ChannelCommitment, + keyRing CommitmentKeyRing) (fn.Option[CommitAuxLeaves], error) + + // FetchLeavesFromRevocation attempts to fetch the auxiliary leaves + // from a channel revocation that stores balance + blob information. + FetchLeavesFromRevocation(revocation *channeldb.RevocationLog, + ) (fn.Option[CommitAuxLeaves], error) + + // ApplyHtlcView serves as the state transition function for the custom + // channel's blob. Given the old blob, and an HTLC view, then a new + // blob should be returned that reflects the pending updates. + ApplyHtlcView(chanState *channeldb.OpenChannel, prevBlob tlv.Blob, + unfilteredView *HtlcView, isOurCommit bool, ourBalance, + theirBalance lnwire.MilliSatoshi, + keyRing CommitmentKeyRing) (fn.Option[tlv.Blob], error) +} + // CommitmentBuilder is a type that wraps the type of channel we are dealing // with, and abstracts the various ways of constructing commitment // transactions. type CommitmentBuilder struct { - // chanState is the underlying channels's state struct, used to + // chanState is the underlying channel's state struct, used to // determine the type of channel we are dealing with, and relevant // parameters. chanState *channeldb.OpenChannel @@ -622,18 +694,25 @@ type CommitmentBuilder struct { // obfuscator is a 48-bit state hint that's used to obfuscate the // current state number on the commitment transactions. obfuscator [StateHintSize]byte + + // auxLeafStore is an interface that allows us to fetch auxiliary + // tapscript leaves for the commitment output. + auxLeafStore fn.Option[AuxLeafStore] } // NewCommitmentBuilder creates a new CommitmentBuilder from chanState. -func NewCommitmentBuilder(chanState *channeldb.OpenChannel) *CommitmentBuilder { +func NewCommitmentBuilder(chanState *channeldb.OpenChannel, + leafStore fn.Option[AuxLeafStore]) *CommitmentBuilder { + // The anchor channel type MUST be tweakless. if chanState.ChanType.HasAnchors() && !chanState.ChanType.IsTweakless() { panic("invalid channel type combination") } return &CommitmentBuilder{ - chanState: chanState, - obfuscator: createStateHintObfuscator(chanState), + chanState: chanState, + obfuscator: createStateHintObfuscator(chanState), + auxLeafStore: leafStore, } } @@ -680,15 +759,89 @@ type unsignedCommitmentTx struct { cltvs []uint32 } +// AuxLeavesFromCommit is a helper function that attempts to fetch the +// auxiliary leaves given a finalized channel commitment, and a leaf store. +func AuxLeavesFromCommit(chanState *channeldb.OpenChannel, + commit channeldb.ChannelCommitment, leafStore fn.Option[AuxLeafStore], + keyRing CommitmentKeyRing) (fn.Option[CommitAuxLeaves], error) { + + if leafStore.IsNone() { + return fn.None[CommitAuxLeaves](), nil + } + + return leafStore.UnsafeFromSome().FetchLeavesFromCommit( + chanState, commit, keyRing, + ) +} + +// auxLeavesFromView is used to derive the set of commit aux leaves (if any), +// that are needed to create a new commitment transaction using the original +// (unfiltered) htlc view. +func auxLeavesFromView(chanState *channeldb.OpenChannel, + prevBlob fn.Option[tlv.Blob], originalView *HtlcView, isOurCommit bool, + ourBalance, theirBalance lnwire.MilliSatoshi, + leafStore fn.Option[AuxLeafStore], + keyRing CommitmentKeyRing) (fn.Option[CommitAuxLeaves], CommitSortFunc, + error) { + + if leafStore.IsNone() { + return fn.None[CommitAuxLeaves](), nil, nil + } + + if prevBlob.IsNone() { + return fn.None[CommitAuxLeaves](), nil, nil + } + + return leafStore.UnsafeFromSome().FetchLeavesFromView( + chanState, prevBlob.UnsafeFromSome(), originalView, isOurCommit, + ourBalance, theirBalance, keyRing, + ) +} + +// auxLeavesFromRevocation is a helper function that attempts to fetch the aux +// leaves given a revoked state. +func auxLeavesFromRevocation(_ *channeldb.OpenChannel, + revocation *channeldb.RevocationLog, leafStore fn.Option[AuxLeafStore], + _ CommitmentKeyRing) (fn.Option[CommitAuxLeaves], error) { + + if leafStore.IsNone() { + return fn.None[CommitAuxLeaves](), nil + } + + return leafStore.UnsafeFromSome().FetchLeavesFromRevocation(revocation) +} + +// updateAuxBlob is a helper function that attempts to update the aux blob +// given the prior and current state information. +func updateAuxBlob(chanState *channeldb.OpenChannel, + prevBlob fn.Option[tlv.Blob], nextViewUnfiltered *HtlcView, + isOurCommit bool, ourBalance, theirBalance lnwire.MilliSatoshi, + leafStore fn.Option[AuxLeafStore], + keyRing CommitmentKeyRing) (fn.Option[tlv.Blob], error) { + + if leafStore.IsNone() { + return fn.None[tlv.Blob](), nil + } + + if prevBlob.IsNone() { + return fn.None[tlv.Blob](), nil + } + + return leafStore.UnsafeFromSome().ApplyHtlcView( + chanState, prevBlob.UnsafeFromSome(), nextViewUnfiltered, + isOurCommit, ourBalance, theirBalance, keyRing, + ) +} + // createUnsignedCommitmentTx generates the unsigned commitment transaction for // a commitment view and returns it as part of the unsignedCommitmentTx. The // passed in balances should be balances *before* subtracting any commitment // fees, but after anchor outputs. func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, theirBalance lnwire.MilliSatoshi, isOurs bool, - feePerKw chainfee.SatPerKWeight, height uint64, - filteredHTLCView *htlcView, - keyRing *CommitmentKeyRing) (*unsignedCommitmentTx, error) { + feePerKw chainfee.SatPerKWeight, height uint64, originalHtlcView, + filteredHTLCView *HtlcView, keyRing *CommitmentKeyRing, + prevCommit *commitment) (*unsignedCommitmentTx, error) { dustLimit := cb.chanState.LocalChanCfg.DustLimit if !isOurs { @@ -696,7 +849,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } numHTLCs := int64(0) - for _, htlc := range filteredHTLCView.ourUpdates { + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( cb.chanState.ChanType, false, isOurs, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -707,7 +860,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, numHTLCs++ } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( cb.chanState.ChanType, true, isOurs, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -749,10 +902,18 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, theirBalance -= commitFeeMSat } - var ( - commitTx *wire.MsgTx - err error + var commitTx *wire.MsgTx + + // Before we create the commitment transaction below, we'll try to see + // if there're any aux leaves that need to be a part of the tapscript + // tree. We'll only do this if we have a custom blob defined though. + auxLeaves, customCommitSort, err := auxLeavesFromView( + cb.chanState, prevCommit.customBlob, originalHtlcView, + isOurs, ourBalance, theirBalance, cb.auxLeafStore, *keyRing, ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } // Depending on whether the transaction is ours or not, we call // CreateCommitTx with parameters matching the perspective, to generate @@ -768,6 +929,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, &cb.chanState.LocalChanCfg, &cb.chanState.RemoteChanCfg, ourBalance.ToSatoshis(), theirBalance.ToSatoshis(), numHTLCs, cb.chanState.IsInitiator, leaseExpiry, + auxLeaves, ) } else { commitTx, err = CreateCommitTx( @@ -775,12 +937,26 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, &cb.chanState.RemoteChanCfg, &cb.chanState.LocalChanCfg, theirBalance.ToSatoshis(), ourBalance.ToSatoshis(), numHTLCs, !cb.chanState.IsInitiator, leaseExpiry, + auxLeaves, ) } if err != nil { return nil, err } + // Similarly, we'll now attempt to extract the set of aux leaves for + // the set of incoming and outgoing HTLCs. + incomingAuxLeaves := fn.MapOption( + func(leaves CommitAuxLeaves) input.AuxTapLeaves { + return leaves.IncomingHtlcLeaves + }, + )(auxLeaves) + outgoingAuxLeaves := fn.MapOption( + func(leaves CommitAuxLeaves) input.AuxTapLeaves { + return leaves.OutgoingHtlcLeaves + }, + )(auxLeaves) + // We'll now add all the HTLC outputs to the commitment transaction. // Each output includes an off-chain 2-of-2 covenant clause, so we'll // need the objective local/remote keys for this particular commitment @@ -791,7 +967,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, // commitment outputs and should correspond to zero values for the // purposes of sorting. cltvs := make([]uint32, len(commitTx.TxOut)) - for _, htlc := range filteredHTLCView.ourUpdates { + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( cb.chanState.ChanType, false, isOurs, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -800,16 +976,22 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } + auxLeaf := fn.MapOption( + func(leaves input.AuxTapLeaves) input.AuxTapLeaf { + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(outgoingAuxLeaves) + err := addHTLC( commitTx, isOurs, false, htlc, keyRing, - cb.chanState.ChanType, + cb.chanState.ChanType, fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err } cltvs = append(cltvs, htlc.Timeout) // nolint:makezero } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( cb.chanState.ChanType, true, isOurs, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -818,9 +1000,15 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } + auxLeaf := fn.MapOption( + func(leaves input.AuxTapLeaves) input.AuxTapLeaf { + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(incomingAuxLeaves) + err := addHTLC( commitTx, isOurs, true, htlc, keyRing, - cb.chanState.ChanType, + cb.chanState.ChanType, fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -837,9 +1025,24 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } // Sort the transactions according to the agreed upon canonical - // ordering. This lets us skip sending the entire transaction over, - // instead we'll just send signatures. - InPlaceCommitSort(commitTx, cltvs) + // ordering (which might be customized for custom channel types, but + // deterministic and both parties will arrive at the same result). This + // lets us skip sending the entire transaction over, instead we'll just + // send signatures. + if auxLeaves.IsSome() { + if customCommitSort == nil { + return nil, fmt.Errorf("custom channel type requires " + + "sorting function") + } + + err = customCommitSort(commitTx, cltvs) + if err != nil { + return nil, fmt.Errorf("unable to sort commitment "+ + "transaction by custom order: %w", err) + } + } else { + InPlaceCommitSort(commitTx, cltvs) + } // Next, we'll ensure that we don't accidentally create a commitment // transaction which would be invalid by consensus. @@ -881,24 +1084,33 @@ func CreateCommitTx(chanType channeldb.ChannelType, fundingOutput wire.TxIn, keyRing *CommitmentKeyRing, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, amountToLocal, amountToRemote btcutil.Amount, - numHTLCs int64, initiator bool, leaseExpiry uint32) (*wire.MsgTx, error) { + numHTLCs int64, initiator bool, leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*wire.MsgTx, error) { // First, we create the script for the delayed "pay-to-self" output. // This output has 2 main redemption clauses: either we can redeem the // output after a relative block delay, or the remote node can claim // the funds with the revocation key if we broadcast a revoked // commitment transaction. + localAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + })(auxLeaves) toLocalScript, err := CommitScriptToSelf( chanType, initiator, keyRing.ToLocalKey, keyRing.RevocationKey, uint32(localChanCfg.CsvDelay), leaseExpiry, + fn.FlattenOption(localAuxLeaf), ) if err != nil { return nil, err } // Next, we create the script paying to the remote. + remoteAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + })(auxLeaves) toRemoteScript, _, err := CommitScriptToRemote( chanType, initiator, keyRing.ToRemoteKey, leaseExpiry, + fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return nil, err @@ -1076,8 +1288,8 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // genTaprootHtlcScript generates the HTLC scripts for a taproot+musig2 // channel. func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, - rHash [32]byte, - keyRing *CommitmentKeyRing) (*input.HtlcScriptTree, error) { + rHash [32]byte, keyRing *CommitmentKeyRing, + auxLeaf input.AuxTapLeaf) (*input.HtlcScriptTree, error) { var ( htlcScriptTree *input.HtlcScriptTree @@ -1094,7 +1306,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case isIncoming && ourCommit: htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], ourCommit, auxLeaf, ) // We're being paid via an HTLC by the remote party, and the HTLC is @@ -1103,7 +1315,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case isIncoming && !ourCommit: htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], ourCommit, auxLeaf, ) // We're sending an HTLC which is being added to our commitment @@ -1112,7 +1324,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case !isIncoming && ourCommit: htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], ourCommit, auxLeaf, ) // Finally, we're paying the remote party via an HTLC, which is being @@ -1121,7 +1333,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case !isIncoming && !ourCommit: htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], ourCommit, auxLeaf, ) } @@ -1136,7 +1348,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, // along side the multiplexer. func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing, -) (input.ScriptDescriptor, error) { + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { if !chanType.IsTaproot() { return genSegwitV0HtlcScript( @@ -1146,7 +1358,7 @@ func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, } return genTaprootHtlcScript( - isIncoming, ourCommit, timeout, rHash, keyRing, + isIncoming, ourCommit, timeout, rHash, keyRing, auxLeaf, ) } @@ -1159,13 +1371,15 @@ func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, // the descriptor itself. func addHTLC(commitTx *wire.MsgTx, ourCommit bool, isIncoming bool, paymentDesc *PaymentDescriptor, - keyRing *CommitmentKeyRing, chanType channeldb.ChannelType) error { + keyRing *CommitmentKeyRing, chanType channeldb.ChannelType, + auxLeaf input.AuxTapLeaf) error { timeout := paymentDesc.Timeout rHash := paymentDesc.RHash scriptInfo, err := genHtlcScript( chanType, isIncoming, ourCommit, timeout, rHash, keyRing, + auxLeaf, ) if err != nil { return err @@ -1198,7 +1412,8 @@ func addHTLC(commitTx *wire.MsgTx, ourCommit bool, // output scripts and compares them against the outputs inside the commitment // to find the match. func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, - chanState *channeldb.OpenChannel) (uint32, uint32, error) { + chanState *channeldb.OpenChannel, + leafStore fn.Option[AuxLeafStore]) (uint32, uint32, error) { // Init the output indexes as empty. ourIndex := uint32(channeldb.OutputIndexEmpty) @@ -1228,6 +1443,16 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, leaseExpiry = chanState.ThawHeight } + // If we have a custom blob, then we'll attempt to fetch the aux leaves + // for this state. + auxLeaves, err := AuxLeavesFromCommit( + chanState, chanCommit, leafStore, *keyRing, + ) + if err != nil { + return ourIndex, theirIndex, fmt.Errorf("unable to fetch aux "+ + "leaves: %w", err) + } + // Map the scripts from our PoV. When facing a local commitment, the to // local output belongs to us and the to remote output belongs to them. // When facing a remote commitment, the to local output belongs to them @@ -1235,9 +1460,13 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, // Compute the to local script. From our PoV, when facing a remote // commitment, the to local output belongs to them. + localAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + })(auxLeaves) theirScript, err := CommitScriptToSelf( chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey, keyRing.RevocationKey, theirDelay, leaseExpiry, + fn.FlattenOption(localAuxLeaf), ) if err != nil { return ourIndex, theirIndex, err @@ -1245,9 +1474,12 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, // Compute the to remote script. From our PoV, when facing a remote // commitment, the to remote output belongs to us. + remoteAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + })(auxLeaves) ourScript, _, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return ourIndex, theirIndex, err diff --git a/lnwallet/config.go b/lnwallet/config.go index 7eeacb6ea23..24961f38edb 100644 --- a/lnwallet/config.go +++ b/lnwallet/config.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcwallet/wallet" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -62,4 +63,8 @@ type Config struct { // CoinSelectionStrategy is the strategy that is used for selecting // coins when funding a transaction. CoinSelectionStrategy wallet.CoinSelectionStrategy + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[AuxLeafStore] } diff --git a/lnwallet/transactions.go b/lnwallet/transactions.go index 1cf954d3cb5..da86650bc63 100644 --- a/lnwallet/transactions.go +++ b/lnwallet/transactions.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/input" ) const ( @@ -50,8 +51,8 @@ var ( // - func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, htlcOutput wire.OutPoint, htlcAmt btcutil.Amount, csvDelay, - leaseExpiry uint32, revocationKey, delayKey *btcec.PublicKey) ( - *wire.MsgTx, error) { + leaseExpiry uint32, revocationKey, delayKey *btcec.PublicKey, + auxLeaf input.AuxTapLeaf) (*wire.MsgTx, error) { // Create a version two transaction (as the success version of this // spends an output with a CSV timeout). @@ -71,7 +72,7 @@ func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, // HTLC outputs. scriptInfo, err := SecondLevelHtlcScript( chanType, initiator, revocationKey, delayKey, csvDelay, - leaseExpiry, + leaseExpiry, auxLeaf, ) if err != nil { return nil, err @@ -110,7 +111,8 @@ func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, func CreateHtlcTimeoutTx(chanType channeldb.ChannelType, initiator bool, htlcOutput wire.OutPoint, htlcAmt btcutil.Amount, cltvExpiry, csvDelay, leaseExpiry uint32, - revocationKey, delayKey *btcec.PublicKey) (*wire.MsgTx, error) { + revocationKey, delayKey *btcec.PublicKey, + auxLeaf input.AuxTapLeaf) (*wire.MsgTx, error) { // Create a version two transaction (as the success version of this // spends an output with a CSV timeout), and set the lock-time to the @@ -134,7 +136,7 @@ func CreateHtlcTimeoutTx(chanType channeldb.ChannelType, initiator bool, // HTLC outputs. scriptInfo, err := SecondLevelHtlcScript( chanType, initiator, revocationKey, delayKey, csvDelay, - leaseExpiry, + leaseExpiry, auxLeaf, ) if err != nil { return nil, err diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index ab0ef73283f..14c11b7f764 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -21,6 +21,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -631,6 +632,7 @@ func testSpendValidation(t *testing.T, tweakless bool) { commitmentTx, err := CreateCommitTx( channelType, *fakeFundingTxIn, keyRing, aliceChanCfg, bobChanCfg, channelBalance, channelBalance, 0, true, 0, + fn.None[CommitAuxLeaves](), ) if err != nil { t.Fatalf("unable to create commitment transaction: %v", nil) diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index 1e8a644db91..c579363542c 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -1458,6 +1458,21 @@ func (l *LightningWallet) handleFundingCancelRequest(req *fundingReserveCancelMs req.err <- nil } +// createCommitOpts is a struct that holds the options for creating a new +// commitment transaction. +type createCommitOpts struct { + auxLeaves fn.Option[CommitAuxLeaves] +} + +// defaultCommitOpts returns a new createCommitOpts with default values. +func defaultCommitOpts() createCommitOpts { + return createCommitOpts{} +} + +// CreateCommitOpt is a functional option that can be used to modify the way a +// new commitment transaction is created. +type CreateCommitOpt func(*createCommitOpts) + // CreateCommitmentTxns is a helper function that creates the initial // commitment transaction for both parties. This function is used during the // initial funding workflow as both sides must generate a signature for the @@ -1467,7 +1482,13 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, ourChanCfg, theirChanCfg *channeldb.ChannelConfig, localCommitPoint, remoteCommitPoint *btcec.PublicKey, fundingTxIn wire.TxIn, chanType channeldb.ChannelType, initiator bool, - leaseExpiry uint32) (*wire.MsgTx, *wire.MsgTx, error) { + leaseExpiry uint32, opts ...CreateCommitOpt) (*wire.MsgTx, *wire.MsgTx, + error) { + + options := defaultCommitOpts() + for _, optFunc := range opts { + optFunc(&options) + } localCommitmentKeys := DeriveCommitmentKeys( localCommitPoint, true, chanType, ourChanCfg, theirChanCfg, @@ -1479,7 +1500,7 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, ourCommitTx, err := CreateCommitTx( chanType, fundingTxIn, localCommitmentKeys, ourChanCfg, theirChanCfg, localBalance, remoteBalance, 0, initiator, - leaseExpiry, + leaseExpiry, options.auxLeaves, ) if err != nil { return nil, nil, err @@ -1493,7 +1514,7 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, theirCommitTx, err := CreateCommitTx( chanType, fundingTxIn, remoteCommitmentKeys, theirChanCfg, ourChanCfg, remoteBalance, localBalance, 0, !initiator, - leaseExpiry, + leaseExpiry, options.auxLeaves, ) if err != nil { return nil, nil, err @@ -2484,9 +2505,16 @@ func TapscriptRootToTweak(root chainhash.Hash) input.MuSig2Tweaks { func (l *LightningWallet) ValidateChannel(channelState *channeldb.OpenChannel, fundingTx *wire.MsgTx) error { + var chanOpts []ChannelOpt + l.Cfg.AuxLeafStore.WhenSome(func(s AuxLeafStore) { + chanOpts = append(chanOpts, WithLeafStore(s)) + }) + // First, we'll obtain a fully signed commitment transaction so we can // pass into it on the chanvalidate package for verification. - channel, err := NewLightningChannel(l.Cfg.Signer, channelState, nil) + channel, err := NewLightningChannel( + l.Cfg.Signer, channelState, nil, chanOpts..., + ) if err != nil { return err } diff --git a/peer/brontide.go b/peer/brontide.go index 984a5280987..b9a9f68ca15 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -359,6 +359,10 @@ type Config struct { AddLocalAlias func(alias, base lnwire.ShortChannelID, gossip bool) error + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] + // PongBuf is a slice we'll reuse instead of allocating memory on the // heap. Since only reads will occur and no writes, there is no need // for any synchronization primitives. As a result, it's safe to share @@ -869,8 +873,12 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( } } + var chanOpts []lnwallet.ChannelOpt + p.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) lnChan, err := lnwallet.NewLightningChannel( - p.cfg.Signer, dbChan, p.cfg.SigPool, + p.cfg.Signer, dbChan, p.cfg.SigPool, chanOpts..., ) if err != nil { return nil, err @@ -3977,6 +3985,10 @@ func (p *Brontide) addActiveChannel(c *lnpeer.NewChannel) error { chanOpts = append(chanOpts, lnwallet.WithSkipNonceInit()) } + p.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // If not already active, we'll add this channel to the set of active // channels, so we can look it up later easily according to its channel // ID. diff --git a/server.go b/server.go index e2c0b831a21..650cdbced40 100644 --- a/server.go +++ b/server.go @@ -157,6 +157,8 @@ type server struct { cfg *Config + implCfg *ImplementationCfg + // identityECDH is an ECDH capable wrapper for the private key used // to authenticate any incoming connections. identityECDH keychain.SingleKeyECDH @@ -480,8 +482,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, nodeKeyDesc *keychain.KeyDescriptor, chansToRestore walletunlocker.ChannelsToRecover, chanPredicate chanacceptor.ChannelAcceptor, - torController *tor.Controller, tlsManager *TLSManager) (*server, - error) { + torController *tor.Controller, tlsManager *TLSManager, + implCfg *ImplementationCfg) (*server, error) { var ( err error @@ -567,6 +569,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s := &server{ cfg: cfg, + implCfg: implCfg, graphDB: dbs.GraphDB.ChannelGraph(), chanStateDB: dbs.ChanStateDB.ChannelStateDB(), addrSource: dbs.ChanStateDB, @@ -1245,6 +1248,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return &pc.Incoming }, + AuxLeafStore: implCfg.AuxLeafStore, }, dbs.ChanStateDB) // Select the configuration and funding parameters for Bitcoin. @@ -1578,6 +1582,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, br, err := lnwallet.NewBreachRetribution( channel, commitHeight, 0, nil, + implCfg.AuxLeafStore, ) if err != nil { return nil, 0, err @@ -3906,6 +3911,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, AddLocalAlias: s.aliasMgr.AddLocalAlias, DisallowRouteBlinding: s.cfg.ProtocolOptions.NoRouteBlinding(), Quit: s.quit, + AuxLeafStore: s.implCfg.AuxLeafStore, } copy(pCfg.PubKeyBytes[:], peerAddr.IdentityKey.SerializeCompressed()) diff --git a/watchtower/blob/justice_kit.go b/watchtower/blob/justice_kit.go index 8b6c20194f1..7780239f07a 100644 --- a/watchtower/blob/justice_kit.go +++ b/watchtower/blob/justice_kit.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -307,9 +308,11 @@ func newTaprootJusticeKit(sweepScript []byte, keyRing := breachInfo.KeyRing + // TODO(roasbeef): aux leaf tower updates needed + tree, err := input.NewLocalCommitScriptTree( breachInfo.RemoteDelay, keyRing.ToLocalKey, - keyRing.RevocationKey, + keyRing.RevocationKey, fn.None[txscript.TapLeaf](), ) if err != nil { return nil, err @@ -416,7 +419,9 @@ func (t *taprootJusticeKit) ToRemoteOutputSpendInfo() (*txscript.PkScript, return nil, nil, 0, err } - scriptTree, err := input.NewRemoteCommitScriptTree(toRemotePk) + scriptTree, err := input.NewRemoteCommitScriptTree( + toRemotePk, fn.None[txscript.TapLeaf](), + ) if err != nil { return nil, nil, 0, err } diff --git a/watchtower/blob/justice_kit_test.go b/watchtower/blob/justice_kit_test.go index fd12993a0a7..a1d6ec9f2c4 100644 --- a/watchtower/blob/justice_kit_test.go +++ b/watchtower/blob/justice_kit_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -304,7 +305,9 @@ func TestJusticeKitRemoteWitnessConstruction(t *testing.T) { name: "taproot commitment", blobType: TypeAltruistTaprootCommit, expWitnessScript: func(pk *btcec.PublicKey) []byte { - tree, _ := input.NewRemoteCommitScriptTree(pk) + tree, _ := input.NewRemoteCommitScriptTree( + pk, fn.None[txscript.TapLeaf](), + ) return tree.SettleLeaf.Script }, @@ -461,6 +464,7 @@ func TestJusticeKitToLocalWitnessConstruction(t *testing.T) { script, _ := input.NewLocalCommitScriptTree( csvDelay, delay, rev, + fn.None[txscript.TapLeaf](), ) return script.RevocationLeaf.Script diff --git a/watchtower/lookout/justice_descriptor_test.go b/watchtower/lookout/justice_descriptor_test.go index 2ca187fdc78..29c7f9a5a11 100644 --- a/watchtower/lookout/justice_descriptor_test.go +++ b/watchtower/lookout/justice_descriptor_test.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -123,7 +124,7 @@ func testJusticeDescriptor(t *testing.T, blobType blob.Type) { if isTaprootChannel { toLocalCommitTree, err = input.NewLocalCommitScriptTree( - csvDelay, toLocalPK, revPK, + csvDelay, toLocalPK, revPK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err) @@ -174,7 +175,7 @@ func testJusticeDescriptor(t *testing.T, blobType blob.Type) { toRemoteSequence = 1 commitScriptTree, err := input.NewRemoteCommitScriptTree( - toRemotePK, + toRemotePK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err) diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 695c4f9ecd7..7894631b8ff 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -136,6 +137,7 @@ func genTaskTest( if chanType.IsTaproot() { scriptTree, _ := input.NewLocalCommitScriptTree( csvDelay, toLocalPK, revPK, + fn.None[txscript.TapLeaf](), ) pkScript, _ := input.PayToTaprootScript( @@ -189,7 +191,7 @@ func genTaskTest( if chanType.IsTaproot() { scriptTree, _ := input.NewRemoteCommitScriptTree( - toRemotePK, + toRemotePK, fn.None[txscript.TapLeaf](), ) pkScript, _ := input.PayToTaprootScript( diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 38d9acd9f07..f3a4d5bf4ee 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -230,12 +231,14 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) { // Construct the to-local witness script. toLocalScriptTree, err := input.NewLocalCommitScriptTree( - c.csvDelay, c.toLocalPK, c.revPK, + c.csvDelay, c.toLocalPK, c.revPK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err, "unable to create to-local script") // Construct the to-remote witness script. - toRemoteScriptTree, err := input.NewRemoteCommitScriptTree(c.toRemotePK) + toRemoteScriptTree, err := input.NewRemoteCommitScriptTree( + c.toRemotePK, fn.None[txscript.TapLeaf](), + ) require.NoError(t, err, "unable to create to-remote script") // Compute the to-local witness script hash.