From 6cf449168ce19d9edbf37936c296b0d094d90ed5 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 15 Mar 2024 11:43:21 -0400 Subject: [PATCH 01/14] input+lnwallet: update taproot scripts to accept optional aux leaf In this commit, we update all the taproot scripts to also accept an optional aux leaf. This aux leaf can be used to add more redemption paths for advanced channels, or just as an extra commitment space. --- input/script_utils.go | 115 ++++++---- input/size_test.go | 20 +- input/taproot.go | 10 + input/taproot_test.go | 197 +++++++++++++----- lnwallet/channel.go | 2 + lnwallet/commitment.go | 19 +- watchtower/blob/justice_kit.go | 9 +- watchtower/blob/justice_kit_test.go | 6 +- watchtower/lookout/justice_descriptor_test.go | 5 +- .../wtclient/backup_task_internal_test.go | 4 +- watchtower/wtclient/client_test.go | 7 +- 11 files changed, 284 insertions(+), 110 deletions(-) diff --git a/input/script_utils.go b/input/script_utils.go index 9e639bca849..5ad0a0e90d9 100644 --- a/input/script_utils.go +++ b/input/script_utils.go @@ -673,6 +673,13 @@ type HtlcScriptTree struct { // TimeoutTapLeaf is the tapleaf for the timeout path. TimeoutTapLeaf txscript.TapLeaf + // AuxLeaf is an auxiliary leaf that can be used to extend the base + // HTLC script tree with new spend paths, or just as extra commitment + // space. When present, this leaf will always be in the left-most or + // right-most area of the tapscript tree. + AuxLeaf AuxTapLeaf + + // htlcType is the type of HTLC script this is. htlcType htlcType } @@ -748,8 +755,8 @@ var _ TapscriptDescriptor = (*HtlcScriptTree)(nil) // senderHtlcTapScriptTree builds the tapscript tree which is used to anchor // the HTLC key for HTLCs on the sender's commitment. func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, - revokeKey *btcec.PublicKey, payHash []byte, - hType htlcType) (*HtlcScriptTree, error) { + revokeKey *btcec.PublicKey, payHash []byte, hType htlcType, + auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { // First, we'll obtain the tap leaves for both the success and timeout // path. @@ -766,11 +773,14 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, return nil, err } + tapLeaves := []txscript.TapLeaf{successTapLeaf, timeoutTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With the two leaves obtained, we'll now make the tapscript tree, // then obtain the root from that - tapscriptTree := txscript.AssembleTaprootScriptTree( - successTapLeaf, timeoutTapLeaf, - ) + tapscriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapscriptTree.RootNode.TapHash() @@ -789,6 +799,7 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, }, SuccessTapLeaf: successTapLeaf, TimeoutTapLeaf: timeoutTapLeaf, + AuxLeaf: auxLeaf, htlcType: hType, }, nil } @@ -822,8 +833,8 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // The top level keyspend key is the revocation key, which allows a defender to // unilaterally spend the created output. func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, - revokeKey *btcec.PublicKey, payHash []byte, - localCommit bool) (*HtlcScriptTree, error) { + revokeKey *btcec.PublicKey, payHash []byte, localCommit bool, + auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { var hType htlcType if localCommit { @@ -836,8 +847,8 @@ func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, // tree that includes the top level output script, as well as the two // tap leaf paths. return senderHtlcTapScriptTree( - senderHtlcKey, receiverHtlcKey, revokeKey, payHash, - hType, + senderHtlcKey, receiverHtlcKey, revokeKey, payHash, hType, + auxLeaf, ) } @@ -1307,8 +1318,8 @@ func ReceiverHtlcTapLeafSuccess(receiverHtlcKey *btcec.PublicKey, // receiverHtlcTapScriptTree builds the tapscript tree which is used to anchor // the HTLC key for HTLCs on the receiver's commitment. func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, - revokeKey *btcec.PublicKey, payHash []byte, - cltvExpiry uint32, hType htlcType) (*HtlcScriptTree, error) { + revokeKey *btcec.PublicKey, payHash []byte, cltvExpiry uint32, + hType htlcType, auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) { // First, we'll obtain the tap leaves for both the success and timeout // path. @@ -1325,11 +1336,14 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, return nil, err } + tapLeaves := []txscript.TapLeaf{timeoutTapLeaf, successTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With the two leaves obtained, we'll now make the tapscript tree, // then obtain the root from that - tapscriptTree := txscript.AssembleTaprootScriptTree( - timeoutTapLeaf, successTapLeaf, - ) + tapscriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapscriptTree.RootNode.TapHash() @@ -1348,6 +1362,7 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, }, SuccessTapLeaf: successTapLeaf, TimeoutTapLeaf: timeoutTapLeaf, + AuxLeaf: auxLeaf, htlcType: hType, }, nil } @@ -1382,7 +1397,8 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // the tap leaf are returned. func ReceiverHTLCScriptTaproot(cltvExpiry uint32, senderHtlcKey, receiverHtlcKey, revocationKey *btcec.PublicKey, - payHash []byte, ourCommit bool) (*HtlcScriptTree, error) { + payHash []byte, ourCommit bool, auxLeaf AuxTapLeaf) (*HtlcScriptTree, + error) { var hType htlcType if ourCommit { @@ -1396,7 +1412,7 @@ func ReceiverHTLCScriptTaproot(cltvExpiry uint32, // tap leaf paths. return receiverHtlcTapScriptTree( senderHtlcKey, receiverHtlcKey, revocationKey, payHash, - cltvExpiry, hType, + cltvExpiry, hType, auxLeaf, ) } @@ -1625,9 +1641,9 @@ func TaprootSecondLevelTapLeaf(delayKey *btcec.PublicKey, } // SecondLevelHtlcTapscriptTree construct the indexed tapscript tree needed to -// generate the taptweak to create the final output and also control block. -func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, - csvDelay uint32) (*txscript.IndexedTapScriptTree, error) { +// generate the tap tweak to create the final output and also control block. +func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, csvDelay uint32, + auxLeaf AuxTapLeaf) (*txscript.IndexedTapScriptTree, error) { // First grab the second level leaf script we need to create the top // level output. @@ -1636,9 +1652,14 @@ func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, return nil, err } + tapLeaves := []txscript.TapLeaf{secondLevelTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // Now that we have the sole second level script, we can create the // tapscript tree that commits to both the leaves. - return txscript.AssembleTaprootScriptTree(secondLevelTapLeaf), nil + return txscript.AssembleTaprootScriptTree(tapLeaves...), nil } // TaprootSecondLevelHtlcScript is the uniform script that's used as the output @@ -1658,12 +1679,12 @@ func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, // // The keyspend path require knowledge of the top level revocation private key. func TaprootSecondLevelHtlcScript(revokeKey, delayKey *btcec.PublicKey, - csvDelay uint32) (*btcec.PublicKey, error) { + csvDelay uint32, auxLeaf AuxTapLeaf) (*btcec.PublicKey, error) { // First, we'll make the tapscript tree that commits to the redemption // path. tapScriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey, csvDelay, + delayKey, csvDelay, auxLeaf, ) if err != nil { return nil, err @@ -1688,17 +1709,21 @@ type SecondLevelScriptTree struct { // SuccessTapLeaf is the tapleaf for the redemption path. SuccessTapLeaf txscript.TapLeaf + + // AuxLeaf is an optional leaf that can be used to extend the script + // tree. + AuxLeaf AuxTapLeaf } // TaprootSecondLevelScriptTree constructs the tapscript tree used to spend the // second level HTLC output. func TaprootSecondLevelScriptTree(revokeKey, delayKey *btcec.PublicKey, - csvDelay uint32) (*SecondLevelScriptTree, error) { + csvDelay uint32, auxLeaf AuxTapLeaf) (*SecondLevelScriptTree, error) { // First, we'll make the tapscript tree that commits to the redemption // path. tapScriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey, csvDelay, + delayKey, csvDelay, auxLeaf, ) if err != nil { return nil, err @@ -1719,6 +1744,7 @@ func TaprootSecondLevelScriptTree(revokeKey, delayKey *btcec.PublicKey, InternalKey: revokeKey, }, SuccessTapLeaf: tapScriptTree.LeafMerkleProofs[0].TapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2095,6 +2121,12 @@ type CommitScriptTree struct { // RevocationLeaf is the leaf used to spend the output with the // revocation key signature. RevocationLeaf txscript.TapLeaf + + // AuxLeaf is an auxiliary leaf that can be used to extend the base + // commitment script tree with new spend paths, or just as extra + // commitment space. When present, this leaf will always be in the + // left-most or right-most area of the tapscript tree. + AuxLeaf AuxTapLeaf } // A compile time check to ensure CommitScriptTree implements the @@ -2154,8 +2186,9 @@ func (c *CommitScriptTree) CtrlBlockForPath(path ScriptPath, // NewLocalCommitScriptTree returns a new CommitScript tree that can be used to // create and spend the commitment output for the local party. -func NewLocalCommitScriptTree(csvTimeout uint32, - selfKey, revokeKey *btcec.PublicKey) (*CommitScriptTree, error) { +func NewLocalCommitScriptTree(csvTimeout uint32, selfKey, + revokeKey *btcec.PublicKey, auxLeaf AuxTapLeaf) (*CommitScriptTree, + error) { // First, we'll need to construct the tapLeaf that'll be our delay CSV // clause. @@ -2175,9 +2208,13 @@ func NewLocalCommitScriptTree(csvTimeout uint32, // the two leaves, and then obtain a root from that. delayTapLeaf := txscript.NewBaseTapLeaf(delayScript) revokeTapLeaf := txscript.NewBaseTapLeaf(revokeScript) - tapScriptTree := txscript.AssembleTaprootScriptTree( - delayTapLeaf, revokeTapLeaf, - ) + + tapLeaves := []txscript.TapLeaf{delayTapLeaf, revokeTapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + + tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapScriptTree.RootNode.TapHash() // Now that we have our root, we can arrive at the final output script @@ -2195,6 +2232,7 @@ func NewLocalCommitScriptTree(csvTimeout uint32, }, SettleLeaf: delayTapLeaf, RevocationLeaf: revokeTapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2264,7 +2302,7 @@ func TaprootCommitScriptToSelf(csvTimeout uint32, selfKey, revokeKey *btcec.PublicKey) (*btcec.PublicKey, error) { commitScriptTree, err := NewLocalCommitScriptTree( - csvTimeout, selfKey, revokeKey, + csvTimeout, selfKey, revokeKey, NoneTapLeaf(), ) if err != nil { return nil, err @@ -2593,7 +2631,7 @@ func CommitScriptToRemoteConfirmed(key *btcec.PublicKey) ([]byte, error) { // NewRemoteCommitScriptTree constructs a new script tree for the remote party // to sweep their funds after a hard coded 1 block delay. func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, -) (*CommitScriptTree, error) { + auxLeaf AuxTapLeaf) (*CommitScriptTree, error) { // First, construct the remote party's tapscript they'll use to sweep // their outputs. @@ -2609,10 +2647,16 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, return nil, err } + tapLeaf := txscript.NewBaseTapLeaf(remoteScript) + + tapLeaves := []txscript.TapLeaf{tapLeaf} + auxLeaf.WhenSome(func(l txscript.TapLeaf) { + tapLeaves = append(tapLeaves, l) + }) + // With this script constructed, we'll map that into a tapLeaf, then // make a new tapscript root from that. - tapLeaf := txscript.NewBaseTapLeaf(remoteScript) - tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaf) + tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...) tapScriptRoot := tapScriptTree.RootNode.TapHash() // Now that we have our root, we can arrive at the final output script @@ -2629,6 +2673,7 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, InternalKey: &TaprootNUMSKey, }, SettleLeaf: tapLeaf, + AuxLeaf: auxLeaf, }, nil } @@ -2645,9 +2690,9 @@ func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey, // OP_CHECKSIG // 1 OP_CHECKSEQUENCEVERIFY OP_DROP func TaprootCommitScriptToRemote(remoteKey *btcec.PublicKey, -) (*btcec.PublicKey, error) { + auxLeaf AuxTapLeaf) (*btcec.PublicKey, error) { - commitScriptTree, err := NewRemoteCommitScriptTree(remoteKey) + commitScriptTree, err := NewRemoteCommitScriptTree(remoteKey, auxLeaf) if err != nil { return nil, err } diff --git a/input/size_test.go b/input/size_test.go index 1f447cf89fa..c4e6fae377a 100644 --- a/input/size_test.go +++ b/input/size_test.go @@ -851,7 +851,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} commitScriptTree, err := input.NewLocalCommitScriptTree( testCSVDelay, testKey.PubKey(), - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -885,7 +885,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} commitScriptTree, err := input.NewLocalCommitScriptTree( testCSVDelay, testKey.PubKey(), - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -919,7 +919,7 @@ var witnessSizeTests = []witnessSizeTest{ signer := &dummySigner{} //nolint:lll commitScriptTree, err := input.NewRemoteCommitScriptTree( - testKey.PubKey(), + testKey.PubKey(), input.NoneTapLeaf(), ) require.NoError(t, err) @@ -986,6 +986,7 @@ var witnessSizeTests = []witnessSizeTest{ scriptTree, err := input.SecondLevelHtlcTapscriptTree( testKey.PubKey(), testCSVDelay, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1025,6 +1026,7 @@ var witnessSizeTests = []witnessSizeTest{ scriptTree, err := input.SecondLevelHtlcTapscriptTree( testKey.PubKey(), testCSVDelay, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1073,6 +1075,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], false, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1114,7 +1117,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], false, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1156,7 +1159,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], false, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1203,6 +1206,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], false, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1263,6 +1267,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), payHash[:], false, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1308,7 +1313,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], false, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1394,6 +1399,7 @@ func genTimeoutTx(t *testing.T, if chanType.IsTaproot() { tapscriptTree, err = input.SenderHTLCScriptTaproot( testPubkey, testPubkey, testPubkey, testHash160, false, + input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1462,7 +1468,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx { if chanType.IsTaproot() { tapscriptTree, err = input.ReceiverHTLCScriptTaproot( testCLTVExpiry, testPubkey, testPubkey, testPubkey, - testHash160, false, + testHash160, false, input.NoneTapLeaf(), ) require.NoError(t, err) diff --git a/input/taproot.go b/input/taproot.go index 34cdb974d5c..0fcd1df4e2a 100644 --- a/input/taproot.go +++ b/input/taproot.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/waddrmgr" + "github.com/lightningnetwork/lnd/fn" ) const ( @@ -21,6 +22,15 @@ const ( PubKeyFormatCompressedOdd byte = 0x03 ) +// AuxTapLeaf is a type alias for an optional tapscript leaf that may be added +// to the tapscript tree of HTLC and commitment outputs. +type AuxTapLeaf = fn.Option[txscript.TapLeaf] + +// NoneTapLeaf returns an empty optional tapscript leaf. +func NoneTapLeaf() AuxTapLeaf { + return fn.None[txscript.TapLeaf]() +} + // NewTxSigHashesV0Only returns a new txscript.TxSigHashes instance that will // only calculate the sighash midstate values for segwit v0 inputs and can // therefore never be used for transactions that want to spend segwit v1 diff --git a/input/taproot_test.go b/input/taproot_test.go index 801b0fef4d5..ad4f16c4be6 100644 --- a/input/taproot_test.go +++ b/input/taproot_test.go @@ -1,13 +1,16 @@ package input import ( + "bytes" "crypto/rand" + "fmt" "testing" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" @@ -31,7 +34,9 @@ type testSenderHtlcScriptTree struct { htlcAmt int64 } -func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree { +func newTestSenderHtlcScriptTree(t *testing.T, + auxLeaf AuxTapLeaf) *testSenderHtlcScriptTree { + var preImage lntypes.Preimage _, err := rand.Read(preImage[:]) require.NoError(t, err) @@ -48,7 +53,7 @@ func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], false, auxLeaf, ) require.NoError(t, err) @@ -207,13 +212,9 @@ func htlcSenderTimeoutWitnessGen(sigHash txscript.SigHashType, } } -// TestTaprootSenderHtlcSpend tests that all the positive and negative paths -// for the sender HTLC tapscript tree work as expected. -func TestTaprootSenderHtlcSpend(t *testing.T) { - t.Parallel() - +func testTaprootSenderHtlcSpend(t *testing.T, auxLeaf AuxTapLeaf) { // First, create a new test script tree. - htlcScriptTree := newTestSenderHtlcScriptTree(t) + htlcScriptTree := newTestSenderHtlcScriptTree(t, auxLeaf) spendTx := wire.NewMsgTx(2) spendTx.AddTxIn(&wire.TxIn{}) @@ -432,6 +433,28 @@ func TestTaprootSenderHtlcSpend(t *testing.T) { } } +// TestTaprootSenderHtlcSpend tests that all the positive and negative paths +// for the sender HTLC tapscript tree work as expected. +func TestTaprootSenderHtlcSpend(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootSenderHtlcSpend(t, auxLeaf) + }) + } +} + type testReceiverHtlcScriptTree struct { preImage lntypes.Preimage @@ -452,7 +475,9 @@ type testReceiverHtlcScriptTree struct { lockTime int32 } -func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree { +func newTestReceiverHtlcScriptTree(t *testing.T, + auxLeaf AuxTapLeaf) *testReceiverHtlcScriptTree { + var preImage lntypes.Preimage _, err := rand.Read(preImage[:]) require.NoError(t, err) @@ -471,7 +496,7 @@ func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := ReceiverHTLCScriptTaproot( cltvExpiry, senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], false, auxLeaf, ) require.NoError(t, err) @@ -629,15 +654,11 @@ func htlcReceiverSuccessWitnessGen(sigHash txscript.SigHashType, } } -// TestTaprootReceiverHtlcSpend tests that all possible paths for redeeming an -// accepted HTLC (on the commitment transaction) of the receiver work properly. -func TestTaprootReceiverHtlcSpend(t *testing.T) { - t.Parallel() - +func testTaprootReceiverHtlcSpend(t *testing.T, auxLeaf AuxTapLeaf) { // We'll start by creating the HTLC script tree (contains all 3 valid // spend paths), and also a mock spend transaction that we'll be // signing below. - htlcScriptTree := newTestReceiverHtlcScriptTree(t) + htlcScriptTree := newTestReceiverHtlcScriptTree(t, auxLeaf) // TODO(roasbeef): issue with revoke key??? ctrl block even/odd @@ -891,6 +912,28 @@ func TestTaprootReceiverHtlcSpend(t *testing.T) { } } +// TestTaprootReceiverHtlcSpend tests that all possible paths for redeeming an +// accepted HTLC (on the commitment transaction) of the receiver work properly. +func TestTaprootReceiverHtlcSpend(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootReceiverHtlcSpend(t, auxLeaf) + }) + } +} + type testCommitScriptTree struct { csvDelay uint32 @@ -905,7 +948,9 @@ type testCommitScriptTree struct { *CommitScriptTree } -func newTestCommitScriptTree(local bool) (*testCommitScriptTree, error) { +func newTestCommitScriptTree(local bool, + auxLeaf AuxTapLeaf) (*testCommitScriptTree, error) { + selfKey, err := btcec.NewPrivateKey() if err != nil { return nil, err @@ -925,10 +970,11 @@ func newTestCommitScriptTree(local bool) (*testCommitScriptTree, error) { if local { commitScriptTree, err = NewLocalCommitScriptTree( csvDelay, selfKey.PubKey(), revokeKey.PubKey(), + auxLeaf, ) } else { commitScriptTree, err = NewRemoteCommitScriptTree( - selfKey.PubKey(), + selfKey.PubKey(), auxLeaf, ) } if err != nil { @@ -1020,12 +1066,8 @@ func localCommitRevokeWitGen(sigHash txscript.SigHashType, } } -// TestTaprootCommitScriptToSelf tests that the taproot script for redeeming -// one's output after a force close behaves as expected. -func TestTaprootCommitScriptToSelf(t *testing.T) { - t.Parallel() - - commitScriptTree, err := newTestCommitScriptTree(true) +func testTaprootCommitScriptToSelf(t *testing.T, auxLeaf AuxTapLeaf) { + commitScriptTree, err := newTestCommitScriptTree(true, auxLeaf) require.NoError(t, err) spendTx := wire.NewMsgTx(2) @@ -1187,6 +1229,28 @@ func TestTaprootCommitScriptToSelf(t *testing.T) { } } +// TestTaprootCommitScriptToSelf tests that the taproot script for redeeming +// one's output after a force close behaves as expected. +func TestTaprootCommitScriptToSelf(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootCommitScriptToSelf(t, auxLeaf) + }) + } +} + func remoteCommitSweepWitGen(sigHash txscript.SigHashType, commitScriptTree *testCommitScriptTree) witnessGen { @@ -1220,12 +1284,8 @@ func remoteCommitSweepWitGen(sigHash txscript.SigHashType, } } -// TestTaprootCommitScriptRemote tests that the remote party can properly sweep -// their output after force close. -func TestTaprootCommitScriptRemote(t *testing.T) { - t.Parallel() - - commitScriptTree, err := newTestCommitScriptTree(false) +func testTaprootCommitScriptRemote(t *testing.T, auxLeaf AuxTapLeaf) { + commitScriptTree, err := newTestCommitScriptTree(false, auxLeaf) require.NoError(t, err) spendTx := wire.NewMsgTx(2) @@ -1364,6 +1424,28 @@ func TestTaprootCommitScriptRemote(t *testing.T) { } } +// TestTaprootCommitScriptRemote tests that the remote party can properly sweep +// their output after force close. +func TestTaprootCommitScriptRemote(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootCommitScriptRemote(t, auxLeaf) + }) + } +} + type testAnchorScriptTree struct { sweepKey *btcec.PrivateKey @@ -1599,25 +1681,21 @@ type testSecondLevelHtlcTree struct { tapScriptRoot []byte } -func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { +func newTestSecondLevelHtlcTree(t *testing.T, + auxLeaf AuxTapLeaf) *testSecondLevelHtlcTree { + delayKey, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } + require.NoError(t, err) revokeKey, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } + require.NoError(t, err) const csvDelay = 6 scriptTree, err := SecondLevelHtlcTapscriptTree( - delayKey.PubKey(), csvDelay, + delayKey.PubKey(), csvDelay, auxLeaf, ) - if err != nil { - return nil, err - } + require.NoError(t, err) tapScriptRoot := scriptTree.RootNode.TapHash() @@ -1626,9 +1704,7 @@ func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { ) pkScript, err := PayToTaprootScript(htlcKey) - if err != nil { - return nil, err - } + require.NoError(t, err) const amt = 100 @@ -1643,7 +1719,7 @@ func newTestSecondLevelHtlcTree() (*testSecondLevelHtlcTree, error) { amt: amt, scriptTree: scriptTree, tapScriptRoot: tapScriptRoot[:], - }, nil + } } func secondLevelHtlcSuccessWitGen(sigHash txscript.SigHashType, @@ -1713,13 +1789,8 @@ func secondLevelHtlcRevokeWitnessgen(sigHash txscript.SigHashType, } } -// TestTaprootSecondLevelHtlcScript tests that a channel peer can properly -// spend the second level HTLC script to resolve HTLCs. -func TestTaprootSecondLevelHtlcScript(t *testing.T) { - t.Parallel() - - htlcScriptTree, err := newTestSecondLevelHtlcTree() - require.NoError(t, err) +func testTaprootSecondLevelHtlcScript(t *testing.T, auxLeaf AuxTapLeaf) { + htlcScriptTree := newTestSecondLevelHtlcTree(t, auxLeaf) spendTx := wire.NewMsgTx(2) spendTx.AddTxIn(&wire.TxIn{}) @@ -1879,3 +1950,25 @@ func TestTaprootSecondLevelHtlcScript(t *testing.T) { }) } } + +// TestTaprootSecondLevelHtlcScript tests that a channel peer can properly +// spend the second level HTLC script to resolve HTLCs. +func TestTaprootSecondLevelHtlcScript(t *testing.T) { + t.Parallel() + + for _, hasAuxLeaf := range []bool{true, false} { + name := fmt.Sprintf("aux_leaf=%v", hasAuxLeaf) + t.Run(name, func(t *testing.T) { + var auxLeaf AuxTapLeaf + if hasAuxLeaf { + auxLeaf = fn.Some( + txscript.NewBaseTapLeaf( + bytes.Repeat([]byte{0x01}, 32), + ), + ) + } + + testTaprootSecondLevelHtlcScript(t, auxLeaf) + }) + } +} diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 67bf6506c26..5240d1bacc4 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -7119,6 +7119,7 @@ func newOutgoingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, + fn.None[txscript.TapLeaf](), ) if err != nil { return nil, err @@ -7364,6 +7365,7 @@ func newIncomingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, + fn.None[txscript.TapLeaf](), ) if err != nil { return nil, err diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 6ecef795b14..564fc7fa2fd 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -224,9 +224,8 @@ func (w *WitnessScriptDesc) WitnessScriptForPath(_ input.ScriptPath, // party learns of the preimage to the revocation hash, then they can claim all // the settled funds in the channel, plus the unsettled funds. func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, - selfKey, revokeKey *btcec.PublicKey, csvDelay, leaseExpiry uint32, -) ( - input.ScriptDescriptor, error) { + selfKey, revokeKey *btcec.PublicKey, csvDelay, + leaseExpiry uint32) (input.ScriptDescriptor, error) { switch { // For taproot scripts, we'll need to make a slightly modified script @@ -236,7 +235,7 @@ func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, // Our "redeem" script here is just the taproot witness program. case chanType.IsTaproot(): return input.NewLocalCommitScriptTree( - csvDelay, selfKey, revokeKey, + csvDelay, selfKey, revokeKey, input.NoneTapLeaf(), ) // If we are the initiator of a leased channel, then we have an @@ -320,7 +319,7 @@ func CommitScriptToRemote(chanType channeldb.ChannelType, initiator bool, // with the sole tap leaf enforcing the 1 CSV delay. case chanType.IsTaproot(): toRemoteScriptTree, err := input.NewRemoteCommitScriptTree( - remoteKey, + remoteKey, input.NoneTapLeaf(), ) if err != nil { return nil, 0, err @@ -426,7 +425,7 @@ func SecondLevelHtlcScript(chanType channeldb.ChannelType, initiator bool, // For taproot channels, the pkScript is a segwit v1 p2tr output. case chanType.IsTaproot(): return input.TaprootSecondLevelScriptTree( - revocationKey, delayKey, csvDelay, + revocationKey, delayKey, csvDelay, input.NoneTapLeaf(), ) // If we are the initiator of a leased channel, then we have an @@ -1095,6 +1094,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RevocationKey, rHash[:], ourCommit, + input.NoneTapLeaf(), ) // We're being paid via an HTLC by the remote party, and the HTLC is @@ -1104,6 +1104,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RevocationKey, rHash[:], ourCommit, + input.NoneTapLeaf(), ) // We're sending an HTLC which is being added to our commitment @@ -1113,6 +1114,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.RevocationKey, rHash[:], ourCommit, + input.NoneTapLeaf(), ) // Finally, we're paying the remote party via an HTLC, which is being @@ -1122,6 +1124,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.RevocationKey, rHash[:], ourCommit, + input.NoneTapLeaf(), ) } @@ -1135,8 +1138,8 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, // we need to sign for the remote party (2nd level HTLCs) is also returned // along side the multiplexer. func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, - timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing, -) (input.ScriptDescriptor, error) { + timeout uint32, rHash [32]byte, + keyRing *CommitmentKeyRing) (input.ScriptDescriptor, error) { if !chanType.IsTaproot() { return genSegwitV0HtlcScript( diff --git a/watchtower/blob/justice_kit.go b/watchtower/blob/justice_kit.go index 8b6c20194f1..7780239f07a 100644 --- a/watchtower/blob/justice_kit.go +++ b/watchtower/blob/justice_kit.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -307,9 +308,11 @@ func newTaprootJusticeKit(sweepScript []byte, keyRing := breachInfo.KeyRing + // TODO(roasbeef): aux leaf tower updates needed + tree, err := input.NewLocalCommitScriptTree( breachInfo.RemoteDelay, keyRing.ToLocalKey, - keyRing.RevocationKey, + keyRing.RevocationKey, fn.None[txscript.TapLeaf](), ) if err != nil { return nil, err @@ -416,7 +419,9 @@ func (t *taprootJusticeKit) ToRemoteOutputSpendInfo() (*txscript.PkScript, return nil, nil, 0, err } - scriptTree, err := input.NewRemoteCommitScriptTree(toRemotePk) + scriptTree, err := input.NewRemoteCommitScriptTree( + toRemotePk, fn.None[txscript.TapLeaf](), + ) if err != nil { return nil, nil, 0, err } diff --git a/watchtower/blob/justice_kit_test.go b/watchtower/blob/justice_kit_test.go index fd12993a0a7..a1d6ec9f2c4 100644 --- a/watchtower/blob/justice_kit_test.go +++ b/watchtower/blob/justice_kit_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -304,7 +305,9 @@ func TestJusticeKitRemoteWitnessConstruction(t *testing.T) { name: "taproot commitment", blobType: TypeAltruistTaprootCommit, expWitnessScript: func(pk *btcec.PublicKey) []byte { - tree, _ := input.NewRemoteCommitScriptTree(pk) + tree, _ := input.NewRemoteCommitScriptTree( + pk, fn.None[txscript.TapLeaf](), + ) return tree.SettleLeaf.Script }, @@ -461,6 +464,7 @@ func TestJusticeKitToLocalWitnessConstruction(t *testing.T) { script, _ := input.NewLocalCommitScriptTree( csvDelay, delay, rev, + fn.None[txscript.TapLeaf](), ) return script.RevocationLeaf.Script diff --git a/watchtower/lookout/justice_descriptor_test.go b/watchtower/lookout/justice_descriptor_test.go index 2ca187fdc78..29c7f9a5a11 100644 --- a/watchtower/lookout/justice_descriptor_test.go +++ b/watchtower/lookout/justice_descriptor_test.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" secp "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -123,7 +124,7 @@ func testJusticeDescriptor(t *testing.T, blobType blob.Type) { if isTaprootChannel { toLocalCommitTree, err = input.NewLocalCommitScriptTree( - csvDelay, toLocalPK, revPK, + csvDelay, toLocalPK, revPK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err) @@ -174,7 +175,7 @@ func testJusticeDescriptor(t *testing.T, blobType blob.Type) { toRemoteSequence = 1 commitScriptTree, err := input.NewRemoteCommitScriptTree( - toRemotePK, + toRemotePK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err) diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 695c4f9ecd7..7894631b8ff 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -136,6 +137,7 @@ func genTaskTest( if chanType.IsTaproot() { scriptTree, _ := input.NewLocalCommitScriptTree( csvDelay, toLocalPK, revPK, + fn.None[txscript.TapLeaf](), ) pkScript, _ := input.PayToTaprootScript( @@ -189,7 +191,7 @@ func genTaskTest( if chanType.IsTaproot() { scriptTree, _ := input.NewRemoteCommitScriptTree( - toRemotePK, + toRemotePK, fn.None[txscript.TapLeaf](), ) pkScript, _ := input.PayToTaprootScript( diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 38d9acd9f07..f3a4d5bf4ee 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -230,12 +231,14 @@ func (c *mockChannel) createRemoteCommitTx(t *testing.T) { // Construct the to-local witness script. toLocalScriptTree, err := input.NewLocalCommitScriptTree( - c.csvDelay, c.toLocalPK, c.revPK, + c.csvDelay, c.toLocalPK, c.revPK, fn.None[txscript.TapLeaf](), ) require.NoError(t, err, "unable to create to-local script") // Construct the to-remote witness script. - toRemoteScriptTree, err := input.NewRemoteCommitScriptTree(c.toRemotePK) + toRemoteScriptTree, err := input.NewRemoteCommitScriptTree( + c.toRemotePK, fn.None[txscript.TapLeaf](), + ) require.NoError(t, err, "unable to create to-remote script") // Compute the to-local witness script hash. From 51c2a1a9ebea3d8681d07027886b87d38846d62b Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Mon, 29 Apr 2024 11:24:50 +0200 Subject: [PATCH 02/14] mod: bump tlv to v1.2.5 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index b8a658edd76..75e43e5731e 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,7 @@ require ( github.com/lightningnetwork/lnd/queue v1.1.1 github.com/lightningnetwork/lnd/sqldb v1.0.1 github.com/lightningnetwork/lnd/ticker v1.1.1 - github.com/lightningnetwork/lnd/tlv v1.2.3 + github.com/lightningnetwork/lnd/tlv v1.2.5 github.com/lightningnetwork/lnd/tor v1.1.2 github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 github.com/miekg/dns v1.1.43 diff --git a/go.sum b/go.sum index c2f3fdb804b..c3f8b2e8265 100644 --- a/go.sum +++ b/go.sum @@ -457,8 +457,8 @@ github.com/lightningnetwork/lnd/sqldb v1.0.1 h1:lpNoJ6qRh3D02oeIUsKQLZUzjcgZ9ppM github.com/lightningnetwork/lnd/sqldb v1.0.1/go.mod h1:nSovU1U+gTPDWhfwmXu/kW8l8EJpwbvZQ05ijnkQzkA= github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM= github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA= -github.com/lightningnetwork/lnd/tlv v1.2.3 h1:If5ibokA/UoCBGuCKaY6Vn2SJU0l9uAbehCnhTZjEP8= -github.com/lightningnetwork/lnd/tlv v1.2.3/go.mod h1:zDkmqxOczP6LaLTvSFDQ1SJUfHcQRCMKFj93dn3eMB8= +github.com/lightningnetwork/lnd/tlv v1.2.5 h1:/VsoWw628t78OiDN90pHDbqwOcuZ9JMicxXZVQjBwX0= +github.com/lightningnetwork/lnd/tlv v1.2.5/go.mod h1:/CmY4VbItpOldksocmGT4lxiJqRP9oLxwSZOda2kzNQ= github.com/lightningnetwork/lnd/tor v1.1.2 h1:3zv9z/EivNFaMF89v3ciBjCS7kvCj4ZFG7XvD2Qq0/k= github.com/lightningnetwork/lnd/tor v1.1.2/go.mod h1:j7T9uJ2NLMaHwE7GiBGnpYLn4f7NRoTM6qj+ul6/ycA= github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 h1:sjOGyegMIhvgfq5oaue6Td+hxZuf3tDC8lAPrFldqFw= From cbaa0c780ddfb68cf5e137044add9b3c545bbf8d Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Sun, 17 Mar 2024 15:51:50 -0400 Subject: [PATCH 03/14] channeldb: new custom blob nested TLV In this commit, for each channel, we'll now start to store an optional custom blob. This can be used to store extra information for custom channels in an opauqe manner. --- channeldb/channel.go | 158 ++++++++++++++++++++++++++++++++++++-- channeldb/channel_test.go | 11 ++- 2 files changed, 163 insertions(+), 6 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index cb2490fcd7e..40bab3d700d 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -251,6 +251,10 @@ type chanAuxData struct { // tapscriptRoot is the optional Tapscript root the channel funding // output commits to. tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte] + + // customBlob is an optional TLV encoded blob of data representing + // custom channel funding information. + customBlob tlv.OptionalRecordT[tlv.TlvType7, tlv.Blob] } // encode serializes the chanAuxData to the given io.Writer. @@ -269,6 +273,9 @@ func (c *chanAuxData) encode(w io.Writer) error { tlvRecords = append(tlvRecords, root.Record()) }, ) + c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType7, tlv.Blob]) { + tlvRecords = append(tlvRecords, blob.Record()) + }) // Create the tlv stream. tlvStream, err := tlv.NewStream(tlvRecords...) @@ -283,6 +290,7 @@ func (c *chanAuxData) encode(w io.Writer) error { func (c *chanAuxData) decode(r io.Reader) error { memo := c.memo.Zero() tapscriptRoot := c.tapscriptRoot.Zero() + blob := c.customBlob.Zero() // Create the tlv stream. tlvStream, err := tlv.NewStream( @@ -292,6 +300,7 @@ func (c *chanAuxData) decode(r io.Reader) error { c.realScid.Record(), memo.Record(), tapscriptRoot.Record(), + blob.Record(), ) if err != nil { return err @@ -308,6 +317,9 @@ func (c *chanAuxData) decode(r io.Reader) error { if _, ok := tlvs[tapscriptRoot.TlvType()]; ok { c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot) } + if _, ok := tlvs[c.customBlob.TlvType()]; ok { + c.customBlob = tlv.SomeRecordT(blob) + } return nil } @@ -325,6 +337,9 @@ func (c *chanAuxData) toOpenChan(o *OpenChannel) { c.tapscriptRoot.WhenSomeV(func(h [32]byte) { o.TapscriptRoot = fn.Some[chainhash.Hash](h) }) + c.customBlob.WhenSomeV(func(blob tlv.Blob) { + o.CustomBlob = fn.Some(blob) + }) } // newChanAuxDataFromChan creates a new chanAuxData from the given channel. @@ -354,6 +369,11 @@ func newChanAuxDataFromChan(openChan *OpenChannel) *chanAuxData { tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h), ) }) + openChan.CustomBlob.WhenSome(func(blob tlv.Blob) { + c.customBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType7](blob), + ) + }) return c } @@ -607,6 +627,74 @@ type ChannelConfig struct { HtlcBasePoint keychain.KeyDescriptor } +// commitAuxData stores all the optional data that may be stored as a TLV stream +// at the _end_ of the normal serialized commit on disk. +type commitAuxData struct { + // customBlob is a custom blob that may store extra data for custom + // channels. + customBlob tlv.OptionalRecordT[tlv.TlvType1, tlv.Blob] +} + +// encode encodes the aux data into the passed io.Writer. +func (c *commitAuxData) encode(w io.Writer) error { + var tlvRecords []tlv.Record + c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType1, tlv.Blob]) { + tlvRecords = append(tlvRecords, blob.Record()) + }) + + // Create the tlv stream. + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + return err + } + + return tlvStream.Encode(w) +} + +// decode attempts to ecode the aux data from the passed io.Reader. +func (c *commitAuxData) decode(r io.Reader) error { + blob := c.customBlob.Zero() + + tlvStream, err := tlv.NewStream( + blob.Record(), + ) + if err != nil { + return err + } + + tlvs, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + if _, ok := tlvs[c.customBlob.TlvType()]; ok { + c.customBlob = tlv.SomeRecordT(blob) + } + + return nil +} + +// toChanCommit extracts the optional data stored in the commitAuxData struct +// and stores it in the ChannelCommitment. +func (c *commitAuxData) toChanCommit(commit *ChannelCommitment) { + c.customBlob.WhenSomeV(func(blob tlv.Blob) { + commit.CustomBlob = fn.Some(blob) + }) +} + +// newCommitAuxData creates an aux data struct from the normal chan commitment. +func newCommitAuxData(commit *ChannelCommitment) commitAuxData { + var c commitAuxData + + commit.CustomBlob.WhenSome(func(blob tlv.Blob) { + c.customBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](blob), + ) + }) + + return c +} + // ChannelCommitment is a snapshot of the commitment state at a particular // point in the commitment chain. With each state transition, a snapshot of the // current state along with all non-settled HTLCs are recorded. These snapshots @@ -673,6 +761,11 @@ type ChannelCommitment struct { // able by us. CommitTx *wire.MsgTx + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This may track some custom + // specific state for this given commitment. + CustomBlob fn.Option[tlv.Blob] + // CommitSig is one half of the signature required to fully complete // the script for the commitment transaction above. This is the // signature signed by the remote party for our version of the @@ -682,9 +775,6 @@ type ChannelCommitment struct { // Htlcs is the set of HTLC's that are pending at this particular // commitment height. Htlcs []HTLC - - // TODO(roasbeef): pending commit pointer? - // * lets just walk through } // ChannelStatus is a bit vector used to indicate whether an OpenChannel is in @@ -982,6 +1072,12 @@ type OpenChannel struct { // funding output. TapscriptRoot fn.Option[chainhash.Hash] + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob fn.Option[tlv.Blob] + // TODO(roasbeef): eww Db *ChannelStateDB @@ -2793,6 +2889,16 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl } } + // We'll also encode the commit aux data stream here. We do this here + // rather than above (at the call to serializeChanCommit), to ensure + // backwards compat for reads to existing non-custom channels. + // + // TODO(roasbeef): migrate it after all? + auxData := newCommitAuxData(&diff.Commitment) + if err := auxData.encode(w); err != nil { + return fmt.Errorf("unable to write aux data: %w", err) + } + return nil } @@ -2853,6 +2959,17 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { } } + // As a final step, we'll read out any aux commit data that we have at + // the end of this byte stream. We do this here to ensure backward + // compatibility, as otherwise we risk erroneously reading into the + // wrong field. + var auxData commitAuxData + if err := auxData.decode(r); err != nil { + return nil, fmt.Errorf("unable to decode aux data: %w", err) + } + + auxData.toChanCommit(&d.Commitment) + return &d, nil } @@ -3831,6 +3948,13 @@ func (c *OpenChannel) Snapshot() *ChannelSnapshot { }, } + localCommit.CustomBlob.WhenSome(func(blob tlv.Blob) { + blobCopy := make([]byte, len(blob)) + copy(blobCopy, blob) + + snapshot.ChannelCommitment.CustomBlob = fn.Some(blobCopy) + }) + // Copy over the current set of HTLCs to ensure the caller can't mutate // our internal state. snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs)) @@ -4222,6 +4346,12 @@ func putChanCommitment(chanBucket kvdb.RwBucket, c *ChannelCommitment, return err } + // Before we write to disk, we'll also write our aux data as well. + auxData := newCommitAuxData(c) + if err := auxData.encode(&b); err != nil { + return fmt.Errorf("unable to write aux data: %w", err) + } + return chanBucket.Put(commitKey, b.Bytes()) } @@ -4367,7 +4497,9 @@ func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { return c, nil } -func fetchChanCommitment(chanBucket kvdb.RBucket, local bool) (ChannelCommitment, error) { +func fetchChanCommitment(chanBucket kvdb.RBucket, + local bool) (ChannelCommitment, error) { + var commitKey []byte if local { commitKey = append(chanCommitmentKey, byte(0x00)) @@ -4381,7 +4513,23 @@ func fetchChanCommitment(chanBucket kvdb.RBucket, local bool) (ChannelCommitment } r := bytes.NewReader(commitBytes) - return deserializeChanCommit(r) + chanCommit, err := deserializeChanCommit(r) + if err != nil { + return ChannelCommitment{}, fmt.Errorf("unable to decode "+ + "chan commit: %w", err) + } + + // We'll also check to see if we have any aux data stored as the end of + // the stream. + var auxData commitAuxData + if err := auxData.decode(r); err != nil { + return ChannelCommitment{}, fmt.Errorf("unable to decode "+ + "chan aux data: %w", err) + } + + auxData.toChanCommit(&chanCommit) + + return chanCommit, nil } func fetchChanCommitments(chanBucket kvdb.RBucket, channel *OpenChannel) error { diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 2389015cfa3..e044cb2e4f2 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -337,6 +337,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { FeePerKw: btcutil.Amount(5000), CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), + CustomBlob: fn.Some([]byte{1, 2, 3}), }, RemoteCommitment: ChannelCommitment{ CommitHeight: 0, @@ -346,6 +347,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { FeePerKw: btcutil.Amount(5000), CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), + CustomBlob: fn.Some([]byte{4, 5, 6}), }, NumConfsRequired: 4, RemoteCurrentRevocation: privKey.PubKey(), @@ -360,6 +362,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { InitialRemoteBalance: lnwire.MilliSatoshi(3000), Memo: []byte("test"), TapscriptRoot: fn.Some(tapscriptRoot), + CustomBlob: fn.Some([]byte{1, 2, 3}), } } @@ -649,6 +652,7 @@ func TestChannelStateTransition(t *testing.T) { CommitTx: newTx, CommitSig: newSig, Htlcs: htlcs, + CustomBlob: fn.Some([]byte{4, 5, 6}), } // First update the local node's broadcastable state and also add a @@ -686,9 +690,14 @@ func TestChannelStateTransition(t *testing.T) { // have been updated. updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub) require.NoError(t, err, "unable to fetch updated channel") - assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment) + + assertCommitmentEqual( + t, &commitment, &updatedChannel[0].LocalCommitment, + ) + numDiskUpdates, err := updatedChannel[0].CommitmentHeight() require.NoError(t, err, "unable to read commitment height from disk") + if numDiskUpdates != uint64(commitment.CommitHeight) { t.Fatalf("num disk updates doesn't match: %v vs %v", numDiskUpdates, commitment.CommitHeight) From 5733bbb55c1d6e3b206c34ae43a525427308ef6f Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 20:00:29 -0700 Subject: [PATCH 04/14] lnwallet: export the HtlcView struct We'll need this later on to ensure we can always interact with the new aux blobs at all stages of commitment transaction construction. --- lnwallet/channel.go | 135 ++++++++++++++++++++++++--------------- lnwallet/channel_test.go | 22 +++---- lnwallet/commitment.go | 10 +-- 3 files changed, 100 insertions(+), 67 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 5240d1bacc4..d8845e8e74c 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2934,18 +2934,29 @@ func HtlcIsDust(chanType channeldb.ChannelType, return (htlcAmt - htlcFee) < dustLimit } -// htlcView represents the "active" HTLCs at a particular point within the +// HtlcView represents the "active" HTLCs at a particular point within the // history of the HTLC update log. -type htlcView struct { - ourUpdates []*PaymentDescriptor - theirUpdates []*PaymentDescriptor - feePerKw chainfee.SatPerKWeight +type HtlcView struct { + // NextHeight is the height of the commitment transaction that will be + // created using this view. + NextHeight uint64 + + // OurUpdates are our outgoing HTLCs. + OurUpdates []*PaymentDescriptor + + // TheirUpdates are their incoming HTLCs. + TheirUpdates []*PaymentDescriptor + + // FeePerKw is the fee rate in sat/kw of the commitment transaction. + FeePerKw chainfee.SatPerKWeight } // fetchHTLCView returns all the candidate HTLC updates which should be // considered for inclusion within a commitment based on the passed HTLC log // indexes. -func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *htlcView { +func (lc *LightningChannel) fetchHTLCView(theirLogIndex, + ourLogIndex uint64) *HtlcView { + var ourHTLCs []*PaymentDescriptor for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() { htlc := e.Value.(*PaymentDescriptor) @@ -2970,9 +2981,9 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *ht } } - return &htlcView{ - ourUpdates: ourHTLCs, - theirUpdates: theirHTLCs, + return &HtlcView{ + OurUpdates: ourHTLCs, + TheirUpdates: theirHTLCs, } } @@ -3007,7 +3018,10 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, if err != nil { return nil, err } - feePerKw := filteredHTLCView.feePerKw + feePerKw := filteredHTLCView.FeePerKw + + htlcView.NextHeight = nextHeight + filteredHTLCView.NextHeight = nextHeight // Actually generate unsigned commitment transaction for this view. commitTx, err := lc.commitBuilder.createUnsignedCommitmentTx( @@ -3067,12 +3081,16 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, // In order to ensure _none_ of the HTLC's associated with this new // commitment are mutated, we'll manually copy over each HTLC to its // respective slice. - c.outgoingHTLCs = make([]PaymentDescriptor, len(filteredHTLCView.ourUpdates)) - for i, htlc := range filteredHTLCView.ourUpdates { + c.outgoingHTLCs = make( + []PaymentDescriptor, len(filteredHTLCView.OurUpdates), + ) + for i, htlc := range filteredHTLCView.OurUpdates { c.outgoingHTLCs[i] = *htlc } - c.incomingHTLCs = make([]PaymentDescriptor, len(filteredHTLCView.theirUpdates)) - for i, htlc := range filteredHTLCView.theirUpdates { + c.incomingHTLCs = make( + []PaymentDescriptor, len(filteredHTLCView.TheirUpdates), + ) + for i, htlc := range filteredHTLCView.TheirUpdates { c.incomingHTLCs[i] = *htlc } @@ -3107,15 +3125,16 @@ func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn { // once for each height, and only in concert with signing a new commitment. // TODO(halseth): return htlcs to mutate instead of mutating inside // method. -func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, +func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - remoteChain, mutateState bool) (*htlcView, error) { + remoteChain, mutateState bool) (*HtlcView, error) { // We initialize the view's fee rate to the fee rate of the unfiltered // view. If any fee updates are found when evaluating the view, it will // be updated. - newView := &htlcView{ - feePerKw: view.feePerKw, + newView := &HtlcView{ + FeePerKw: view.FeePerKw, + NextHeight: nextHeight, } // We use two maps, one for the local log and one for the remote log to @@ -3128,7 +3147,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // First we run through non-add entries in both logs, populating the // skip sets and mutating the current chain state (crediting balances, // etc) to reflect the settle/timeout entry encountered. - for _, entry := range view.ourUpdates { + for _, entry := range view.OurUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -3157,10 +3176,13 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } skipThem[addEntry.HtlcIndex] = struct{}{} - processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, remoteChain, true, mutateState) + + processRemoveEntry( + entry, ourBalance, theirBalance, nextHeight, + remoteChain, true, mutateState, + ) } - for _, entry := range view.theirUpdates { + for _, entry := range view.TheirUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -3190,32 +3212,41 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } skipUs[addEntry.HtlcIndex] = struct{}{} - processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, remoteChain, false, mutateState) + + processRemoveEntry( + entry, ourBalance, theirBalance, nextHeight, + remoteChain, false, mutateState, + ) } // Next we take a second pass through all the log entries, skipping any // settled HTLCs, and debiting the chain state balance due to any newly // added HTLCs. - for _, entry := range view.ourUpdates { + for _, entry := range view.OurUpdates { isAdd := entry.EntryType == Add if _, ok := skipUs[entry.HtlcIndex]; !isAdd || ok { continue } - processAddEntry(entry, ourBalance, theirBalance, nextHeight, - remoteChain, false, mutateState) - newView.ourUpdates = append(newView.ourUpdates, entry) + processAddEntry( + entry, ourBalance, theirBalance, nextHeight, + remoteChain, false, mutateState, + ) + + newView.OurUpdates = append(newView.OurUpdates, entry) } - for _, entry := range view.theirUpdates { + for _, entry := range view.TheirUpdates { isAdd := entry.EntryType == Add if _, ok := skipThem[entry.HtlcIndex]; !isAdd || ok { continue } - processAddEntry(entry, ourBalance, theirBalance, nextHeight, - remoteChain, true, mutateState) - newView.theirUpdates = append(newView.theirUpdates, entry) + processAddEntry( + entry, ourBalance, theirBalance, nextHeight, + remoteChain, true, mutateState, + ) + + newView.TheirUpdates = append(newView.TheirUpdates, entry) } return newView, nil @@ -3363,7 +3394,7 @@ func processRemoveEntry(htlc *PaymentDescriptor, ourBalance, // processFeeUpdate processes a log update that updates the current commitment // fee. func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, - remoteChain bool, mutateState bool, view *htlcView) { + remoteChain bool, mutateState bool, view *HtlcView) { // Fee updates are applied for all commitments after they are // sent/received, so we consider them being added and removed at the @@ -3384,7 +3415,7 @@ func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, // If the update wasn't already locked in, update the current fee rate // to reflect this update. - view.feePerKw = chainfee.SatPerKWeight(feeUpdate.Amount.ToSatoshis()) + view.FeePerKw = chainfee.SatPerKWeight(feeUpdate.Amount.ToSatoshis()) if mutateState { *addHeight = nextHeight @@ -3959,10 +3990,10 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // appropriate update log, in order to validate the sanity of the // commitment resulting from _actually adding_ this HTLC to the state. if predictOurAdd != nil { - view.ourUpdates = append(view.ourUpdates, predictOurAdd) + view.OurUpdates = append(view.OurUpdates, predictOurAdd) } if predictTheirAdd != nil { - view.theirUpdates = append(view.theirUpdates, predictTheirAdd) + view.TheirUpdates = append(view.TheirUpdates, predictTheirAdd) } ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView( @@ -3972,7 +4003,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, return err } - feePerKw := filteredView.feePerKw + feePerKw := filteredView.FeePerKw // Ensure that the fee being applied is enough to be relayed across the // network in a reasonable time frame. @@ -4116,7 +4147,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // First check that the remote updates won't violate it's channel // constraints. err = validateUpdates( - filteredView.theirUpdates, &lc.channelState.RemoteChanCfg, + filteredView.TheirUpdates, &lc.channelState.RemoteChanCfg, ) if err != nil { return err @@ -4125,7 +4156,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // Secondly check that our updates won't violate our channel // constraints. err = validateUpdates( - filteredView.ourUpdates, &lc.channelState.LocalChanCfg, + filteredView.OurUpdates, &lc.channelState.LocalChanCfg, ) if err != nil { return err @@ -4727,7 +4758,7 @@ func (lc *LightningChannel) ProcessChanSyncMsg( return updates, openedCircuits, closedCircuits, nil } -// computeView takes the given htlcView, and calculates the balances, filtered +// computeView takes the given HtlcView, and calculates the balances, filtered // view (settling unsettled HTLCs), commitment weight and feePerKw, after // applying the HTLCs to the latest commitment. The returned balances are the // balances *before* subtracting the commitment fee from the initiator's @@ -4735,9 +4766,9 @@ func (lc *LightningChannel) ProcessChanSyncMsg( // // If the updateState boolean is set true, the add and remove heights of the // HTLCs will be set to the next commitment height. -func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, +func (lc *LightningChannel) computeView(view *HtlcView, remoteChain bool, updateState bool) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, int64, - *htlcView, error) { + *HtlcView, error) { commitChain := lc.localCommitChain dustLimit := lc.channelState.LocalChanCfg.DustLimit @@ -4768,7 +4799,7 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, // Initiate feePerKw to the last committed fee for this chain as we'll // need this to determine which HTLCs are dust, and also the final fee // rate. - view.feePerKw = commitChain.tip().feePerKw + view.FeePerKw = commitChain.tip().feePerKw // We evaluate the view at this stage, meaning settled and failed HTLCs // will remove their corresponding added HTLCs. The resulting filtered @@ -4776,12 +4807,14 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, // channel constraints to the final commitment state. If any fee // updates are found in the logs, the commitment fee rate should be // changed, so we'll also set the feePerKw to this new value. - filteredHTLCView, err := lc.evaluateHTLCView(view, &ourBalance, - &theirBalance, nextHeight, remoteChain, updateState) + filteredHTLCView, err := lc.evaluateHTLCView( + view, &ourBalance, &theirBalance, nextHeight, remoteChain, + updateState, + ) if err != nil { return 0, 0, 0, nil, err } - feePerKw := filteredHTLCView.feePerKw + feePerKw := filteredHTLCView.FeePerKw // We need to first check ourBalance and theirBalance to be negative // because MilliSathoshi is a unsigned type and can underflow in @@ -4799,7 +4832,7 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, // Now go through all HTLCs at this stage, to calculate the total // weight, needed to calculate the transaction fee. var totalHtlcWeight int64 - for _, htlc := range filteredHTLCView.ourUpdates { + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( lc.channelState.ChanType, false, !remoteChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -4810,7 +4843,7 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, totalHtlcWeight += input.HTLCWeight } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( lc.channelState.ChanType, true, !remoteChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -8271,13 +8304,13 @@ func (lc *LightningChannel) availableBalance( } // availableCommitmentBalance attempts to calculate the balance we have -// available for HTLCs on the local/remote commitment given the htlcView. To +// available for HTLCs on the local/remote commitment given the HtlcView. To // account for sending HTLCs of different sizes, it will report the balance // available for sending non-dust HTLCs, which will be manifested on the // commitment, increasing the commitment fee we must pay as an initiator, // eating into our balance. It will make sure we won't violate the channel // reserve constraints for this amount. -func (lc *LightningChannel) availableCommitmentBalance(view *htlcView, +func (lc *LightningChannel) availableCommitmentBalance(view *HtlcView, remoteChain bool, buffer BufferType) (lnwire.MilliSatoshi, int64) { // Compute the current balances for this commitment. This will take @@ -8305,7 +8338,7 @@ func (lc *LightningChannel) availableCommitmentBalance(view *htlcView, // Calculate the commitment fee in the case where we would add another // HTLC to the commitment, as only the balance remaining after this fee // has been paid is actually available for sending. - feePerKw := filteredView.feePerKw + feePerKw := filteredView.FeePerKw additionalHtlcFee := lnwire.NewMSatFromSatoshis( feePerKw.FeeForWeight(input.HTLCWeight), ) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index c73abdbe0e3..d08eeea29bc 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -8540,10 +8540,10 @@ func TestEvaluateView(t *testing.T) { } } - view := &htlcView{ - ourUpdates: test.ourHtlcs, - theirUpdates: test.theirHtlcs, - feePerKw: feePerKw, + view := &HtlcView{ + OurUpdates: test.ourHtlcs, + TheirUpdates: test.theirHtlcs, + FeePerKw: feePerKw, } var ( @@ -8564,17 +8564,17 @@ func TestEvaluateView(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if result.feePerKw != test.expectedFee { + if result.FeePerKw != test.expectedFee { t.Fatalf("expected fee: %v, got: %v", - test.expectedFee, result.feePerKw) + test.expectedFee, result.FeePerKw) } checkExpectedHtlcs( - t, result.ourUpdates, test.ourExpectedHtlcs, + t, result.OurUpdates, test.ourExpectedHtlcs, ) checkExpectedHtlcs( - t, result.theirUpdates, test.theirExpectedHtlcs, + t, result.TheirUpdates, test.theirExpectedHtlcs, ) if lc.channelState.TotalMSatSent != test.expectSent { @@ -8797,15 +8797,15 @@ func TestProcessFeeUpdate(t *testing.T) { EntryType: FeeUpdate, } - view := &htlcView{ - feePerKw: chainfee.SatPerKWeight(feePerKw), + view := &HtlcView{ + FeePerKw: chainfee.SatPerKWeight(feePerKw), } processFeeUpdate( update, nextHeight, test.remoteChain, test.mutate, view, ) - if view.feePerKw != test.expectedFee { + if view.FeePerKw != test.expectedFee { t.Fatalf("expected fee: %v, got: %v", test.expectedFee, feePerKw) } diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 564fc7fa2fd..5dc6ddfbb20 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -686,7 +686,7 @@ type unsignedCommitmentTx struct { func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, theirBalance lnwire.MilliSatoshi, isOurs bool, feePerKw chainfee.SatPerKWeight, height uint64, - filteredHTLCView *htlcView, + filteredHTLCView *HtlcView, keyRing *CommitmentKeyRing) (*unsignedCommitmentTx, error) { dustLimit := cb.chanState.LocalChanCfg.DustLimit @@ -695,7 +695,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } numHTLCs := int64(0) - for _, htlc := range filteredHTLCView.ourUpdates { + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( cb.chanState.ChanType, false, isOurs, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -706,7 +706,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, numHTLCs++ } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( cb.chanState.ChanType, true, isOurs, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -790,7 +790,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, // commitment outputs and should correspond to zero values for the // purposes of sorting. cltvs := make([]uint32, len(commitTx.TxOut)) - for _, htlc := range filteredHTLCView.ourUpdates { + for _, htlc := range filteredHTLCView.OurUpdates { if HtlcIsDust( cb.chanState.ChanType, false, isOurs, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -808,7 +808,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } cltvs = append(cltvs, htlc.Timeout) // nolint:makezero } - for _, htlc := range filteredHTLCView.theirUpdates { + for _, htlc := range filteredHTLCView.TheirUpdates { if HtlcIsDust( cb.chanState.ChanType, true, isOurs, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, From 12acbac16abdbd4c27e439beba623fe15c3a44a7 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 20:07:15 -0700 Subject: [PATCH 05/14] lnwallet: add custom tlv blob to internal commitment struct In this commit, we also add the custom TLV blob to the internal commitment struct that we use within the in-memory commitment linked list. This'll be useful to ensure that we're tracking the current blob for our in memory commitment for when we need to write it to disk. --- lnwallet/channel.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index d8845e8e74c..716e583f0c1 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -32,6 +32,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -546,6 +547,10 @@ type commitment struct { // on this commitment transaction. incomingHTLCs []PaymentDescriptor + // customBlob stores opaque bytes that may be used by custom channels + // to store extra data for a given commitment state. + customBlob fn.Option[tlv.Blob] + // [outgoing|incoming]HTLCIndex is an index that maps an output index // on the commitment transaction to the payment descriptor that // represents the HTLC output. @@ -726,6 +731,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { CommitTx: c.txn, CommitSig: c.sig, Htlcs: make([]channeldb.HTLC, 0, numHtlcs), + CustomBlob: c.customBlob, } for _, htlc := range c.outgoingHTLCs { @@ -971,6 +977,7 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, feePerKw: chainfee.SatPerKWeight(diskCommit.FeePerKw), incomingHTLCs: incomingHtlcs, outgoingHTLCs: outgoingHtlcs, + customBlob: diskCommit.CustomBlob, } if isLocal { commit.dustLimit = lc.channelState.LocalChanCfg.DustLimit From fc44ff4bffe7b346f0dc95f93e818b38e3c25d8d Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 19:52:21 -0700 Subject: [PATCH 06/14] input: add some utility type definitions for aux leaves In this commit, we add some useful type definitions for the aux leaf. --- input/taproot.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/input/taproot.go b/input/taproot.go index 0fcd1df4e2a..6edaf5de621 100644 --- a/input/taproot.go +++ b/input/taproot.go @@ -31,6 +31,24 @@ func NoneTapLeaf() AuxTapLeaf { return fn.None[txscript.TapLeaf]() } +// HtlcIndex represents the monotonically increasing counter that is used to +// identify HTLCs created a peer. +type HtlcIndex = uint64 + +// HtlcAuxLeaf is a type that represents an auxiliary leaf for an HTLC output. +// An HTLC may have up to two aux leaves: one for the output on the commitment +// transaction, and one for the second level HTLC. +type HtlcAuxLeaf struct { + AuxTapLeaf + + // SecondLevelLeaf is the auxiliary leaf for the second level HTLC + // success or timeout transaction. + SecondLevelLeaf AuxTapLeaf +} + +// AuxTapLeaves is a type alias for a slice of optional tapscript leaves. +type AuxTapLeaves = map[HtlcIndex]HtlcAuxLeaf + // NewTxSigHashesV0Only returns a new txscript.TxSigHashes instance that will // only calculate the sighash midstate values for segwit v0 inputs and can // therefore never be used for transactions that want to spend segwit v1 From 258dd5c96e391fae755f137f65c7a27a86e67f0f Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Sun, 17 Mar 2024 16:53:38 -0400 Subject: [PATCH 07/14] lnwallet+channeldb: add new AuxLeafStore for dynamic aux leaves In this commit, we add a new AuxLeafStore which can be used to dynamically fetch the latest aux leaves for a given state. This is useful for custom channel types that will store some extra information in the form of a custom blob, then will use that information to derive the new leaf tapscript leaves that may be attached to reach state. --- contractcourt/chain_watcher.go | 3 +- lnwallet/channel.go | 58 ++++++-- lnwallet/commitment.go | 247 ++++++++++++++++++++++++++++++--- lnwallet/transactions_test.go | 2 + lnwallet/wallet.go | 27 +++- 5 files changed, 303 insertions(+), 34 deletions(-) diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 5372d8c0dde..28022650cc8 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -429,7 +429,7 @@ func (c *chainWatcher) handleUnknownLocalState( } remoteScript, _, err := lnwallet.CommitScriptToRemote( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, - commitKeyRing.ToRemoteKey, leaseExpiry, + commitKeyRing.ToRemoteKey, leaseExpiry, input.NoneTapLeaf(), ) if err != nil { return false, err @@ -442,6 +442,7 @@ func (c *chainWatcher) handleUnknownLocalState( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, commitKeyRing.ToLocalKey, commitKeyRing.RevocationKey, uint32(c.cfg.chanState.LocalChanCfg.CsvDelay), leaseExpiry, + input.NoneTapLeaf(), ) if err != nil { return false, err diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 716e583f0c1..96f99987ea7 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -1269,6 +1269,10 @@ type LightningChannel struct { // machine. Signer input.Signer + // leafStore is used to retrieve extra tapscript leaves for special + // custom channel types. + leafStore fn.Option[AuxLeafStore] + // signDesc is the primary sign descriptor that is capable of signing // the commitment transaction that spends the multi-sig output. signDesc *input.SignDescriptor @@ -1344,6 +1348,8 @@ type channelOpts struct { localNonce *musig2.Nonces remoteNonce *musig2.Nonces + leafStore fn.Option[AuxLeafStore] + skipNonceInit bool } @@ -1374,6 +1380,13 @@ func WithSkipNonceInit() ChannelOpt { } } +// WithLeafStore is used to specify a custom leaf store for the channel. +func WithLeafStore(store AuxLeafStore) ChannelOpt { + return func(o *channelOpts) { + o.leafStore = fn.Some[AuxLeafStore](store) + } +} + // defaultChannelOpts returns the set of default options for a new channel. func defaultChannelOpts() *channelOpts { return &channelOpts{} @@ -1415,13 +1428,16 @@ func NewLightningChannel(signer input.Signer, } lc := &LightningChannel{ - Signer: signer, - sigPool: sigPool, - currentHeight: localCommit.CommitHeight, - remoteCommitChain: newCommitmentChain(), - localCommitChain: newCommitmentChain(), - channelState: state, - commitBuilder: NewCommitmentBuilder(state), + Signer: signer, + leafStore: opts.leafStore, + sigPool: sigPool, + currentHeight: localCommit.CommitHeight, + remoteCommitChain: newCommitmentChain(), + localCommitChain: newCommitmentChain(), + channelState: state, + commitBuilder: NewCommitmentBuilder( + state, opts.leafStore, + ), localUpdateLog: localUpdateLog, remoteUpdateLog: remoteUpdateLog, Capacity: state.Capacity, @@ -2464,12 +2480,14 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, leaseExpiry = chanState.ThawHeight } + // TODO(roasbeef): fetch aux leave + // Since it is the remote breach we are reconstructing, the output // going to us will be a to-remote script with our local params. isRemoteInitiator := !chanState.IsInitiator ourScript, ourDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, input.NoneTapLeaf(), ) if err != nil { return nil, err @@ -2479,6 +2497,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, theirScript, err := CommitScriptToSelf( chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey, keyRing.RevocationKey, theirDelay, leaseExpiry, + input.NoneTapLeaf(), ) if err != nil { return nil, err @@ -3033,7 +3052,7 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, // Actually generate unsigned commitment transaction for this view. commitTx, err := lc.commitBuilder.createUnsignedCommitmentTx( ourBalance, theirBalance, !remoteChain, feePerKw, nextHeight, - filteredHTLCView, keyRing, + htlcView, filteredHTLCView, keyRing, commitChain.tip(), ) if err != nil { return nil, err @@ -3068,6 +3087,16 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, effFeeRate, spew.Sdump(commitTx)) } + // Given the custom blob of the past state, and this new HTLC view, + // we'll generate a new blob for the latest commitment. + newCommitBlob, err := updateAuxBlob( + lc.channelState, commitChain.tip().customBlob, htlcView, + !remoteChain, ourBalance, theirBalance, lc.leafStore, *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // With the commitment view created, store the resulting balances and // transaction with the other parameters for this height. c := &commitment{ @@ -3083,6 +3112,7 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, feePerKw: feePerKw, dustLimit: dustLimit, isOurs: !remoteChain, + customBlob: newCommitBlob, } // In order to ensure _none_ of the HTLC's associated with this new @@ -3174,6 +3204,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // number of satoshis we've received within the channel. if mutateState && entry.EntryType == Settle && !remoteChain && entry.removeCommitHeightLocal == 0 { + lc.channelState.TotalMSatReceived += entry.Amount } @@ -5854,7 +5885,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( // before the change since the indexes are meant for the current, // revoked remote commitment. ourOutputIndex, theirOutputIndex, err := findOutputIndexesFromRemote( - revocation, lc.channelState, + revocation, lc.channelState, lc.leafStore, ) if err != nil { return nil, nil, nil, nil, err @@ -6736,12 +6767,14 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si commitTxBroadcast := commitSpend.SpendingTx + // TODO(roasbeef): fetch aux leave + // Before we can generate the proper sign descriptor, we'll need to // locate the output index of our non-delayed output on the commitment // transaction. selfScript, maturityDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, input.NoneTapLeaf(), ) if err != nil { return nil, fmt.Errorf("unable to create self commit "+ @@ -7684,6 +7717,8 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) + // TODO(roasbeef): fetch aux leave + var leaseExpiry uint32 if chanState.ChanType.HasLeaseExpiration() { leaseExpiry = chanState.ThawHeight @@ -7691,6 +7726,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, toLocalScript, err := CommitScriptToSelf( chanState.ChanType, chanState.IsInitiator, keyRing.ToLocalKey, keyRing.RevocationKey, csvTimeout, leaseExpiry, + input.NoneTapLeaf(), ) if err != nil { return nil, err diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 5dc6ddfbb20..fa8e8e08330 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -11,9 +11,11 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" ) // anchorSize is the constant anchor output size. @@ -224,8 +226,8 @@ func (w *WitnessScriptDesc) WitnessScriptForPath(_ input.ScriptPath, // party learns of the preimage to the revocation hash, then they can claim all // the settled funds in the channel, plus the unsettled funds. func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, - selfKey, revokeKey *btcec.PublicKey, csvDelay, - leaseExpiry uint32) (input.ScriptDescriptor, error) { + selfKey, revokeKey *btcec.PublicKey, csvDelay, leaseExpiry uint32, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { switch { // For taproot scripts, we'll need to make a slightly modified script @@ -235,7 +237,7 @@ func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, // Our "redeem" script here is just the taproot witness program. case chanType.IsTaproot(): return input.NewLocalCommitScriptTree( - csvDelay, selfKey, revokeKey, input.NoneTapLeaf(), + csvDelay, selfKey, revokeKey, auxLeaf, ) // If we are the initiator of a leased channel, then we have an @@ -289,8 +291,8 @@ func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool, // script for. The second return value is the CSV delay of the output script, // what must be satisfied in order to spend the output. func CommitScriptToRemote(chanType channeldb.ChannelType, initiator bool, - remoteKey *btcec.PublicKey, - leaseExpiry uint32) (input.ScriptDescriptor, uint32, error) { + remoteKey *btcec.PublicKey, leaseExpiry uint32, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, uint32, error) { switch { // If we are not the initiator of a leased channel, then the remote @@ -319,7 +321,7 @@ func CommitScriptToRemote(chanType channeldb.ChannelType, initiator bool, // with the sole tap leaf enforcing the 1 CSV delay. case chanType.IsTaproot(): toRemoteScriptTree, err := input.NewRemoteCommitScriptTree( - remoteKey, input.NoneTapLeaf(), + remoteKey, auxLeaf, ) if err != nil { return nil, 0, err @@ -609,11 +611,82 @@ func CommitScriptAnchors(chanType channeldb.ChannelType, return localAnchor, remoteAnchor, nil } +// CommitSortFunc is a function type alias for a function that sorts the +// commitment transaction outputs. The second parameter is a list of CLTV +// timeouts that must correspond to the number of transaction outputs, with the +// value of 0 for non-HTLC outputs. +type CommitSortFunc func(*wire.MsgTx, []uint32) error + +// CommitAuxLeaves stores two potential auxiliary leaves for the remote and +// local output that may be used to argument the final tapscript trees of the +// commitment transaction. +type CommitAuxLeaves struct { + // LocalAuxLeaf is the local party's auxiliary leaf. + LocalAuxLeaf input.AuxTapLeaf + + // RemoteAuxLeaf is the remote party's auxiliary leaf. + RemoteAuxLeaf input.AuxTapLeaf + + // OutgoingHTLCLeaves is the set of aux leaves for the outgoing HTLCs + // on this commitment transaction. + OutgoingHtlcLeaves input.AuxTapLeaves + + // IncomingHTLCLeaves is the set of aux leaves for the incoming HTLCs + // on this commitment transaction. + IncomingHtlcLeaves input.AuxTapLeaves +} + +// ForRemoteCommit returns the local+remote aux leaves from the PoV of the +// remote party's commitment. +func (c *CommitAuxLeaves) ForRemoteCommit() CommitAuxLeaves { + return CommitAuxLeaves{ + LocalAuxLeaf: c.RemoteAuxLeaf, + RemoteAuxLeaf: c.LocalAuxLeaf, + OutgoingHtlcLeaves: c.IncomingHtlcLeaves, + IncomingHtlcLeaves: c.OutgoingHtlcLeaves, + } +} + +// AuxLeafStore is used to optionally fetch auxiliary tapscript leaves for the +// commitment transaction given an opaque blob. This is also used to implement +// a state transition function for the blobs to allow them to be refreshed with +// each state. +type AuxLeafStore interface { + // FetchLeavesFromView attempts to fetch the auxiliary leaves that + // correspond to the passed aux blob, and pending original (unfiltered) + // HTLC view. + FetchLeavesFromView(chanState *channeldb.OpenChannel, prevBlob tlv.Blob, + unfilteredView *HtlcView, isOurCommit bool, ourBalance, + theirBalance lnwire.MilliSatoshi, + keyRing CommitmentKeyRing) (fn.Option[CommitAuxLeaves], + CommitSortFunc, error) + + // FetchLeavesFromCommit attempts to fetch the auxiliary leaves that + // correspond to the passed aux blob, and an existing channel + // commitment. + FetchLeavesFromCommit(chanState *channeldb.OpenChannel, + commit channeldb.ChannelCommitment, + keyRing CommitmentKeyRing) (fn.Option[CommitAuxLeaves], error) + + // FetchLeavesFromRevocation attempts to fetch the auxiliary leaves + // from a channel revocation that stores balance + blob information. + FetchLeavesFromRevocation(revocation *channeldb.RevocationLog, + ) (fn.Option[CommitAuxLeaves], error) + + // ApplyHtlcView serves as the state transition function for the custom + // channel's blob. Given the old blob, and an HTLC view, then a new + // blob should be returned that reflects the pending updates. + ApplyHtlcView(chanState *channeldb.OpenChannel, prevBlob tlv.Blob, + unfilteredView *HtlcView, isOurCommit bool, ourBalance, + theirBalance lnwire.MilliSatoshi, + keyRing CommitmentKeyRing) (fn.Option[tlv.Blob], error) +} + // CommitmentBuilder is a type that wraps the type of channel we are dealing // with, and abstracts the various ways of constructing commitment // transactions. type CommitmentBuilder struct { - // chanState is the underlying channels's state struct, used to + // chanState is the underlying channel's state struct, used to // determine the type of channel we are dealing with, and relevant // parameters. chanState *channeldb.OpenChannel @@ -621,18 +694,25 @@ type CommitmentBuilder struct { // obfuscator is a 48-bit state hint that's used to obfuscate the // current state number on the commitment transactions. obfuscator [StateHintSize]byte + + // auxLeafStore is an interface that allows us to fetch auxiliary + // tapscript leaves for the commitment output. + auxLeafStore fn.Option[AuxLeafStore] } // NewCommitmentBuilder creates a new CommitmentBuilder from chanState. -func NewCommitmentBuilder(chanState *channeldb.OpenChannel) *CommitmentBuilder { +func NewCommitmentBuilder(chanState *channeldb.OpenChannel, + leafStore fn.Option[AuxLeafStore]) *CommitmentBuilder { + // The anchor channel type MUST be tweakless. if chanState.ChanType.HasAnchors() && !chanState.ChanType.IsTweakless() { panic("invalid channel type combination") } return &CommitmentBuilder{ - chanState: chanState, - obfuscator: createStateHintObfuscator(chanState), + chanState: chanState, + obfuscator: createStateHintObfuscator(chanState), + auxLeafStore: leafStore, } } @@ -679,15 +759,89 @@ type unsignedCommitmentTx struct { cltvs []uint32 } +// AuxLeavesFromCommit is a helper function that attempts to fetch the +// auxiliary leaves given a finalized channel commitment, and a leaf store. +func AuxLeavesFromCommit(chanState *channeldb.OpenChannel, + commit channeldb.ChannelCommitment, leafStore fn.Option[AuxLeafStore], + keyRing CommitmentKeyRing) (fn.Option[CommitAuxLeaves], error) { + + if leafStore.IsNone() { + return fn.None[CommitAuxLeaves](), nil + } + + return leafStore.UnsafeFromSome().FetchLeavesFromCommit( + chanState, commit, keyRing, + ) +} + +// auxLeavesFromView is used to derive the set of commit aux leaves (if any), +// that are needed to create a new commitment transaction using the original +// (unfiltered) htlc view. +func auxLeavesFromView(chanState *channeldb.OpenChannel, + prevBlob fn.Option[tlv.Blob], originalView *HtlcView, isOurCommit bool, + ourBalance, theirBalance lnwire.MilliSatoshi, + leafStore fn.Option[AuxLeafStore], + keyRing CommitmentKeyRing) (fn.Option[CommitAuxLeaves], CommitSortFunc, + error) { + + if leafStore.IsNone() { + return fn.None[CommitAuxLeaves](), nil, nil + } + + if prevBlob.IsNone() { + return fn.None[CommitAuxLeaves](), nil, nil + } + + return leafStore.UnsafeFromSome().FetchLeavesFromView( + chanState, prevBlob.UnsafeFromSome(), originalView, isOurCommit, + ourBalance, theirBalance, keyRing, + ) +} + +// auxLeavesFromRevocation is a helper function that attempts to fetch the aux +// leaves given a revoked state. +func auxLeavesFromRevocation(_ *channeldb.OpenChannel, + revocation *channeldb.RevocationLog, leafStore fn.Option[AuxLeafStore], + _ CommitmentKeyRing) (fn.Option[CommitAuxLeaves], error) { + + if leafStore.IsNone() { + return fn.None[CommitAuxLeaves](), nil + } + + return leafStore.UnsafeFromSome().FetchLeavesFromRevocation(revocation) +} + +// updateAuxBlob is a helper function that attempts to update the aux blob +// given the prior and current state information. +func updateAuxBlob(chanState *channeldb.OpenChannel, + prevBlob fn.Option[tlv.Blob], nextViewUnfiltered *HtlcView, + isOurCommit bool, ourBalance, theirBalance lnwire.MilliSatoshi, + leafStore fn.Option[AuxLeafStore], + keyRing CommitmentKeyRing) (fn.Option[tlv.Blob], error) { + + if leafStore.IsNone() { + return fn.None[tlv.Blob](), nil + } + + if prevBlob.IsNone() { + return fn.None[tlv.Blob](), nil + } + + return leafStore.UnsafeFromSome().ApplyHtlcView( + chanState, prevBlob.UnsafeFromSome(), nextViewUnfiltered, + isOurCommit, ourBalance, theirBalance, keyRing, + ) +} + // createUnsignedCommitmentTx generates the unsigned commitment transaction for // a commitment view and returns it as part of the unsignedCommitmentTx. The // passed in balances should be balances *before* subtracting any commitment // fees, but after anchor outputs. func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, theirBalance lnwire.MilliSatoshi, isOurs bool, - feePerKw chainfee.SatPerKWeight, height uint64, - filteredHTLCView *HtlcView, - keyRing *CommitmentKeyRing) (*unsignedCommitmentTx, error) { + feePerKw chainfee.SatPerKWeight, height uint64, originalHtlcView, + filteredHTLCView *HtlcView, keyRing *CommitmentKeyRing, + prevCommit *commitment) (*unsignedCommitmentTx, error) { dustLimit := cb.chanState.LocalChanCfg.DustLimit if !isOurs { @@ -753,6 +907,17 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, err error ) + // Before we create the commitment transaction below, we'll try to see + // if there're any aux leaves that need to be a part of the tapscript + // tree. We'll only do this if we have a custom blob defined though. + auxLeaves, customCommitSort, err := auxLeavesFromView( + cb.chanState, prevCommit.customBlob, originalHtlcView, + isOurs, ourBalance, theirBalance, cb.auxLeafStore, *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // Depending on whether the transaction is ours or not, we call // CreateCommitTx with parameters matching the perspective, to generate // a new commitment transaction with all the latest unsettled/un-timed @@ -767,6 +932,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, &cb.chanState.LocalChanCfg, &cb.chanState.RemoteChanCfg, ourBalance.ToSatoshis(), theirBalance.ToSatoshis(), numHTLCs, cb.chanState.IsInitiator, leaseExpiry, + auxLeaves, ) } else { commitTx, err = CreateCommitTx( @@ -774,6 +940,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, &cb.chanState.RemoteChanCfg, &cb.chanState.LocalChanCfg, theirBalance.ToSatoshis(), ourBalance.ToSatoshis(), numHTLCs, !cb.chanState.IsInitiator, leaseExpiry, + auxLeaves, ) } if err != nil { @@ -836,9 +1003,24 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } // Sort the transactions according to the agreed upon canonical - // ordering. This lets us skip sending the entire transaction over, - // instead we'll just send signatures. - InPlaceCommitSort(commitTx, cltvs) + // ordering (which might be customized for custom channel types, but + // deterministic and both parties will arrive at the same result). This + // lets us skip sending the entire transaction over, instead we'll just + // send signatures. + if auxLeaves.IsSome() { + if customCommitSort == nil { + return nil, fmt.Errorf("custom channel type requires " + + "sorting function") + } + + err = customCommitSort(commitTx, cltvs) + if err != nil { + return nil, fmt.Errorf("unable to sort commitment "+ + "transaction by custom order: %w", err) + } + } else { + InPlaceCommitSort(commitTx, cltvs) + } // Next, we'll ensure that we don't accidentally create a commitment // transaction which would be invalid by consensus. @@ -880,24 +1062,33 @@ func CreateCommitTx(chanType channeldb.ChannelType, fundingOutput wire.TxIn, keyRing *CommitmentKeyRing, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, amountToLocal, amountToRemote btcutil.Amount, - numHTLCs int64, initiator bool, leaseExpiry uint32) (*wire.MsgTx, error) { + numHTLCs int64, initiator bool, leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*wire.MsgTx, error) { // First, we create the script for the delayed "pay-to-self" output. // This output has 2 main redemption clauses: either we can redeem the // output after a relative block delay, or the remote node can claim // the funds with the revocation key if we broadcast a revoked // commitment transaction. + localAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + })(auxLeaves) toLocalScript, err := CommitScriptToSelf( chanType, initiator, keyRing.ToLocalKey, keyRing.RevocationKey, uint32(localChanCfg.CsvDelay), leaseExpiry, + fn.FlattenOption(localAuxLeaf), ) if err != nil { return nil, err } // Next, we create the script paying to the remote. + remoteAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + })(auxLeaves) toRemoteScript, _, err := CommitScriptToRemote( chanType, initiator, keyRing.ToRemoteKey, leaseExpiry, + fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return nil, err @@ -1201,7 +1392,8 @@ func addHTLC(commitTx *wire.MsgTx, ourCommit bool, // output scripts and compares them against the outputs inside the commitment // to find the match. func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, - chanState *channeldb.OpenChannel) (uint32, uint32, error) { + chanState *channeldb.OpenChannel, + leafStore fn.Option[AuxLeafStore]) (uint32, uint32, error) { // Init the output indexes as empty. ourIndex := uint32(channeldb.OutputIndexEmpty) @@ -1231,6 +1423,16 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, leaseExpiry = chanState.ThawHeight } + // If we have a custom blob, then we'll attempt to fetch the aux leaves + // for this state. + auxLeaves, err := AuxLeavesFromCommit( + chanState, chanCommit, leafStore, *keyRing, + ) + if err != nil { + return ourIndex, theirIndex, fmt.Errorf("unable to fetch aux "+ + "leaves: %w", err) + } + // Map the scripts from our PoV. When facing a local commitment, the to // local output belongs to us and the to remote output belongs to them. // When facing a remote commitment, the to local output belongs to them @@ -1238,9 +1440,13 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, // Compute the to local script. From our PoV, when facing a remote // commitment, the to local output belongs to them. + localAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + })(auxLeaves) theirScript, err := CommitScriptToSelf( chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey, keyRing.RevocationKey, theirDelay, leaseExpiry, + fn.FlattenOption(localAuxLeaf), ) if err != nil { return ourIndex, theirIndex, err @@ -1248,9 +1454,12 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, // Compute the to remote script. From our PoV, when facing a remote // commitment, the to remote output belongs to us. + remoteAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + })(auxLeaves) ourScript, _, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, + leaseExpiry, fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return ourIndex, theirIndex, err diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index ab0ef73283f..14c11b7f764 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -21,6 +21,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" @@ -631,6 +632,7 @@ func testSpendValidation(t *testing.T, tweakless bool) { commitmentTx, err := CreateCommitTx( channelType, *fakeFundingTxIn, keyRing, aliceChanCfg, bobChanCfg, channelBalance, channelBalance, 0, true, 0, + fn.None[CommitAuxLeaves](), ) if err != nil { t.Fatalf("unable to create commitment transaction: %v", nil) diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index 1e8a644db91..77dabfcba78 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -1458,6 +1458,21 @@ func (l *LightningWallet) handleFundingCancelRequest(req *fundingReserveCancelMs req.err <- nil } +// createCommitOpts is a struct that holds the options for creating a new +// commitment transaction. +type createCommitOpts struct { + auxLeaves fn.Option[CommitAuxLeaves] +} + +// defaultCommitOpts returns a new createCommitOpts with default values. +func defaultCommitOpts() createCommitOpts { + return createCommitOpts{} +} + +// CreateCommitOpt is a functional option that can be used to modify the way a +// new commitment transaction is created. +type CreateCommitOpt func(*createCommitOpts) + // CreateCommitmentTxns is a helper function that creates the initial // commitment transaction for both parties. This function is used during the // initial funding workflow as both sides must generate a signature for the @@ -1467,7 +1482,13 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, ourChanCfg, theirChanCfg *channeldb.ChannelConfig, localCommitPoint, remoteCommitPoint *btcec.PublicKey, fundingTxIn wire.TxIn, chanType channeldb.ChannelType, initiator bool, - leaseExpiry uint32) (*wire.MsgTx, *wire.MsgTx, error) { + leaseExpiry uint32, opts ...CreateCommitOpt) (*wire.MsgTx, *wire.MsgTx, + error) { + + options := defaultCommitOpts() + for _, optFunc := range opts { + optFunc(&options) + } localCommitmentKeys := DeriveCommitmentKeys( localCommitPoint, true, chanType, ourChanCfg, theirChanCfg, @@ -1479,7 +1500,7 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, ourCommitTx, err := CreateCommitTx( chanType, fundingTxIn, localCommitmentKeys, ourChanCfg, theirChanCfg, localBalance, remoteBalance, 0, initiator, - leaseExpiry, + leaseExpiry, options.auxLeaves, ) if err != nil { return nil, nil, err @@ -1493,7 +1514,7 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, theirCommitTx, err := CreateCommitTx( chanType, fundingTxIn, remoteCommitmentKeys, theirChanCfg, ourChanCfg, remoteBalance, localBalance, 0, !initiator, - leaseExpiry, + leaseExpiry, options.auxLeaves, ) if err != nil { return nil, nil, err From f3655da0705a6442911d769756aedfd14d9ebc9d Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Sat, 30 Mar 2024 17:28:35 -0700 Subject: [PATCH 08/14] channeldb: convert HTLCEntry to use tlv.RecordT --- channeldb/channel_test.go | 24 ++-- channeldb/revocation_log.go | 201 +++++++++++++++---------------- channeldb/revocation_log_test.go | 100 +++++++++++---- lnwallet/channel.go | 18 +-- lnwallet/channel_test.go | 29 +++-- 5 files changed, 209 insertions(+), 163 deletions(-) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index e044cb2e4f2..8e2984f8778 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -579,15 +579,21 @@ func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment, require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch") for i, rHtlc := range r.HTLCEntries { cHtlc := c.Htlcs[i] - require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch") - require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(), - "Amt mismatch") - require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout, - "RefundTimeout mismatch") - require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex, - "OutputIndex mismatch") - require.Equal(t, rHtlc.Incoming, cHtlc.Incoming, - "Incoming mismatch") + require.Equal(t, rHtlc.RHash.Val[:], cHtlc.RHash[:], "RHash") + require.Equal( + t, rHtlc.Amt.Val.Int(), cHtlc.Amt.ToSatoshis(), "Amt", + ) + require.Equal( + t, rHtlc.RefundTimeout.Val, cHtlc.RefundTimeout, + "RefundTimeout", + ) + require.EqualValues( + t, rHtlc.OutputIndex.Val, cHtlc.OutputIndex, + "OutputIndex", + ) + require.Equal( + t, rHtlc.Incoming.Val, cHtlc.Incoming, "Incoming", + ) } } diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index f062ac0860e..26e4b9644ac 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -54,6 +54,74 @@ var ( ErrOutputIndexTooBig = errors.New("output index is over uint16") ) +// SparsePayHash is a type alias for a 32 byte array, which when serialized is +// able to save some space by not including an empty payment hash on disk. +type SparsePayHash [32]byte + +// NewSparsePayHash creates a new SparsePayHash from a 32 byte array. +func NewSparsePayHash(rHash [32]byte) SparsePayHash { + return SparsePayHash(rHash) +} + +// Record returns a tlv record for the SparsePayHash. +func (s *SparsePayHash) Record() tlv.Record { + // We use a zero for the type here, as this'll be used along with the + // RecordT type. + return tlv.MakeDynamicRecord( + 0, s, s.hashLen, + sparseHashEncoder, sparseHashDecoder, + ) +} + +// hashLen is used by MakeDynamicRecord to return the size of the RHash. +// +// NOTE: for zero hash, we return a length 0. +func (s *SparsePayHash) hashLen() uint64 { + if bytes.Equal(s[:], lntypes.ZeroHash[:]) { + return 0 + } + + return 32 +} + +// sparseHashEncoder is the customized encoder which skips encoding the empty +// hash. +func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the value is an empty hash, we will skip encoding it. + if bytes.Equal(v[:], lntypes.ZeroHash[:]) { + return nil + } + + vArray := (*[32]byte)(v) + + return tlv.EBytes32(w, vArray, buf) +} + +// sparseHashDecoder is the customized decoder which skips decoding the empty +// hash. +func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the length is zero, we will skip encoding the empty hash. + if l == 0 { + return nil + } + + vArray := (*[32]byte)(v) + + return tlv.DBytes32(r, vArray, buf, 32) +} + // HTLCEntry specifies the minimal info needed to be stored on disk for ALL the // historical HTLCs, which is useful for constructing RevocationLog when a // breach is detected. @@ -72,116 +140,60 @@ var ( // made into tlv records without further conversion. type HTLCEntry struct { // RHash is the payment hash of the HTLC. - RHash [32]byte + RHash tlv.RecordT[tlv.TlvType0, SparsePayHash] // RefundTimeout is the absolute timeout on the HTLC that the sender // must wait before reclaiming the funds in limbo. - RefundTimeout uint32 + RefundTimeout tlv.RecordT[tlv.TlvType1, uint32] // OutputIndex is the output index for this particular HTLC output // within the commitment transaction. // // NOTE: we use uint16 instead of int32 here to save us 2 bytes, which // gives us a max number of HTLCs of 65K. - OutputIndex uint16 + OutputIndex tlv.RecordT[tlv.TlvType2, uint16] // Incoming denotes whether we're the receiver or the sender of this // HTLC. // // NOTE: this field is the memory representation of the field // incomingUint. - Incoming bool + Incoming tlv.RecordT[tlv.TlvType3, bool] // Amt is the amount of satoshis this HTLC escrows. // // NOTE: this field is the memory representation of the field amtUint. - Amt btcutil.Amount - - // amtTlv is the uint64 format of Amt. This field is created so we can - // easily make it into a tlv record and save it to disk. - // - // NOTE: we keep this field for accounting purpose only. If the disk - // space becomes an issue, we could delete this field to save us extra - // 8 bytes. - amtTlv uint64 - - // incomingTlv is the uint8 format of Incoming. This field is created - // so we can easily make it into a tlv record and save it to disk. - incomingTlv uint8 -} - -// RHashLen is used by MakeDynamicRecord to return the size of the RHash. -// -// NOTE: for zero hash, we return a length 0. -func (h *HTLCEntry) RHashLen() uint64 { - if h.RHash == lntypes.ZeroHash { - return 0 - } - return 32 -} - -// RHashEncoder is the customized encoder which skips encoding the empty hash. -func RHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { - v, ok := val.(*[32]byte) - if !ok { - return tlv.NewTypeForEncodingErr(val, "RHash") - } - - // If the value is an empty hash, we will skip encoding it. - if *v == lntypes.ZeroHash { - return nil - } - - return tlv.EBytes32(w, v, buf) -} - -// RHashDecoder is the customized decoder which skips decoding the empty hash. -func RHashDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { - v, ok := val.(*[32]byte) - if !ok { - return tlv.NewTypeForEncodingErr(val, "RHash") - } - - // If the length is zero, we will skip encoding the empty hash. - if l == 0 { - return nil - } - - return tlv.DBytes32(r, v, buf, 32) + Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] } // toTlvStream converts an HTLCEntry record into a tlv representation. func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { - const ( - // A set of tlv type definitions used to serialize htlc entries - // to the database. We define it here instead of the head of - // the file to avoid naming conflicts. - // - // NOTE: A migration should be added whenever this list - // changes. - rHashType tlv.Type = 0 - refundTimeoutType tlv.Type = 1 - outputIndexType tlv.Type = 2 - incomingType tlv.Type = 3 - amtType tlv.Type = 4 + return tlv.NewStream( + h.RHash.Record(), + h.RefundTimeout.Record(), + h.OutputIndex.Record(), + h.Incoming.Record(), + h.Amt.Record(), ) +} - return tlv.NewStream( - tlv.MakeDynamicRecord( - rHashType, &h.RHash, h.RHashLen, - RHashEncoder, RHashDecoder, +// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC. +func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { + return &HTLCEntry{ + RHash: tlv.NewRecordT[tlv.TlvType0]( + NewSparsePayHash(htlc.RHash), ), - tlv.MakePrimitiveRecord( - refundTimeoutType, &h.RefundTimeout, + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1]( + htlc.RefundTimeout, ), - tlv.MakePrimitiveRecord( - outputIndexType, &h.OutputIndex, + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint16(htlc.OutputIndex), ), - tlv.MakePrimitiveRecord(incomingType, &h.incomingTlv), - // We will save 3 bytes if the amount is less or equal to - // 4,294,967,295 msat, or roughly 0.043 bitcoin. - tlv.MakeBigSizeRecord(amtType, &h.amtTlv), - ) + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), + ), + } } // RevocationLog stores the info needed to construct a breach retribution. Its @@ -265,13 +277,7 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, return ErrOutputIndexTooBig } - entry := &HTLCEntry{ - RHash: htlc.RHash, - RefundTimeout: htlc.RefundTimeout, - Incoming: htlc.Incoming, - OutputIndex: uint16(htlc.OutputIndex), - Amt: htlc.Amt.ToSatoshis(), - } + entry := NewHTLCEntryFromHTLC(htlc) rl.HTLCEntries = append(rl.HTLCEntries, entry) } @@ -351,14 +357,6 @@ func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { // format. func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { for _, htlc := range htlcs { - // Patch the incomingTlv field. - if htlc.Incoming { - htlc.incomingTlv = 1 - } - - // Patch the amtTlv field. - htlc.amtTlv = uint64(htlc.Amt) - // Create the tlv stream. tlvStream, err := htlc.toTlvStream() if err != nil { @@ -447,14 +445,6 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { return nil, err } - // Patch the Incoming field. - if htlc.incomingTlv == 1 { - htlc.Incoming = true - } - - // Patch the Amt field. - htlc.Amt = btcutil.Amount(htlc.amtTlv) - // Append the entry. htlcs = append(htlcs, &htlc) } @@ -469,6 +459,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error { if err := s.Encode(&b); err != nil { return err } + // Write the stream's length as a varint. err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{}) if err != nil { diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index fc5303a48dc..c9d2a16f708 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -34,12 +34,16 @@ var ( } testHTLCEntry = HTLCEntry{ - RefundTimeout: 740_000, - OutputIndex: 10, - Incoming: true, - Amt: 1000_000, - amtTlv: 1000_000, - incomingTlv: 1, + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32]( + 740_000, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16]( + 10, + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(1_000_000)), + ), } testHTLCEntryBytes = []byte{ // Body length 23. @@ -56,6 +60,40 @@ var ( 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, } + testHTLCEntryHash = HTLCEntry{ + RHash: tlv.NewPrimitiveRecord[tlv.TlvType0](NewSparsePayHash( + [32]byte{0x33, 0x44, 0x55}, + )), + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32]( + 740_000, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16]( + 10, + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(1_000_000)), + ), + } + testHTLCEntryHashBytes = []byte{ + // Body length 54. + 0x36, + // Rhash tlv. + 0x0, 0x20, + 0x33, 0x44, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // RefundTimeout tlv. + 0x1, 0x4, 0x0, 0xb, 0x4a, 0xa0, + // OutputIndex tlv. + 0x2, 0x2, 0x0, 0xa, + // Incoming tlv. + 0x3, 0x1, 0x1, + // Amt tlv. + 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, + } + localBalance = lnwire.MilliSatoshi(9000) remoteBalance = lnwire.MilliSatoshi(3000) @@ -68,11 +106,11 @@ var ( CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), Htlcs: []HTLC{{ - RefundTimeout: testHTLCEntry.RefundTimeout, - OutputIndex: int32(testHTLCEntry.OutputIndex), - Incoming: testHTLCEntry.Incoming, + RefundTimeout: testHTLCEntry.RefundTimeout.Val, + OutputIndex: int32(testHTLCEntry.OutputIndex.Val), + Incoming: testHTLCEntry.Incoming.Val, Amt: lnwire.NewMSatFromSatoshis( - testHTLCEntry.Amt, + testHTLCEntry.Amt.Val.Int(), ), }}, } @@ -193,11 +231,6 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) { // Copy the testHTLCEntry. entry := testHTLCEntry - // Set the internal fields to empty values so we can test the bytes are - // padded. - entry.incomingTlv = 0 - entry.amtTlv = 0 - // Write the tlv stream. buf := bytes.NewBuffer([]byte{}) err := serializeHTLCEntries(buf, []*HTLCEntry{&entry}) @@ -207,6 +240,21 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) { require.Equal(t, testHTLCEntryBytes, buf.Bytes()) } +func TestSerializeHTLCEntriesWithRHash(t *testing.T) { + t.Parallel() + + // Copy the testHTLCEntry. + entry := testHTLCEntryHash + + // Write the tlv stream. + buf := bytes.NewBuffer([]byte{}) + err := serializeHTLCEntries(buf, []*HTLCEntry{&entry}) + require.NoError(t, err) + + // Check the bytes are read as expected. + require.Equal(t, testHTLCEntryHashBytes, buf.Bytes()) +} + func TestSerializeHTLCEntries(t *testing.T) { t.Parallel() @@ -215,7 +263,7 @@ func TestSerializeHTLCEntries(t *testing.T) { // Create a fake rHash. rHashBytes := bytes.Repeat([]byte{10}, 32) - copy(entry.RHash[:], rHashBytes) + copy(entry.RHash.Val[:], rHashBytes) // Construct the serialized bytes. // @@ -269,7 +317,7 @@ func TestSerializeAndDeserializeRevLog(t *testing.T) { t, &test.revLog, test.revLogBytes, ) - testDerializeRevocationLog( + testDeserializeRevocationLog( t, &test.revLog, test.revLogBytes, ) }) @@ -293,7 +341,7 @@ func testSerializeRevocationLog(t *testing.T, rl *RevocationLog, require.Equal(t, revLogBytes, buf.Bytes()[:bodyIndex]) } -func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog, +func testDeserializeRevocationLog(t *testing.T, revLog *RevocationLog, revLogBytes []byte) { // Construct the full bytes. @@ -309,7 +357,7 @@ func testDerializeRevocationLog(t *testing.T, revLog *RevocationLog, require.Equal(t, *revLog, rl) } -func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) { +func TestDeserializeHTLCEntriesEmptyRHash(t *testing.T) { t.Parallel() // Read the tlv stream. @@ -322,7 +370,7 @@ func TestDerializeHTLCEntriesEmptyRHash(t *testing.T) { require.Equal(t, &testHTLCEntry, htlcs[0]) } -func TestDerializeHTLCEntries(t *testing.T) { +func TestDeserializeHTLCEntries(t *testing.T) { t.Parallel() // Copy the testHTLCEntry. @@ -330,7 +378,7 @@ func TestDerializeHTLCEntries(t *testing.T) { // Create a fake rHash. rHashBytes := bytes.Repeat([]byte{10}, 32) - copy(entry.RHash[:], rHashBytes) + copy(entry.RHash.Val[:], rHashBytes) // Construct the serialized bytes. // @@ -398,11 +446,11 @@ func TestDeleteLogBucket(t *testing.T) { err = kvdb.Update(backend, func(tx kvdb.RwTx) error { // Create the buckets. - chanBucket, _, err := createTestRevocatoinLogBuckets(tx) + chanBucket, _, err := createTestRevocationLogBuckets(tx) require.NoError(t, err) // Create the buckets again should give us an error. - _, _, err = createTestRevocatoinLogBuckets(tx) + _, _, err = createTestRevocationLogBuckets(tx) require.ErrorIs(t, err, kvdb.ErrBucketExists) // Delete both buckets. @@ -410,7 +458,7 @@ func TestDeleteLogBucket(t *testing.T) { require.NoError(t, err) // Create the buckets again should give us NO error. - _, _, err = createTestRevocatoinLogBuckets(tx) + _, _, err = createTestRevocationLogBuckets(tx) return err }, func() {}) require.NoError(t, err) @@ -516,7 +564,7 @@ func TestPutRevocationLog(t *testing.T) { // Construct the testing db transaction. dbTx := func(tx kvdb.RwTx) (RevocationLog, error) { // Create the buckets. - _, bucket, err := createTestRevocatoinLogBuckets(tx) + _, bucket, err := createTestRevocationLogBuckets(tx) require.NoError(t, err) // Save the log. @@ -686,7 +734,7 @@ func TestFetchRevocationLogCompatible(t *testing.T) { } } -func createTestRevocatoinLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket, +func createTestRevocationLogBuckets(tx kvdb.RwTx) (kvdb.RwBucket, kvdb.RwBucket, error) { chanBucket, err := tx.CreateTopLevelBucket(openChannelBucket) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 96f99987ea7..61c9fc6bf31 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2674,8 +2674,8 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // then from the PoV of the remote commitment state, they're the // receiver of this HTLC. scriptInfo, err := genHtlcScript( - chanState.ChanType, htlc.Incoming, false, - htlc.RefundTimeout, htlc.RHash, keyRing, + chanState.ChanType, htlc.Incoming.Val, false, + htlc.RefundTimeout.Val, htlc.RHash.Val, keyRing, ) if err != nil { return emptyRetribution, err @@ -2688,7 +2688,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, WitnessScript: scriptInfo.WitnessScriptToSign(), Output: &wire.TxOut{ PkScript: scriptInfo.PkScript(), - Value: int64(htlc.Amt), + Value: int64(htlc.Amt.Val.Int()), }, HashType: sweepSigHash(chanState.ChanType), } @@ -2721,10 +2721,10 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, SignDesc: signDesc, OutPoint: wire.OutPoint{ Hash: commitHash, - Index: uint32(htlc.OutputIndex), + Index: uint32(htlc.OutputIndex.Val), }, SecondLevelWitnessScript: secondLevelWitnessScript, - IsIncoming: htlc.Incoming, + IsIncoming: htlc.Incoming.Val, SecondLevelTapTweak: secondLevelTapTweak, }, nil } @@ -2886,13 +2886,7 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, continue } - entry := &channeldb.HTLCEntry{ - RHash: htlc.RHash, - RefundTimeout: htlc.RefundTimeout, - OutputIndex: uint16(htlc.OutputIndex), - Incoming: htlc.Incoming, - Amt: htlc.Amt.ToSatoshis(), - } + entry := channeldb.NewHTLCEntryFromHTLC(htlc) hr, err := createHtlcRetribution( chanState, keyRing, commitHash, commitmentSecret, leaseExpiry, entry, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index d08eeea29bc..84c694c6116 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -9950,9 +9950,11 @@ func TestCreateHtlcRetribution(t *testing.T) { aliceChannel.channelState, ) htlc := &channeldb.HTLCEntry{ - Amt: testAmt, - Incoming: true, - OutputIndex: 1, + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(testAmt), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16](1), } // Create the htlc retribution. @@ -9966,8 +9968,8 @@ func TestCreateHtlcRetribution(t *testing.T) { // Check the fields have expected values. require.EqualValues(t, testAmt, hr.SignDesc.Output.Value) require.Equal(t, commitHash, hr.OutPoint.Hash) - require.EqualValues(t, htlc.OutputIndex, hr.OutPoint.Index) - require.Equal(t, htlc.Incoming, hr.IsIncoming) + require.EqualValues(t, htlc.OutputIndex.Val, hr.OutPoint.Index) + require.Equal(t, htlc.Incoming.Val, hr.IsIncoming) } // TestCreateBreachRetribution checks that `createBreachRetribution` behaves as @@ -10007,9 +10009,13 @@ func TestCreateBreachRetribution(t *testing.T) { aliceChannel.channelState, ) htlc := &channeldb.HTLCEntry{ - Amt: btcutil.Amount(testAmt), - Incoming: true, - OutputIndex: uint16(htlcIndex), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(testAmt)), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint16(htlcIndex), + ), } // Create a dummy revocation log. @@ -10136,11 +10142,12 @@ func TestCreateBreachRetribution(t *testing.T) { require.Equal(t, remote, br.RemoteOutpoint) for _, hr := range br.HtlcRetributions { - require.EqualValues(t, testAmt, - hr.SignDesc.Output.Value) + require.EqualValues( + t, testAmt, hr.SignDesc.Output.Value, + ) require.Equal(t, commitHash, hr.OutPoint.Hash) require.EqualValues(t, htlcIndex, hr.OutPoint.Index) - require.Equal(t, htlc.Incoming, hr.IsIncoming) + require.Equal(t, htlc.Incoming.Val, hr.IsIncoming) } } From 0bf20500a78396432a8060bf25e2bfd42a581f4b Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Sat, 30 Mar 2024 18:13:57 -0700 Subject: [PATCH 09/14] channeldb: convert RevocationLog to use RecordT --- channeldb/channel_test.go | 12 ++- channeldb/revocation_log.go | 158 ++++++++++++++++++------------- channeldb/revocation_log_test.go | 25 +++-- lnwallet/channel.go | 36 ++++--- lnwallet/channel_test.go | 28 +++--- 5 files changed, 146 insertions(+), 113 deletions(-) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 8e2984f8778..8eca3e7dd4c 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -570,9 +570,11 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment, r *RevocationLog) { + t.Helper() + // Check the common fields. require.EqualValues( - t, r.CommitTxHash, c.CommitTx.TxHash(), "CommitTx mismatch", + t, r.CommitTxHash.Val, c.CommitTx.TxHash(), "CommitTx mismatch", ) // Now check the common fields from the HTLCs. @@ -806,10 +808,10 @@ func TestChannelStateTransition(t *testing.T) { // Check the output indexes are saved as expected. require.EqualValues( - t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, + t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex.Val, ) require.EqualValues( - t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, + t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex.Val, ) // The two deltas (the original vs the on-disk version) should @@ -851,10 +853,10 @@ func TestChannelStateTransition(t *testing.T) { // Check the output indexes are saved as expected. require.EqualValues( - t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex, + t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex.Val, ) require.EqualValues( - t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex, + t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex.Val, ) assertRevocationLogEntryEqual(t, &oldRemoteCommit, prevCommit) diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index 26e4b9644ac..3fc4163c112 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -7,6 +7,7 @@ import ( "math" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -16,16 +17,15 @@ import ( const ( // OutputIndexEmpty is used when the output index doesn't exist. OutputIndexEmpty = math.MaxUint16 +) - // A set of tlv type definitions used to serialize the body of - // revocation logs to the database. - // - // NOTE: A migration should be added whenever this list changes. - revLogOurOutputIndexType tlv.Type = 0 - revLogTheirOutputIndexType tlv.Type = 1 - revLogCommitTxHashType tlv.Type = 2 - revLogOurBalanceType tlv.Type = 3 - revLogTheirBalanceType tlv.Type = 4 +type ( + // BigSizeAmount is a type alias for a TLV record of a btcutil.Amount. + BigSizeAmount = tlv.BigSizeT[btcutil.Amount] + + // BigSizeMilliSatoshi is a type alias for a TLV record of a + // lnwire.MilliSatoshi. + BigSizeMilliSatoshi = tlv.BigSizeT[lnwire.MilliSatoshi] ) var ( @@ -203,15 +203,15 @@ func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { type RevocationLog struct { // OurOutputIndex specifies our output index in this commitment. In a // remote commitment transaction, this is the to remote output index. - OurOutputIndex uint16 + OurOutputIndex tlv.RecordT[tlv.TlvType0, uint16] // TheirOutputIndex specifies their output index in this commitment. In // a remote commitment transaction, this is the to local output index. - TheirOutputIndex uint16 + TheirOutputIndex tlv.RecordT[tlv.TlvType1, uint16] // CommitTxHash is the hash of the latest version of the commitment // state, broadcast able by us. - CommitTxHash [32]byte + CommitTxHash tlv.RecordT[tlv.TlvType2, [32]byte] // HTLCEntries is the set of HTLCEntry's that are pending at this // particular commitment height. @@ -221,21 +221,53 @@ type RevocationLog struct { // directly spendable by us. In other words, it is the value of the // to_remote output on the remote parties' commitment transaction. // - // NOTE: this is a pointer so that it is clear if the value is zero or + // NOTE: this is an option so that it is clear if the value is zero or // nil. Since migration 30 of the channeldb initially did not include // this field, it could be the case that the field is not present for // all revocation logs. - OurBalance *lnwire.MilliSatoshi + OurBalance tlv.OptionalRecordT[tlv.TlvType3, BigSizeMilliSatoshi] // TheirBalance is the current available balance within the channel // directly spendable by the remote node. In other words, it is the // value of the to_local output on the remote parties' commitment. // - // NOTE: this is a pointer so that it is clear if the value is zero or + // NOTE: this is an option so that it is clear if the value is zero or // nil. Since migration 30 of the channeldb initially did not include // this field, it could be the case that the field is not present for // all revocation logs. - TheirBalance *lnwire.MilliSatoshi + TheirBalance tlv.OptionalRecordT[tlv.TlvType4, BigSizeMilliSatoshi] +} + +// NewRevocationLog creates a new RevocationLog from the given parameters. +func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16, + commitHash [32]byte, ourBalance, + theirBalance fn.Option[lnwire.MilliSatoshi], + htlcs []*HTLCEntry) RevocationLog { + + rl := RevocationLog{ + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + ourOutputIndex, + ), + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + theirOutputIndex, + ), + CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2](commitHash), + HTLCEntries: htlcs, + } + + ourBalance.WhenSome(func(balance lnwire.MilliSatoshi) { + rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( + tlv.NewBigSizeT(balance), + )) + }) + + theirBalance.WhenSome(func(balance lnwire.MilliSatoshi) { + rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(balance), + )) + }) + + return rl } // putRevocationLog uses the fields `CommitTx` and `Htlcs` from a @@ -254,15 +286,26 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, } rl := &RevocationLog{ - OurOutputIndex: uint16(ourOutputIndex), - TheirOutputIndex: uint16(theirOutputIndex), - CommitTxHash: commit.CommitTx.TxHash(), - HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint16(ourOutputIndex), + ), + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint16(theirOutputIndex), + ), + CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2, [32]byte]( + commit.CommitTx.TxHash(), + ), + HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), } if !noAmtData { - rl.OurBalance = &commit.LocalBalance - rl.TheirBalance = &commit.RemoteBalance + rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( + tlv.NewBigSizeT(commit.LocalBalance), + )) + + rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(commit.RemoteBalance), + )) } for _, htlc := range commit.Htlcs { @@ -312,31 +355,23 @@ func fetchRevocationLog(log kvdb.RBucket, func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { // Add the tlv records for all non-optional fields. records := []tlv.Record{ - tlv.MakePrimitiveRecord( - revLogOurOutputIndexType, &rl.OurOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogTheirOutputIndexType, &rl.TheirOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogCommitTxHashType, &rl.CommitTxHash, - ), + rl.OurOutputIndex.Record(), + rl.TheirOutputIndex.Record(), + rl.CommitTxHash.Record(), } // Now we add any optional fields that are non-nil. - if rl.OurBalance != nil { - lb := uint64(*rl.OurBalance) - records = append(records, tlv.MakeBigSizeRecord( - revLogOurBalanceType, &lb, - )) - } + rl.OurBalance.WhenSome( + func(r tlv.RecordT[tlv.TlvType3, BigSizeMilliSatoshi]) { + records = append(records, r.Record()) + }, + ) - if rl.TheirBalance != nil { - rb := uint64(*rl.TheirBalance) - records = append(records, tlv.MakeBigSizeRecord( - revLogTheirBalanceType, &rb, - )) - } + rl.TheirBalance.WhenSome( + func(r tlv.RecordT[tlv.TlvType4, BigSizeMilliSatoshi]) { + records = append(records, r.Record()) + }, + ) // Create the tlv stream. tlvStream, err := tlv.NewStream(records...) @@ -374,27 +409,18 @@ func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { // deserializeRevocationLog deserializes a RevocationLog based on tlv format. func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { - var ( - rl RevocationLog - ourBalance uint64 - theirBalance uint64 - ) + var rl RevocationLog + + ourBalance := rl.OurBalance.Zero() + theirBalance := rl.TheirBalance.Zero() // Create the tlv stream. tlvStream, err := tlv.NewStream( - tlv.MakePrimitiveRecord( - revLogOurOutputIndexType, &rl.OurOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogTheirOutputIndexType, &rl.TheirOutputIndex, - ), - tlv.MakePrimitiveRecord( - revLogCommitTxHashType, &rl.CommitTxHash, - ), - tlv.MakeBigSizeRecord(revLogOurBalanceType, &ourBalance), - tlv.MakeBigSizeRecord( - revLogTheirBalanceType, &theirBalance, - ), + rl.OurOutputIndex.Record(), + rl.TheirOutputIndex.Record(), + rl.CommitTxHash.Record(), + ourBalance.Record(), + theirBalance.Record(), ) if err != nil { return rl, err @@ -406,14 +432,12 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { return rl, err } - if t, ok := parsedTypes[revLogOurBalanceType]; ok && t == nil { - lb := lnwire.MilliSatoshi(ourBalance) - rl.OurBalance = &lb + if t, ok := parsedTypes[ourBalance.TlvType()]; ok && t == nil { + rl.OurBalance = tlv.SomeRecordT(ourBalance) } - if t, ok := parsedTypes[revLogTheirBalanceType]; ok && t == nil { - rb := lnwire.MilliSatoshi(theirBalance) - rl.TheirBalance = &rb + if t, ok := parsedTypes[theirBalance.TlvType()]; ok && t == nil { + rl.TheirBalance = tlv.SomeRecordT(theirBalance) } // Read the HTLC entries. diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index c9d2a16f708..813f70ac99a 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lnwire" @@ -115,12 +116,11 @@ var ( }}, } - testRevocationLogNoAmts = RevocationLog{ - OurOutputIndex: 0, - TheirOutputIndex: 1, - CommitTxHash: testChannelCommit.CommitTx.TxHash(), - HTLCEntries: []*HTLCEntry{&testHTLCEntry}, - } + testRevocationLogNoAmts = NewRevocationLog( + 0, 1, testChannelCommit.CommitTx.TxHash(), + fn.None[lnwire.MilliSatoshi](), fn.None[lnwire.MilliSatoshi](), + []*HTLCEntry{&testHTLCEntry}, + ) testRevocationLogNoAmtsBytes = []byte{ // Body length 42. 0x2a, @@ -136,14 +136,11 @@ var ( 0xc8, 0x22, 0x51, 0xb1, 0x5b, 0xa0, 0xbf, 0xd, } - testRevocationLogWithAmts = RevocationLog{ - OurOutputIndex: 0, - TheirOutputIndex: 1, - CommitTxHash: testChannelCommit.CommitTx.TxHash(), - HTLCEntries: []*HTLCEntry{&testHTLCEntry}, - OurBalance: &localBalance, - TheirBalance: &remoteBalance, - } + testRevocationLogWithAmts = NewRevocationLog( + 0, 1, testChannelCommit.CommitTx.TxHash(), + fn.Some(localBalance), fn.Some(remoteBalance), + []*HTLCEntry{&testHTLCEntry}, + ) testRevocationLogWithAmtsBytes = []byte{ // Body length 52. 0x34, diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 61c9fc6bf31..a2fe35c2bd6 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2747,7 +2747,7 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, htlcRetributions := make([]HtlcRetribution, len(revokedLog.HTLCEntries)) for i, htlc := range revokedLog.HTLCEntries { hr, err := createHtlcRetribution( - chanState, keyRing, commitHash, + chanState, keyRing, commitHash.Val, commitmentSecret, leaseExpiry, htlc, ) if err != nil { @@ -2760,10 +2760,10 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // Construct the our outpoint. ourOutpoint := wire.OutPoint{ - Hash: commitHash, + Hash: commitHash.Val, } - if revokedLog.OurOutputIndex != channeldb.OutputIndexEmpty { - ourOutpoint.Index = uint32(revokedLog.OurOutputIndex) + if revokedLog.OurOutputIndex.Val != channeldb.OutputIndexEmpty { + ourOutpoint.Index = uint32(revokedLog.OurOutputIndex.Val) // If the spend transaction is provided, then we use it to get // the value of our output. @@ -2786,26 +2786,29 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // contains our output amount. Due to a previous // migration, this field may be empty in which case an // error will be returned. - if revokedLog.OurBalance == nil { - return nil, 0, 0, ErrRevLogDataMissing + b, err := revokedLog.OurBalance.ValOpt().UnwrapOrErr( + ErrRevLogDataMissing, + ) + if err != nil { + return nil, 0, 0, err } - ourAmt = int64(revokedLog.OurBalance.ToSatoshis()) + ourAmt = int64(b.Int().ToSatoshis()) } } // Construct the their outpoint. theirOutpoint := wire.OutPoint{ - Hash: commitHash, + Hash: commitHash.Val, } - if revokedLog.TheirOutputIndex != channeldb.OutputIndexEmpty { - theirOutpoint.Index = uint32(revokedLog.TheirOutputIndex) + if revokedLog.TheirOutputIndex.Val != channeldb.OutputIndexEmpty { + theirOutpoint.Index = uint32(revokedLog.TheirOutputIndex.Val) // If the spend transaction is provided, then we use it to get // the value of the remote parties' output. if spendTx != nil { // Sanity check that TheirOutputIndex is within range. - if int(revokedLog.TheirOutputIndex) >= + if int(revokedLog.TheirOutputIndex.Val) >= len(spendTx.TxOut) { return nil, 0, 0, fmt.Errorf("%w: theirs=%v, "+ @@ -2823,16 +2826,19 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, // contains remote parties' output amount. Due to a // previous migration, this field may be empty in which // case an error will be returned. - if revokedLog.TheirBalance == nil { - return nil, 0, 0, ErrRevLogDataMissing + b, err := revokedLog.TheirBalance.ValOpt().UnwrapOrErr( + ErrRevLogDataMissing, + ) + if err != nil { + return nil, 0, 0, err } - theirAmt = int64(revokedLog.TheirBalance.ToSatoshis()) + theirAmt = int64(b.Int().ToSatoshis()) } } return &BreachRetribution{ - BreachTxHash: commitHash, + BreachTxHash: commitHash.Val, ChainHash: chanState.ChainHash, LocalOutpoint: ourOutpoint, RemoteOutpoint: theirOutpoint, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 84c694c6116..bd6b048481f 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -21,6 +21,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -10021,22 +10022,19 @@ func TestCreateBreachRetribution(t *testing.T) { // Create a dummy revocation log. ourAmtMsat := lnwire.MilliSatoshi(ourAmt * 1000) theirAmtMsat := lnwire.MilliSatoshi(theirAmt * 1000) - revokedLog := channeldb.RevocationLog{ - CommitTxHash: commitHash, - OurOutputIndex: uint16(localIndex), - TheirOutputIndex: uint16(remoteIndex), - HTLCEntries: []*channeldb.HTLCEntry{htlc}, - TheirBalance: &theirAmtMsat, - OurBalance: &ourAmtMsat, - } + revokedLog := channeldb.NewRevocationLog( + uint16(localIndex), uint16(remoteIndex), commitHash, + fn.Some(ourAmtMsat), fn.Some(theirAmtMsat), + []*channeldb.HTLCEntry{htlc}, + ) // Create a log with an empty local output index. revokedLogNoLocal := revokedLog - revokedLogNoLocal.OurOutputIndex = channeldb.OutputIndexEmpty + revokedLogNoLocal.OurOutputIndex.Val = channeldb.OutputIndexEmpty // Create a log with an empty remote output index. revokedLogNoRemote := revokedLog - revokedLogNoRemote.TheirOutputIndex = channeldb.OutputIndexEmpty + revokedLogNoRemote.TheirOutputIndex.Val = channeldb.OutputIndexEmpty testCases := []struct { name string @@ -10066,14 +10064,20 @@ func TestCreateBreachRetribution(t *testing.T) { { name: "fail due to our index too big", revocationLog: &channeldb.RevocationLog{ - OurOutputIndex: uint16(htlcIndex + 1), + //nolint:lll + OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint16(htlcIndex + 1), + ), }, expectedErr: ErrOutputIndexOutOfRange, }, { name: "fail due to their index too big", revocationLog: &channeldb.RevocationLog{ - TheirOutputIndex: uint16(htlcIndex + 1), + //nolint:lll + TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint16(htlcIndex + 1), + ), }, expectedErr: ErrOutputIndexOutOfRange, }, From d5f595e64166329cc15e3e13ccbcfbbec98bbe8f Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 17:00:55 -0700 Subject: [PATCH 10/14] channeldb: add custom blobs to RevocationLog+HTLCEntry This'll be useful for custom channel types that want to store extra information that'll be useful to help handle channel revocation cases. --- channeldb/revocation_log.go | 78 +++++++++++++++++++++++++++++--- channeldb/revocation_log_test.go | 35 ++++++++++---- lnwallet/channel_test.go | 2 +- 3 files changed, 97 insertions(+), 18 deletions(-) diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index 3fc4163c112..0b51400c440 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -164,22 +164,32 @@ type HTLCEntry struct { // // NOTE: this field is the memory representation of the field amtUint. Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] + + // CustomBlob is an optional blob that can be used to store information + // specific to revocation handling for a custom channel type. + CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] } // toTlvStream converts an HTLCEntry record into a tlv representation. func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { - return tlv.NewStream( + records := []tlv.Record{ h.RHash.Record(), h.RefundTimeout.Record(), h.OutputIndex.Record(), h.Incoming.Record(), h.Amt.Record(), - ) + } + + h.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { + records = append(records, r.Record()) + }) + + return tlv.NewStream(records...) } // NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC. func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { - return &HTLCEntry{ + h := &HTLCEntry{ RHash: tlv.NewRecordT[tlv.TlvType0]( NewSparsePayHash(htlc.RHash), ), @@ -194,6 +204,16 @@ func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), ), } + + if len(htlc.ExtraData) != 0 { + h.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob]( + htlc.ExtraData, + ), + ) + } + + return h } // RevocationLog stores the info needed to construct a breach retribution. Its @@ -236,13 +256,19 @@ type RevocationLog struct { // this field, it could be the case that the field is not present for // all revocation logs. TheirBalance tlv.OptionalRecordT[tlv.TlvType4, BigSizeMilliSatoshi] + + // CustomBlob is an optional blob that can be used to store information + // specific to a custom channel type. This information is only created + // at channel funding time, and after wards is to be considered + // immutable. + CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] } // NewRevocationLog creates a new RevocationLog from the given parameters. func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16, commitHash [32]byte, ourBalance, - theirBalance fn.Option[lnwire.MilliSatoshi], - htlcs []*HTLCEntry) RevocationLog { + theirBalance fn.Option[lnwire.MilliSatoshi], htlcs []*HTLCEntry, + customBlob fn.Option[tlv.Blob]) RevocationLog { rl := RevocationLog{ OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0]( @@ -267,6 +293,12 @@ func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16, )) }) + customBlob.WhenSome(func(blob tlv.Blob) { + rl.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + }) + return rl } @@ -298,6 +330,12 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)), } + commit.CustomBlob.WhenSome(func(blob tlv.Blob) { + rl.CustomBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob), + ) + }) + if !noAmtData { rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3]( tlv.NewBigSizeT(commit.LocalBalance), @@ -373,6 +411,10 @@ func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { }, ) + rl.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { + records = append(records, r.Record()) + }) + // Create the tlv stream. tlvStream, err := tlv.NewStream(records...) if err != nil { @@ -413,6 +455,7 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { ourBalance := rl.OurBalance.Zero() theirBalance := rl.TheirBalance.Zero() + customBlob := rl.CustomBlob.Zero() // Create the tlv stream. tlvStream, err := tlv.NewStream( @@ -421,6 +464,7 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { rl.CommitTxHash.Record(), ourBalance.Record(), theirBalance.Record(), + customBlob.Record(), ) if err != nil { return rl, err @@ -440,6 +484,10 @@ func deserializeRevocationLog(r io.Reader) (RevocationLog, error) { rl.TheirBalance = tlv.SomeRecordT(theirBalance) } + if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { + rl.CustomBlob = tlv.SomeRecordT(customBlob) + } + // Read the HTLC entries. rl.HTLCEntries, err = deserializeHTLCEntries(r) @@ -454,14 +502,26 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { for { var htlc HTLCEntry + customBlob := htlc.CustomBlob.Zero() + // Create the tlv stream. - tlvStream, err := htlc.toTlvStream() + records := []tlv.Record{ + htlc.RHash.Record(), + htlc.RefundTimeout.Record(), + htlc.OutputIndex.Record(), + htlc.Incoming.Record(), + htlc.Amt.Record(), + customBlob.Record(), + } + + tlvStream, err := tlv.NewStream(records...) if err != nil { return nil, err } // Read the HTLC entry. - if _, err := readTlvStream(r, tlvStream); err != nil { + parsedTypes, err := readTlvStream(r, tlvStream) + if err != nil { // We've reached the end when hitting an EOF. if err == io.ErrUnexpectedEOF { break @@ -469,6 +529,10 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { return nil, err } + if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil { + htlc.CustomBlob = tlv.SomeRecordT(customBlob) + } + // Append the entry. htlcs = append(htlcs, &htlc) } diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index 813f70ac99a..76b5d8b7804 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -34,6 +34,10 @@ var ( 0xff, // value = 255 } + blobBytes = tlv.Blob{ + 0x01, 0x02, 0x03, 0x04, + } + testHTLCEntry = HTLCEntry{ RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32]( 740_000, @@ -45,10 +49,13 @@ var ( Amt: tlv.NewRecordT[tlv.TlvType4]( tlv.NewBigSizeT(btcutil.Amount(1_000_000)), ), + CustomBlob: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5](blobBytes), + ), } testHTLCEntryBytes = []byte{ - // Body length 23. - 0x16, + // Body length 28. + 0x1c, // Rhash tlv. 0x0, 0x0, // RefundTimeout tlv. @@ -59,6 +66,8 @@ var ( 0x3, 0x1, 0x1, // Amt tlv. 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, + // Custom blob tlv. + 0x5, 0x4, 0x1, 0x2, 0x3, 0x4, } testHTLCEntryHash = HTLCEntry{ @@ -113,17 +122,19 @@ var ( Amt: lnwire.NewMSatFromSatoshis( testHTLCEntry.Amt.Val.Int(), ), + ExtraData: blobBytes, }}, + CustomBlob: fn.Some(blobBytes), } testRevocationLogNoAmts = NewRevocationLog( 0, 1, testChannelCommit.CommitTx.TxHash(), fn.None[lnwire.MilliSatoshi](), fn.None[lnwire.MilliSatoshi](), - []*HTLCEntry{&testHTLCEntry}, + []*HTLCEntry{&testHTLCEntry}, fn.Some(blobBytes), ) testRevocationLogNoAmtsBytes = []byte{ - // Body length 42. - 0x2a, + // Body length 48. + 0x30, // OurOutputIndex tlv. 0x0, 0x2, 0x0, 0x0, // TheirOutputIndex tlv. @@ -134,16 +145,18 @@ var ( 0x6e, 0x60, 0x29, 0x23, 0x1d, 0x5e, 0xc5, 0xe6, 0xbd, 0xf7, 0xd3, 0x9b, 0x16, 0x7d, 0x0, 0xff, 0xc8, 0x22, 0x51, 0xb1, 0x5b, 0xa0, 0xbf, 0xd, + // Custom blob tlv. + 0x5, 0x4, 0x1, 0x2, 0x3, 0x4, } testRevocationLogWithAmts = NewRevocationLog( 0, 1, testChannelCommit.CommitTx.TxHash(), fn.Some(localBalance), fn.Some(remoteBalance), - []*HTLCEntry{&testHTLCEntry}, + []*HTLCEntry{&testHTLCEntry}, fn.Some(blobBytes), ) testRevocationLogWithAmtsBytes = []byte{ - // Body length 52. - 0x34, + // Body length 58. + 0x3a, // OurOutputIndex tlv. 0x0, 0x2, 0x0, 0x0, // TheirOutputIndex tlv. @@ -158,6 +171,8 @@ var ( 0x3, 0x3, 0xfd, 0x23, 0x28, // Remote Balance. 0x4, 0x3, 0xfd, 0x0b, 0xb8, + // Custom blob tlv. + 0x5, 0x4, 0x1, 0x2, 0x3, 0x4, } ) @@ -269,7 +284,7 @@ func TestSerializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - expectedBytes := []byte{0x36, 0x0, 0x20} + expectedBytes := []byte{0x3c, 0x0, 0x20} expectedBytes = append(expectedBytes, rHashBytes...) // Append the rest. @@ -384,7 +399,7 @@ func TestDeserializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - testBytes := append([]byte{0x36, 0x0, 0x20}, rHashBytes...) + testBytes := append([]byte{0x3c, 0x0, 0x20}, rHashBytes...) // Append the rest. testBytes = append(testBytes, partialBytes...) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index bd6b048481f..33f8f636c67 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -10025,7 +10025,7 @@ func TestCreateBreachRetribution(t *testing.T) { revokedLog := channeldb.NewRevocationLog( uint16(localIndex), uint16(remoteIndex), commitHash, fn.Some(ourAmtMsat), fn.Some(theirAmtMsat), - []*channeldb.HTLCEntry{htlc}, + []*channeldb.HTLCEntry{htlc}, fn.None[tlv.Blob](), ) // Create a log with an empty local output index. From 2cf38540e00e266e5b5122d7ba02df434b051f73 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 17:44:14 -0700 Subject: [PATCH 11/14] channeldb: add HtlcIndex to HTLCEntry This may be useful for custom channel types that base everything off the index (a global value) rather than the output index (can change with each state). --- channeldb/revocation_log.go | 15 ++++++++++----- channeldb/revocation_log_test.go | 18 ++++++++++++------ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index 0b51400c440..4334ec64dd9 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -155,19 +155,17 @@ type HTLCEntry struct { // Incoming denotes whether we're the receiver or the sender of this // HTLC. - // - // NOTE: this field is the memory representation of the field - // incomingUint. Incoming tlv.RecordT[tlv.TlvType3, bool] // Amt is the amount of satoshis this HTLC escrows. - // - // NOTE: this field is the memory representation of the field amtUint. Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] // CustomBlob is an optional blob that can be used to store information // specific to revocation handling for a custom channel type. CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] + + // HtlcIndex is the index of the HTLC in the channel. + HtlcIndex tlv.RecordT[tlv.TlvType6, uint16] } // toTlvStream converts an HTLCEntry record into a tlv representation. @@ -178,12 +176,15 @@ func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { h.OutputIndex.Record(), h.Incoming.Record(), h.Amt.Record(), + h.HtlcIndex.Record(), } h.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { records = append(records, r.Record()) }) + tlv.SortRecords(records) + return tlv.NewStream(records...) } @@ -203,6 +204,9 @@ func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { Amt: tlv.NewRecordT[tlv.TlvType4]( tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), ), + HtlcIndex: tlv.NewPrimitiveRecord[tlv.TlvType6]( + uint16(htlc.HtlcIndex), + ), } if len(htlc.ExtraData) != 0 { @@ -512,6 +516,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { htlc.Incoming.Record(), htlc.Amt.Record(), customBlob.Record(), + htlc.HtlcIndex.Record(), } tlvStream, err := tlv.NewStream(records...) diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index 76b5d8b7804..d21c41fbf93 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -52,10 +52,11 @@ var ( CustomBlob: tlv.SomeRecordT( tlv.NewPrimitiveRecord[tlv.TlvType5](blobBytes), ), + HtlcIndex: tlv.NewPrimitiveRecord[tlv.TlvType6, uint16](3), } testHTLCEntryBytes = []byte{ - // Body length 28. - 0x1c, + // Body length 32. + 0x20, // Rhash tlv. 0x0, 0x0, // RefundTimeout tlv. @@ -68,6 +69,8 @@ var ( 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, // Custom blob tlv. 0x5, 0x4, 0x1, 0x2, 0x3, 0x4, + // HLTC index tlv. + 0x6, 0x2, 0x0, 0x03, } testHTLCEntryHash = HTLCEntry{ @@ -86,8 +89,8 @@ var ( ), } testHTLCEntryHashBytes = []byte{ - // Body length 54. - 0x36, + // Body length 58. + 0x3a, // Rhash tlv. 0x0, 0x20, 0x33, 0x44, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -102,6 +105,8 @@ var ( 0x3, 0x1, 0x1, // Amt tlv. 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, + // HLTC index tlv. + 0x6, 0x2, 0x0, 0x00, } localBalance = lnwire.MilliSatoshi(9000) @@ -118,6 +123,7 @@ var ( Htlcs: []HTLC{{ RefundTimeout: testHTLCEntry.RefundTimeout.Val, OutputIndex: int32(testHTLCEntry.OutputIndex.Val), + HtlcIndex: uint64(testHTLCEntry.HtlcIndex.Val), Incoming: testHTLCEntry.Incoming.Val, Amt: lnwire.NewMSatFromSatoshis( testHTLCEntry.Amt.Val.Int(), @@ -284,7 +290,7 @@ func TestSerializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - expectedBytes := []byte{0x3c, 0x0, 0x20} + expectedBytes := []byte{0x40, 0x0, 0x20} expectedBytes = append(expectedBytes, rHashBytes...) // Append the rest. @@ -399,7 +405,7 @@ func TestDeserializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - testBytes := append([]byte{0x3c, 0x0, 0x20}, rHashBytes...) + testBytes := append([]byte{0x40, 0x0, 0x20}, rHashBytes...) // Append the rest. testBytes = append(testBytes, partialBytes...) From 2ac360a1b6b504166ebc1d938a7663785dd09ab9 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 1 Apr 2024 19:56:50 -0700 Subject: [PATCH 12/14] lnwallet: add TLV blob to PaymentDescriptor + htlc add In this commit, we add a TLV blob to the PaymentDescriptor struct. We also now thread through this value from the UpdateAddHTLC message to the PaymentDescriptor mapping, and the other way around. --- lnwallet/channel.go | 45 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index a2fe35c2bd6..b38fadcbec5 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -335,6 +335,10 @@ type PaymentDescriptor struct { // NOTE: Populated only on add payment descriptor entry types. OnionBlob []byte + // CustomRecrods also stores the set of optional custom records that + // may have been attached to a sent HTLC. + CustomRecords fn.Option[tlv.Blob] + // ShaOnionBlob is a sha of the onion blob. // // NOTE: Populated only in payment descriptor with MalformedFail type. @@ -752,6 +756,12 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { } copy(h.OnionBlob[:], htlc.OnionBlob) + // If the HTLC had custom records, then we'll copy that over so + // we restore with the same information. + htlc.CustomRecords.WhenSome(func(b tlv.Blob) { + copy(h.ExtraData, b) + }) + if ourCommit && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -776,6 +786,13 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { BlindingPoint: htlc.BlindingPoint, } copy(h.OnionBlob[:], htlc.OnionBlob) + + // If the HTLC had custom records, then we'll copy that over so + // we restore with the same information. + htlc.CustomRecords.WhenSome(func(b tlv.Blob) { + copy(h.ExtraData, b) + }) + if ourCommit && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -858,7 +875,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // With the scripts reconstructed (depending on if this is our commit // vs theirs or a pending commit for the remote party), we can now // re-create the original payment descriptor. - return PaymentDescriptor{ + pd = PaymentDescriptor{ RHash: htlc.RHash, Timeout: htlc.RefundTimeout, Amount: htlc.Amt, @@ -873,7 +890,15 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, theirPkScript: theirP2WSH, theirWitnessScript: theirWitnessScript, BlindingPoint: htlc.BlindingPoint, - }, nil + } + + // Ensure that we'll restore any custom records which were stored as + // extra data on disk. + if len(htlc.ExtraData) != 0 { + pd.CustomRecords = fn.Some[tlv.Blob](htlc.ExtraData) + } + + return pd, nil } // extractPayDescs will convert all HTLC's present within a disk commit state @@ -6159,7 +6184,7 @@ func (lc *LightningChannel) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error { func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, openKey *models.CircuitKey) *PaymentDescriptor { - return &PaymentDescriptor{ + pd := &PaymentDescriptor{ EntryType: Add, RHash: PaymentHash(htlc.PaymentHash), Timeout: htlc.Expiry, @@ -6170,6 +6195,14 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, OpenCircuitKey: openKey, BlindingPoint: htlc.BlindingPoint, } + + // Copy over any extra data included to ensure we can forward and + // process this HTLC properly. + if len(htlc.ExtraData) != 0 { + pd.CustomRecords = fn.Some[tlv.Blob](htlc.ExtraData[:]) + } + + return pd } // validateAddHtlc validates the addition of an outgoing htlc to our local and @@ -6229,6 +6262,12 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err BlindingPoint: htlc.BlindingPoint, } + // Copy over any extra data included to ensure we can forward and + // process this HTLC properly. + if htlc.ExtraData != nil { + pd.CustomRecords = fn.Some(tlv.Blob(htlc.ExtraData[:])) + } + localACKedIndex := lc.remoteCommitChain.tail().ourMessageIndex // Clamp down on the number of HTLC's we can receive by checking the From 33f2db1c25120d09c70fc5e49486289f1d482643 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Thu, 25 Apr 2024 19:00:42 +0200 Subject: [PATCH 13/14] multi: thread thru the AuxLeafStore everywhere --- chainreg/chainregistry.go | 5 ++ config_builder.go | 28 +++++++++-- contractcourt/breach_arbitrator_test.go | 3 ++ contractcourt/chain_arbitrator.go | 18 ++++++- contractcourt/chain_watcher.go | 11 +++-- funding/manager.go | 11 ++++- lnd.go | 5 +- lnwallet/channel.go | 63 +++++++++++++++---------- lnwallet/channel_test.go | 10 ++++ lnwallet/config.go | 5 ++ lnwallet/wallet.go | 9 +++- peer/brontide.go | 14 +++++- server.go | 10 +++- 13 files changed, 150 insertions(+), 42 deletions(-) diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index 036da71a5b5..934a84996b8 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -24,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs/neutrinonotify" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -63,6 +64,10 @@ type Config struct { // state. ChanStateDB *channeldb.ChannelStateDB + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] + // BlockCache is the main cache for storing block information. BlockCache *blockcache.BlockCache diff --git a/config_builder.go b/config_builder.go index bbf291506b5..7af32738243 100644 --- a/config_builder.go +++ b/config_builder.go @@ -34,6 +34,7 @@ import ( "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -104,7 +105,7 @@ type DatabaseBuilder interface { type WalletConfigBuilder interface { // BuildWalletConfig is responsible for creating or unlocking and then // fully initializing a wallet. - BuildWalletConfig(context.Context, *DatabaseInstances, + BuildWalletConfig(context.Context, *DatabaseInstances, *AuxComponents, *rpcperms.InterceptorChain, []*ListenerWithSignal) (*chainreg.PartialChainControl, *btcwallet.Config, func(), error) @@ -145,6 +146,17 @@ type ImplementationCfg struct { // ChainControlBuilder is a type that can provide a custom wallet // implementation. ChainControlBuilder + // AuxComponents is a set of auxiliary components that can be used by + // lnd for certain custom channel types. + AuxComponents +} + +// AuxComponents is a set of auxiliary components that can be used by lnd for +// certain custom channel types. +type AuxComponents struct { + // AuxLeafStore is an optional data source that can be used by custom + // channels to fetch+store various data. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // DefaultWalletImpl is the default implementation of our normal, btcwallet @@ -229,7 +241,8 @@ func (d *DefaultWalletImpl) Permissions() map[string][]bakery.Op { // // NOTE: This is part of the WalletConfigBuilder interface. func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, - dbs *DatabaseInstances, interceptorChain *rpcperms.InterceptorChain, + dbs *DatabaseInstances, aux *AuxComponents, + interceptorChain *rpcperms.InterceptorChain, grpcListeners []*ListenerWithSignal) (*chainreg.PartialChainControl, *btcwallet.Config, func(), error) { @@ -549,6 +562,7 @@ func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, HeightHintDB: dbs.HeightHintDB, ChanStateDB: dbs.ChanStateDB.ChannelStateDB(), NeutrinoCS: neutrinoCS, + AuxLeafStore: aux.AuxLeafStore, ActiveNetParams: d.cfg.ActiveNetParams, FeeURL: d.cfg.FeeURL, Dialer: func(addr string) (net.Conn, error) { @@ -607,8 +621,9 @@ func (d *DefaultWalletImpl) BuildWalletConfig(ctx context.Context, // proxyBlockEpoch proxies a block epoch subsections to the underlying neutrino // rebroadcaster client. -func proxyBlockEpoch(notifier chainntnfs.ChainNotifier, -) func() (*blockntfns.Subscription, error) { +func proxyBlockEpoch( + notifier chainntnfs.ChainNotifier) func() (*blockntfns.Subscription, + error) { return func() (*blockntfns.Subscription, error) { blockEpoch, err := notifier.RegisterBlockEpochNtfn( @@ -699,6 +714,7 @@ func (d *DefaultWalletImpl) BuildChainControl( ChainIO: walletController, NetParams: *walletConfig.NetParams, CoinSelectionStrategy: walletConfig.CoinSelectionStrategy, + AuxLeafStore: partialChainControl.Cfg.AuxLeafStore, } // The broadcast is already always active for neutrino nodes, so we @@ -878,6 +894,10 @@ type DatabaseInstances struct { // for native SQL queries for tables that already support it. This may // be nil if the use-native-sql flag was not set. NativeSQLStore *sqldb.BaseDB + + // AuxLeafStore is an optional data source that can be used by custom + // channels to fetch+store various data. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // DefaultDatabaseBuilder is a type that builds the default database backends diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index 2fe4644db96..babb427ea29 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -22,6 +22,7 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntest/channels" @@ -1590,6 +1591,7 @@ func testBreachSpends(t *testing.T, test breachTest) { // Notify the breach arbiter about the breach. retribution, err := lnwallet.NewBreachRetribution( alice.State(), height, 1, forceCloseTx, + fn.None[lnwallet.AuxLeafStore](), ) require.NoError(t, err, "unable to create breach retribution") @@ -1799,6 +1801,7 @@ func TestBreachDelayedJusticeConfirmation(t *testing.T) { // Notify the breach arbiter about the breach. retribution, err := lnwallet.NewBreachRetribution( alice.State(), height, uint32(blockHeight), forceCloseTx, + fn.None[lnwallet.AuxLeafStore](), ) require.NoError(t, err, "unable to create breach retribution") diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index fbddd81f0ee..3245162cb59 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -217,6 +217,10 @@ type ChainArbitratorConfig struct { // meanwhile, turn `PaymentCircuit` into an interface or bring it to a // lower package. QueryIncomingCircuit func(circuit models.CircuitKey) *models.CircuitKey + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // ChainArbitrator is a sub-system that oversees the on-chain resolution of all @@ -299,8 +303,13 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions, return nil, err } + var chanOpts []lnwallet.ChannelOpt + a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + chanMachine, err := lnwallet.NewLightningChannel( - a.c.cfg.Signer, channel, nil, + a.c.cfg.Signer, channel, nil, chanOpts..., ) if err != nil { return nil, err @@ -344,10 +353,15 @@ func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error) return nil, err } + var chanOpts []lnwallet.ChannelOpt + a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // Finally, we'll force close the channel completing // the force close workflow. chanMachine, err := lnwallet.NewLightningChannel( - a.c.cfg.Signer, channel, nil, + a.c.cfg.Signer, channel, nil, chanOpts..., ) if err != nil { return nil, err diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 28022650cc8..aa14ae05106 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -188,6 +188,9 @@ type chainWatcherConfig struct { // obfuscater. This is used by the chain watcher to identify which // state was broadcast and confirmed on-chain. extractStateNumHint func(*wire.MsgTx, [lnwallet.StateHintSize]byte) uint64 + + // auxLeafStore can be used to fetch information for custom channels. + auxLeafStore fn.Option[lnwallet.AuxLeafStore] } // chainWatcher is a system that's assigned to every active channel. The duty @@ -862,7 +865,7 @@ func (c *chainWatcher) handlePossibleBreach(commitSpend *chainntnfs.SpendDetail, spendHeight := uint32(commitSpend.SpendingHeight) retribution, err := lnwallet.NewBreachRetribution( c.cfg.chanState, broadcastStateNum, spendHeight, - commitSpend.SpendingTx, + commitSpend.SpendingTx, c.cfg.auxLeafStore, ) switch { @@ -1073,8 +1076,8 @@ func (c *chainWatcher) dispatchLocalForceClose( "detected", c.cfg.chanState.FundingOutpoint) forceClose, err := lnwallet.NewLocalForceCloseSummary( - c.cfg.chanState, c.cfg.signer, - commitSpend.SpendingTx, stateNum, + c.cfg.chanState, c.cfg.signer, commitSpend.SpendingTx, stateNum, + c.cfg.auxLeafStore, ) if err != nil { return err @@ -1167,7 +1170,7 @@ func (c *chainWatcher) dispatchRemoteForceClose( // channel on-chain. uniClose, err := lnwallet.NewUnilateralCloseSummary( c.cfg.chanState, c.cfg.signer, commitSpend, - remoteCommit, commitPoint, + remoteCommit, commitPoint, c.cfg.auxLeafStore, ) if err != nil { return err diff --git a/funding/manager.go b/funding/manager.go index 227bb0e6632..4a9c0122d61 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -538,6 +538,10 @@ type Config struct { // AliasManager is an implementation of the aliasHandler interface that // abstracts away the handling of many alias functions. AliasManager aliasHandler + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] } // Manager acts as an orchestrator/bridge between the wallet's @@ -1056,9 +1060,14 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, } } + var chanOpts []lnwallet.ChannelOpt + f.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // We create the state-machine object which wraps the database state. lnChannel, err := lnwallet.NewLightningChannel( - nil, channel, nil, + nil, channel, nil, chanOpts..., ) if err != nil { log.Errorf("Unable to create LightningChannel(%v): %v", diff --git a/lnd.go b/lnd.go index f4ef03a078b..9f628ff8125 100644 --- a/lnd.go +++ b/lnd.go @@ -437,7 +437,8 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, defer cleanUp() partialChainControl, walletConfig, cleanUp, err := implCfg.BuildWalletConfig( - ctx, dbs, interceptorChain, grpcListeners, + ctx, dbs, &implCfg.AuxComponents, interceptorChain, + grpcListeners, ) if err != nil { return mkErr("error creating wallet config: %v", err) @@ -580,7 +581,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, server, err := newServer( cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, activeChainControl.Cfg.WalletUnlockParams.ChansToRestore, - multiAcceptor, torController, tlsManager, + multiAcceptor, torController, tlsManager, implCfg, ) if err != nil { return mkErr("unable to create server: %v", err) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index b38fadcbec5..b01c2cde304 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2462,7 +2462,8 @@ type BreachRetribution struct { // required to construct the BreachRetribution. If the revocation log is missing // the required fields then ErrRevLogDataMissing will be returned. func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, - breachHeight uint32, spendTx *wire.MsgTx) (*BreachRetribution, error) { + breachHeight uint32, spendTx *wire.MsgTx, + leafStore fn.Option[AuxLeafStore]) (*BreachRetribution, error) { // Query the on-disk revocation log for the snapshot which was recorded // at this particular state num. Based on whether a legacy revocation @@ -3493,9 +3494,16 @@ func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, // signature can be submitted to the sigPool to generate all the signatures // asynchronously and in parallel. func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, - chanType channeldb.ChannelType, isRemoteInitiator bool, - leaseExpiry uint32, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, - remoteCommitView *commitment) ([]SignJob, chan struct{}, error) { + chanState *channeldb.OpenChannel, leaseExpiry uint32, + remoteCommitView *commitment, + leafStore fn.Option[AuxLeafStore]) ([]SignJob, chan struct{}, error) { + + var ( + isRemoteInitiator = !chanState.IsInitiator + localChanCfg = chanState.LocalChanCfg + remoteChanCfg = chanState.RemoteChanCfg + chanType = chanState.ChanType + ) txHash := remoteCommitView.txn.TxHash() dustLimit := remoteChanCfg.DustLimit @@ -3661,9 +3669,9 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // validate this new state. This function is called right before sending the // new commitment to the remote party. The commit diff returned contains all // information necessary for retransmission. -func (lc *LightningChannel) createCommitDiff( - newCommit *commitment, commitSig lnwire.Sig, - htlcSigs []lnwire.Sig) (*channeldb.CommitDiff, error) { +func (lc *LightningChannel) createCommitDiff(newCommit *commitment, + commitSig lnwire.Sig, htlcSigs []lnwire.Sig) (*channeldb.CommitDiff, + error) { // First, we need to convert the funding outpoint into the ID that's // used on the wire to identify this channel. We'll use this shortly @@ -4362,9 +4370,8 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { leaseExpiry = lc.channelState.ThawHeight } sigBatch, cancelChan, err := genRemoteHtlcSigJobs( - keyRing, lc.channelState.ChanType, !lc.channelState.IsInitiator, - leaseExpiry, &lc.channelState.LocalChanCfg, - &lc.channelState.RemoteChanCfg, newCommitView, + keyRing, lc.channelState, leaseExpiry, newCommitView, + lc.leafStore, ) if err != nil { return nil, err @@ -4927,10 +4934,18 @@ func (lc *LightningChannel) computeView(view *HtlcView, remoteChain bool, // meant to verify all the signatures for HTLC's attached to a newly created // commitment state. The jobs generated are fully populated, and can be sent // directly into the pool of workers. -func genHtlcSigValidationJobs(localCommitmentView *commitment, - keyRing *CommitmentKeyRing, htlcSigs []lnwire.Sig, - chanType channeldb.ChannelType, isLocalInitiator bool, leaseExpiry uint32, - localChanCfg, remoteChanCfg *channeldb.ChannelConfig) ([]VerifyJob, error) { +// +//nolint:funlen +func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, + localCommitmentView *commitment, keyRing *CommitmentKeyRing, + htlcSigs []lnwire.Sig, leaseExpiry uint32, + leafStore fn.Option[AuxLeafStore]) ([]VerifyJob, error) { + + var ( + isLocalInitiator = chanState.IsInitiator + localChanCfg = chanState.LocalChanCfg + chanType = chanState.ChanType + ) txHash := localCommitmentView.txn.TxHash() feePerKw := localCommitmentView.feePerKw @@ -5323,10 +5338,8 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { leaseExpiry = lc.channelState.ThawHeight } verifyJobs, err := genHtlcSigValidationJobs( - localCommitmentView, keyRing, commitSigs.HtlcSigs, - lc.channelState.ChanType, lc.channelState.IsInitiator, - leaseExpiry, &lc.channelState.LocalChanCfg, - &lc.channelState.RemoteChanCfg, + lc.channelState, localCommitmentView, keyRing, + commitSigs.HtlcSigs, leaseExpiry, lc.leafStore, ) if err != nil { return err @@ -6773,10 +6786,10 @@ type UnilateralCloseSummary struct { // happen in case we have lost state) it should be set to an empty struct, in // which case we will attempt to sweep the non-HTLC output using the passed // commitPoint. -func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Signer, - commitSpend *chainntnfs.SpendDetail, - remoteCommit channeldb.ChannelCommitment, - commitPoint *btcec.PublicKey) (*UnilateralCloseSummary, error) { +func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, + signer input.Signer, commitSpend *chainntnfs.SpendDetail, + remoteCommit channeldb.ChannelCommitment, commitPoint *btcec.PublicKey, + leafStore fn.Option[AuxLeafStore]) (*UnilateralCloseSummary, error) { // First, we'll generate the commitment point and the revocation point // so we can re-construct the HTLC state and also our payment key. @@ -7717,7 +7730,7 @@ func (lc *LightningChannel) ForceClose() (*LocalForceCloseSummary, error) { localCommitment := lc.channelState.LocalCommitment summary, err := NewLocalForceCloseSummary( lc.channelState, lc.Signer, commitTx, - localCommitment.CommitHeight, + localCommitment.CommitHeight, lc.leafStore, ) if err != nil { return nil, fmt.Errorf("unable to gen force close "+ @@ -7734,8 +7747,8 @@ func (lc *LightningChannel) ForceClose() (*LocalForceCloseSummary, error) { // channel state. The passed commitTx must be a fully signed commitment // transaction corresponding to localCommit. func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, - signer input.Signer, commitTx *wire.MsgTx, stateNum uint64) ( - *LocalForceCloseSummary, error) { + signer input.Signer, commitTx *wire.MsgTx, stateNum uint64, + leafStore fn.Option[AuxLeafStore]) (*LocalForceCloseSummary, error) { // Re-derive the original pkScript for to-self output within the // commitment transaction. We'll need this to find the corresponding diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 33f8f636c67..24caef62d63 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -5695,6 +5695,7 @@ func TestChannelUnilateralCloseHtlcResolution(t *testing.T) { spendDetail, aliceChannel.channelState.RemoteCommitment, aliceChannel.channelState.RemoteCurrentRevocation, + fn.None[AuxLeafStore](), ) require.NoError(t, err, "unable to create alice close summary") @@ -5844,6 +5845,7 @@ func TestChannelUnilateralClosePendingCommit(t *testing.T) { spendDetail, aliceChannel.channelState.RemoteCommitment, aliceChannel.channelState.RemoteCurrentRevocation, + fn.None[AuxLeafStore](), ) require.NoError(t, err, "unable to create alice close summary") @@ -5861,6 +5863,7 @@ func TestChannelUnilateralClosePendingCommit(t *testing.T) { spendDetail, aliceRemoteChainTip.Commitment, aliceChannel.channelState.RemoteNextRevocation, + fn.None[AuxLeafStore](), ) require.NoError(t, err, "unable to create alice close summary") @@ -6741,6 +6744,7 @@ func TestNewBreachRetributionSkipsDustHtlcs(t *testing.T) { breachTx := aliceChannel.channelState.RemoteCommitment.CommitTx breachRet, err := NewBreachRetribution( aliceChannel.channelState, revokedStateNum, 100, breachTx, + fn.None[AuxLeafStore](), ) require.NoError(t, err, "unable to create breach retribution") @@ -10285,6 +10289,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // error as there are no past delta state saved as revocation logs yet. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, breachTx, + fn.None[AuxLeafStore](), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10292,6 +10297,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // provided. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, nil, + fn.None[AuxLeafStore](), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10337,6 +10343,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // successfully. br, err := NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, breachTx, + fn.None[AuxLeafStore](), ) require.NoError(t, err) @@ -10348,6 +10355,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // since the necessary info should now be found in the revocation log. br, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, nil, + fn.None[AuxLeafStore](), ) require.NoError(t, err) assertRetribution(br, 1, 0) @@ -10356,6 +10364,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // error. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum+1, breachHeight, breachTx, + fn.None[AuxLeafStore](), ) require.ErrorIs(t, err, channeldb.ErrLogEntryNotFound) @@ -10363,6 +10372,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // provided. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum+1, breachHeight, nil, + fn.None[AuxLeafStore](), ) require.ErrorIs(t, err, channeldb.ErrLogEntryNotFound) } diff --git a/lnwallet/config.go b/lnwallet/config.go index 7eeacb6ea23..24961f38edb 100644 --- a/lnwallet/config.go +++ b/lnwallet/config.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcwallet/wallet" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -62,4 +63,8 @@ type Config struct { // CoinSelectionStrategy is the strategy that is used for selecting // coins when funding a transaction. CoinSelectionStrategy wallet.CoinSelectionStrategy + + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[AuxLeafStore] } diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index 77dabfcba78..c579363542c 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -2505,9 +2505,16 @@ func TapscriptRootToTweak(root chainhash.Hash) input.MuSig2Tweaks { func (l *LightningWallet) ValidateChannel(channelState *channeldb.OpenChannel, fundingTx *wire.MsgTx) error { + var chanOpts []ChannelOpt + l.Cfg.AuxLeafStore.WhenSome(func(s AuxLeafStore) { + chanOpts = append(chanOpts, WithLeafStore(s)) + }) + // First, we'll obtain a fully signed commitment transaction so we can // pass into it on the chanvalidate package for verification. - channel, err := NewLightningChannel(l.Cfg.Signer, channelState, nil) + channel, err := NewLightningChannel( + l.Cfg.Signer, channelState, nil, chanOpts..., + ) if err != nil { return err } diff --git a/peer/brontide.go b/peer/brontide.go index 984a5280987..b9a9f68ca15 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -359,6 +359,10 @@ type Config struct { AddLocalAlias func(alias, base lnwire.ShortChannelID, gossip bool) error + // AuxLeafStore is an optional store that can be used to store auxiliary + // leaves for certain custom channel types. + AuxLeafStore fn.Option[lnwallet.AuxLeafStore] + // PongBuf is a slice we'll reuse instead of allocating memory on the // heap. Since only reads will occur and no writes, there is no need // for any synchronization primitives. As a result, it's safe to share @@ -869,8 +873,12 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( } } + var chanOpts []lnwallet.ChannelOpt + p.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) lnChan, err := lnwallet.NewLightningChannel( - p.cfg.Signer, dbChan, p.cfg.SigPool, + p.cfg.Signer, dbChan, p.cfg.SigPool, chanOpts..., ) if err != nil { return nil, err @@ -3977,6 +3985,10 @@ func (p *Brontide) addActiveChannel(c *lnpeer.NewChannel) error { chanOpts = append(chanOpts, lnwallet.WithSkipNonceInit()) } + p.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) { + chanOpts = append(chanOpts, lnwallet.WithLeafStore(s)) + }) + // If not already active, we'll add this channel to the set of active // channels, so we can look it up later easily according to its channel // ID. diff --git a/server.go b/server.go index e2c0b831a21..650cdbced40 100644 --- a/server.go +++ b/server.go @@ -157,6 +157,8 @@ type server struct { cfg *Config + implCfg *ImplementationCfg + // identityECDH is an ECDH capable wrapper for the private key used // to authenticate any incoming connections. identityECDH keychain.SingleKeyECDH @@ -480,8 +482,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, nodeKeyDesc *keychain.KeyDescriptor, chansToRestore walletunlocker.ChannelsToRecover, chanPredicate chanacceptor.ChannelAcceptor, - torController *tor.Controller, tlsManager *TLSManager) (*server, - error) { + torController *tor.Controller, tlsManager *TLSManager, + implCfg *ImplementationCfg) (*server, error) { var ( err error @@ -567,6 +569,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s := &server{ cfg: cfg, + implCfg: implCfg, graphDB: dbs.GraphDB.ChannelGraph(), chanStateDB: dbs.ChanStateDB.ChannelStateDB(), addrSource: dbs.ChanStateDB, @@ -1245,6 +1248,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return &pc.Incoming }, + AuxLeafStore: implCfg.AuxLeafStore, }, dbs.ChanStateDB) // Select the configuration and funding parameters for Bitcoin. @@ -1578,6 +1582,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, br, err := lnwallet.NewBreachRetribution( channel, commitHeight, 0, nil, + implCfg.AuxLeafStore, ) if err != nil { return nil, 0, err @@ -3906,6 +3911,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, AddLocalAlias: s.aliasMgr.AddLocalAlias, DisallowRouteBlinding: s.cfg.ProtocolOptions.NoRouteBlinding(), Quit: s.quit, + AuxLeafStore: s.implCfg.AuxLeafStore, } copy(pCfg.PubKeyBytes[:], peerAddr.IdentityKey.SerializeCompressed()) From 72f7b80c28088ad93f0ee5bfd0503454ea8644ea Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Thu, 25 Apr 2024 19:01:37 +0200 Subject: [PATCH 14/14] lnwallet: thread thru input.AuxTapleaf to all relevant areas In this commit, we start to thread thru the new aux tap leaf structures to all relevant areas. This includes: commitment outputs, resolution creation, breach handling, and also HTLC scripts. --- contractcourt/chain_watcher.go | 24 ++- input/size_test.go | 4 +- lnwallet/channel.go | 259 ++++++++++++++++++++++++++++----- lnwallet/channel_test.go | 5 +- lnwallet/commitment.go | 66 ++++++--- lnwallet/transactions.go | 12 +- 6 files changed, 298 insertions(+), 72 deletions(-) diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index aa14ae05106..89280d1bd1c 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -424,15 +424,30 @@ func (c *chainWatcher) handleUnknownLocalState( &c.cfg.chanState.LocalChanCfg, &c.cfg.chanState.RemoteChanCfg, ) + auxLeaves, err := lnwallet.AuxLeavesFromCommit( + c.cfg.chanState, c.cfg.chanState.LocalCommitment, + c.cfg.auxLeafStore, *commitKeyRing, + ) + if err != nil { + return false, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // With the keys derived, we'll construct the remote script that'll be // present if they have a non-dust balance on the commitment. var leaseExpiry uint32 if c.cfg.chanState.ChanType.HasLeaseExpiration() { leaseExpiry = c.cfg.chanState.ThawHeight } + + remoteAuxLeaf := fn.MapOption( + func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + }, + )(auxLeaves) remoteScript, _, err := lnwallet.CommitScriptToRemote( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, - commitKeyRing.ToRemoteKey, leaseExpiry, input.NoneTapLeaf(), + commitKeyRing.ToRemoteKey, leaseExpiry, + fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return false, err @@ -441,11 +456,16 @@ func (c *chainWatcher) handleUnknownLocalState( // Next, we'll derive our script that includes the revocation base for // the remote party allowing them to claim this output before the CSV // delay if we breach. + localAuxLeaf := fn.MapOption( + func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + }, + )(auxLeaves) localScript, err := lnwallet.CommitScriptToSelf( c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator, commitKeyRing.ToLocalKey, commitKeyRing.RevocationKey, uint32(c.cfg.chanState.LocalChanCfg.CsvDelay), leaseExpiry, - input.NoneTapLeaf(), + fn.FlattenOption(localAuxLeaf), ) if err != nil { return false, err diff --git a/input/size_test.go b/input/size_test.go index c4e6fae377a..4cc1680ba48 100644 --- a/input/size_test.go +++ b/input/size_test.go @@ -1386,7 +1386,7 @@ func genTimeoutTx(t *testing.T, // Create the unsigned timeout tx. timeoutTx, err := lnwallet.CreateHtlcTimeoutTx( chanType, false, testOutPoint, testAmt, testCLTVExpiry, - testCSVDelay, 0, testPubkey, testPubkey, + testCSVDelay, 0, testPubkey, testPubkey, input.NoneTapLeaf(), ) require.NoError(t, err) @@ -1455,7 +1455,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx { // Create the unsigned success tx. successTx, err := lnwallet.CreateHtlcSuccessTx( chanType, false, testOutPoint, testAmt, testCSVDelay, 0, - testPubkey, testPubkey, + testPubkey, testPubkey, input.NoneTapLeaf(), ) require.NoError(t, err) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index b01c2cde304..b048a37b43a 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -810,8 +810,8 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { // restart a channel session. func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, commitHeight uint64, htlc *channeldb.HTLC, localCommitKeys, - remoteCommitKeys *CommitmentKeyRing, isLocal bool) (PaymentDescriptor, - error) { + remoteCommitKeys *CommitmentKeyRing, isLocal bool, + auxLeaf input.AuxTapLeaf) (PaymentDescriptor, error) { // The proper pkScripts for this PaymentDescriptor must be // generated so we can easily locate them within the commitment @@ -835,7 +835,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, if !isDustLocal && localCommitKeys != nil { scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, true, htlc.RefundTimeout, - htlc.RHash, localCommitKeys, + htlc.RHash, localCommitKeys, auxLeaf, ) if err != nil { return pd, err @@ -850,7 +850,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, if !isDustRemote && remoteCommitKeys != nil { scriptInfo, err := genHtlcScript( chanType, htlc.Incoming, false, htlc.RefundTimeout, - htlc.RHash, remoteCommitKeys, + htlc.RHash, remoteCommitKeys, auxLeaf, ) if err != nil { return pd, err @@ -907,7 +907,8 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // for each side. func (lc *LightningChannel) extractPayDescs(commitHeight uint64, feeRate chainfee.SatPerKWeight, htlcs []channeldb.HTLC, localCommitKeys, - remoteCommitKeys *CommitmentKeyRing, isLocal bool) ([]PaymentDescriptor, + remoteCommitKeys *CommitmentKeyRing, isLocal bool, + auxLeaves fn.Option[CommitAuxLeaves]) ([]PaymentDescriptor, []PaymentDescriptor, error) { var ( @@ -925,10 +926,21 @@ func (lc *LightningChannel) extractPayDescs(commitHeight uint64, htlc := htlc + auxLeaf := fn.MapOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + if htlc.Incoming { + leaves = l.IncomingHtlcLeaves + } + + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(auxLeaves) + payDesc, err := lc.diskHtlcToPayDesc( feeRate, commitHeight, &htlc, localCommitKeys, remoteCommitKeys, - isLocal, + isLocal, fn.FlattenOption(auxLeaf), ) if err != nil { return incomingHtlcs, outgoingHtlcs, err @@ -972,14 +984,28 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, ) } + auxLeaves, err := AuxLeavesFromCommit( + lc.channelState, *diskCommit, lc.leafStore, + func() CommitmentKeyRing { + if isLocal { + return *localCommitKeys + } + + return *remoteCommitKeys + }(), + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // With the key rings re-created, we'll now convert all the on-disk // HTLC"s into PaymentDescriptor's so we can re-insert them into our // update log. incomingHtlcs, outgoingHtlcs, err := lc.extractPayDescs( diskCommit.CommitHeight, chainfee.SatPerKWeight(diskCommit.FeePerKw), - diskCommit.Htlcs, localCommitKeys, remoteCommitKeys, - isLocal, + diskCommit.Htlcs, localCommitKeys, remoteCommitKeys, isLocal, + auxLeaves, ) if err != nil { return nil, err @@ -1583,7 +1609,8 @@ func (lc *LightningChannel) ResetState() { func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, remoteUpdateLog *updateLog, commitHeight uint64, feeRate chainfee.SatPerKWeight, remoteCommitKeys *CommitmentKeyRing, - remoteDustLimit btcutil.Amount) (*PaymentDescriptor, error) { + remoteDustLimit btcutil.Amount, + auxLeaves fn.Option[CommitAuxLeaves]) (*PaymentDescriptor, error) { // Depending on the type of update message we'll map that to a distinct // PaymentDescriptor instance. @@ -1619,10 +1646,17 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, wireMsg.Amount.ToSatoshis(), remoteDustLimit, ) if !isDustRemote { + auxLeaf := fn.MapOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + return leaves[pd.HtlcIndex].AuxTapLeaf + }, + )(auxLeaves) + scriptInfo, err := genHtlcScript( lc.channelState.ChanType, false, false, wireMsg.Expiry, wireMsg.PaymentHash, - remoteCommitKeys, + remoteCommitKeys, fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -2290,6 +2324,14 @@ func (lc *LightningChannel) restorePendingLocalUpdates( pendingCommit := pendingRemoteCommitDiff.Commitment pendingHeight := pendingCommit.CommitHeight + auxLeaves, err := AuxLeavesFromCommit( + lc.channelState, pendingCommit, lc.leafStore, + *pendingRemoteKeys, + ) + if err != nil { + return fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // If we did have a dangling commit, then we'll examine which updates // we included in that state and re-insert them into our update log. for _, logUpdate := range pendingRemoteCommitDiff.LogUpdates { @@ -2299,7 +2341,7 @@ func (lc *LightningChannel) restorePendingLocalUpdates( &logUpdate, lc.remoteUpdateLog, pendingHeight, chainfee.SatPerKWeight(pendingCommit.FeePerKw), pendingRemoteKeys, - lc.channelState.RemoteChanCfg.DustLimit, + lc.channelState.RemoteChanCfg.DustLimit, auxLeaves, ) if err != nil { return err @@ -2506,24 +2548,35 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, leaseExpiry = chanState.ThawHeight } - // TODO(roasbeef): fetch aux leave + auxLeaves, err := auxLeavesFromRevocation( + chanState, revokedLog, leafStore, *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } // Since it is the remote breach we are reconstructing, the output // going to us will be a to-remote script with our local params. + localAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + })(auxLeaves) isRemoteInitiator := !chanState.IsInitiator ourScript, ourDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, input.NoneTapLeaf(), + leaseExpiry, fn.FlattenOption(localAuxLeaf), ) if err != nil { return nil, err } + remoteAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + })(auxLeaves) theirDelay := uint32(chanState.RemoteChanCfg.CsvDelay) theirScript, err := CommitScriptToSelf( chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey, keyRing.RevocationKey, theirDelay, leaseExpiry, - input.NoneTapLeaf(), + fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return nil, err @@ -2541,7 +2594,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, if revokedLog != nil { br, ourAmt, theirAmt, err = createBreachRetribution( revokedLog, spendTx, chanState, keyRing, - commitmentSecret, leaseExpiry, + commitmentSecret, leaseExpiry, auxLeaves, ) if err != nil { return nil, err @@ -2675,7 +2728,8 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, func createHtlcRetribution(chanState *channeldb.OpenChannel, keyRing *CommitmentKeyRing, commitHash chainhash.Hash, commitmentSecret *btcec.PrivateKey, leaseExpiry uint32, - htlc *channeldb.HTLCEntry) (HtlcRetribution, error) { + htlc *channeldb.HTLCEntry, + auxLeaves fn.Option[CommitAuxLeaves]) (HtlcRetribution, error) { var emptyRetribution HtlcRetribution @@ -2685,10 +2739,21 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // We'll generate the original second level witness script now, as // we'll need it if we're revoking an HTLC output on the remote // commitment transaction, and *they* go to the second level. + secondLevelAuxLeaf := fn.MapOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + idx := input.HtlcIndex(htlc.HtlcIndex.Val) + + if htlc.Incoming.Val { + return l.IncomingHtlcLeaves[idx].SecondLevelLeaf + } + + return l.OutgoingHtlcLeaves[idx].SecondLevelLeaf + }, + )(auxLeaves) secondLevelScript, err := SecondLevelHtlcScript( chanState.ChanType, isRemoteInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, theirDelay, - leaseExpiry, + leaseExpiry, fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return emptyRetribution, err @@ -2699,9 +2764,19 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // HTLC script. Otherwise, is this was an outgoing HTLC that we sent, // then from the PoV of the remote commitment state, they're the // receiver of this HTLC. + htlcLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + idx := input.HtlcIndex(htlc.HtlcIndex.Val) + + if htlc.Incoming.Val { + return l.IncomingHtlcLeaves[idx].AuxTapLeaf + } + + return l.OutgoingHtlcLeaves[idx].AuxTapLeaf + })(auxLeaves) scriptInfo, err := genHtlcScript( chanState.ChanType, htlc.Incoming.Val, false, htlc.RefundTimeout.Val, htlc.RHash.Val, keyRing, + fn.FlattenOption(htlcLeaf), ) if err != nil { return emptyRetribution, err @@ -2765,7 +2840,9 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, func createBreachRetribution(revokedLog *channeldb.RevocationLog, spendTx *wire.MsgTx, chanState *channeldb.OpenChannel, keyRing *CommitmentKeyRing, commitmentSecret *btcec.PrivateKey, - leaseExpiry uint32) (*BreachRetribution, int64, int64, error) { + leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*BreachRetribution, int64, int64, + error) { commitHash := revokedLog.CommitTxHash @@ -2774,7 +2851,7 @@ func createBreachRetribution(revokedLog *channeldb.RevocationLog, for i, htlc := range revokedLog.HTLCEntries { hr, err := createHtlcRetribution( chanState, keyRing, commitHash.Val, - commitmentSecret, leaseExpiry, htlc, + commitmentSecret, leaseExpiry, htlc, auxLeaves, ) if err != nil { return nil, 0, 0, err @@ -2922,6 +2999,7 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, hr, err := createHtlcRetribution( chanState, keyRing, commitHash, commitmentSecret, leaseExpiry, entry, + fn.None[CommitAuxLeaves](), ) if err != nil { return nil, 0, 0, err @@ -3400,6 +3478,9 @@ func processAddEntry(htlc *PaymentDescriptor, ourBalance, theirBalance *lnwire.M *ourBalance -= htlc.Amount } + // TODO(roasbef): also have it modify balances here + // * obtain for HTLC as well? + if mutateState { *addHeight = nextHeight } @@ -3520,6 +3601,15 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, var err error cancelChan := make(chan struct{}) + auxLeaves, err := AuxLeavesFromCommit( + chanState, *remoteCommitView.toDiskCommit(false), leafStore, + *keyRing, + ) + if err != nil { + return nil, nil, fmt.Errorf("unable to fetch aux leaves: "+ + "%w", err) + } + // For each outgoing and incoming HTLC, if the HTLC isn't considered a // dust output after taking into account second-level HTLC fees, then a // sigJob will be generated and appended to the current batch. @@ -3546,6 +3636,14 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.MapOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.IncomingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) + // With the fee calculate, we can properly create the HTLC // timeout transaction using the HTLC amount minus the fee. op := wire.OutPoint{ @@ -3556,11 +3654,14 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, chanType, isRemoteInitiator, op, outputAmt, htlc.Timeout, uint32(remoteChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + fn.FlattenOption(auxLeaf), ) if err != nil { return nil, nil, err } + // TODO(roasbeef): hook up signer interface here + // Construct a full hash cache as we may be signing a segwit v1 // sighash. txOut := remoteCommitView.txn.TxOut[htlc.remoteOutputIndex] @@ -3587,7 +3688,8 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // If this is a taproot channel, then we'll need to set the // method type to ensure we generate a valid signature. if chanType.IsTaproot() { - sigJob.SignDesc.SignMethod = input.TaprootScriptSpendSignMethod //nolint:lll + //nolint:lll + sigJob.SignDesc.SignMethod = input.TaprootScriptSpendSignMethod } sigBatch = append(sigBatch, sigJob) @@ -3613,6 +3715,14 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.MapOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.OutgoingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) + // With the proper output amount calculated, we can now // generate the success transaction using the remote party's // CSV delay. @@ -3624,6 +3734,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, chanType, isRemoteInitiator, op, outputAmt, uint32(remoteChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + fn.FlattenOption(auxLeaf), ) if err != nil { return nil, nil, err @@ -4871,6 +4982,9 @@ func (lc *LightningChannel) computeView(view *HtlcView, remoteChain bool, // rate. view.FeePerKw = commitChain.tip().feePerKw + // TODO(roasbeef): also need to pass blob here as well for final + // balances? + // We evaluate the view at this stage, meaning settled and failed HTLCs // will remove their corresponding added HTLCs. The resulting filtered // view will only have Add entries left, making it easy to compare the @@ -4959,6 +5073,15 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, len(localCommitmentView.outgoingHTLCs)) verifyJobs := make([]VerifyJob, 0, numHtlcs) + auxLeaves, err := AuxLeavesFromCommit( + chanState, *localCommitmentView.toDiskCommit(true), leafStore, + *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", + err) + } + // We'll iterate through each output in the commitment transaction, // populating the sigHash closure function if it's detected to be an // HLTC output. Given the sighash, and the signing key, we'll be able @@ -4992,11 +5115,20 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, htlcFee := HtlcSuccessFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.MapOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.IncomingHtlcLeaves + idx := htlc.HtlcIndex + return leaves[idx].SecondLevelLeaf + })(auxLeaves) + successTx, err := CreateHtlcSuccessTx( chanType, isLocalInitiator, op, outputAmt, uint32(localChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -5076,12 +5208,21 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, htlcFee := HtlcTimeoutFee(chanType, feePerKw) outputAmt := htlc.Amount.ToSatoshis() - htlcFee + auxLeaf := fn.MapOption(func( + l CommitAuxLeaves) input.AuxTapLeaf { + + leaves := l.OutgoingHtlcLeaves + idx := htlc.HtlcIndex + return leaves[idx].SecondLevelLeaf + })(auxLeaves) + timeoutTx, err := CreateHtlcTimeoutTx( chanType, isLocalInitiator, op, outputAmt, htlc.Timeout, uint32(localChanCfg.CsvDelay), leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey, + fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -6799,6 +6940,13 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) + auxLeaves, err := AuxLeavesFromCommit( + chanState, remoteCommit, leafStore, *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + // Next, we'll obtain HTLC resolutions for all the outgoing HTLC's we // had on their commitment transaction. var leaseExpiry uint32 @@ -6810,7 +6958,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, chainfee.SatPerKWeight(remoteCommit.FeePerKw), isOurCommit, signer, remoteCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitSpend.SpendingTx, - chanState.ChanType, isRemoteInitiator, leaseExpiry, + chanState.ChanType, isRemoteInitiator, leaseExpiry, auxLeaves, ) if err != nil { return nil, fmt.Errorf("unable to create htlc "+ @@ -6819,14 +6967,17 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, commitTxBroadcast := commitSpend.SpendingTx - // TODO(roasbeef): fetch aux leave - // Before we can generate the proper sign descriptor, we'll need to // locate the output index of our non-delayed output on the commitment // transaction. + // + // TODO(roasbeef): helper func to hide flatten + remoteAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.RemoteAuxLeaf + })(auxLeaves) selfScript, maturityDelay, err := CommitScriptToRemote( chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey, - leaseExpiry, input.NoneTapLeaf(), + leaseExpiry, fn.FlattenOption(remoteAuxLeaf), ) if err != nil { return nil, fmt.Errorf("unable to create self commit "+ @@ -7065,8 +7216,8 @@ func newOutgoingHtlcResolution(signer input.Signer, localChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, - localCommit, isCommitFromInitiator bool, - chanType channeldb.ChannelType) (*OutgoingHtlcResolution, error) { + localCommit, isCommitFromInitiator bool, chanType channeldb.ChannelType, + auxLeaves fn.Option[CommitAuxLeaves]) (*OutgoingHtlcResolution, error) { op := wire.OutPoint{ Hash: commitTx.TxHash(), @@ -7075,9 +7226,12 @@ func newOutgoingHtlcResolution(signer input.Signer, // First, we'll re-generate the script used to send the HTLC to the // remote party within their commitment transaction. + auxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.OutgoingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf + })(auxLeaves) htlcScriptInfo, err := genHtlcScript( chanType, false, localCommit, htlc.RefundTimeout, htlc.RHash, - keyRing, + keyRing, fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -7150,10 +7304,17 @@ func newOutgoingHtlcResolution(signer input.Signer, // With the fee calculated, re-construct the second level timeout // transaction. + secondLevelAuxLeaf := fn.MapOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.OutgoingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) timeoutTx, err := CreateHtlcTimeoutTx( chanType, isCommitFromInitiator, op, secondLevelOutputAmt, - htlc.RefundTimeout, csvDelay, leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + htlc.RefundTimeout, csvDelay, leaseExpiry, + keyRing.RevocationKey, keyRing.ToLocalKey, + fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7236,6 +7397,7 @@ func newOutgoingHtlcResolution(signer input.Signer, htlcSweepScript, err = SecondLevelHtlcScript( chanType, isCommitFromInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, leaseExpiry, + fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7244,7 +7406,7 @@ func newOutgoingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, - fn.None[txscript.TapLeaf](), + fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7318,8 +7480,8 @@ func newIncomingHtlcResolution(signer input.Signer, localChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, - localCommit, isCommitFromInitiator bool, chanType channeldb.ChannelType) ( - *IncomingHtlcResolution, error) { + localCommit, isCommitFromInitiator bool, chanType channeldb.ChannelType, + auxLeaves fn.Option[CommitAuxLeaves]) (*IncomingHtlcResolution, error) { op := wire.OutPoint{ Hash: commitTx.TxHash(), @@ -7328,9 +7490,12 @@ func newIncomingHtlcResolution(signer input.Signer, // First, we'll re-generate the script the remote party used to // send the HTLC to us in their commitment transaction. + auxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.IncomingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf + })(auxLeaves) scriptInfo, err := genHtlcScript( chanType, true, localCommit, htlc.RefundTimeout, htlc.RHash, - keyRing, + keyRing, fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -7390,6 +7555,13 @@ func newIncomingHtlcResolution(signer input.Signer, }, nil } + secondLevelAuxLeaf := fn.MapOption( + func(l CommitAuxLeaves) input.AuxTapLeaf { + leaves := l.IncomingHtlcLeaves + return leaves[htlc.HtlcIndex].SecondLevelLeaf + }, + )(auxLeaves) + // Otherwise, we'll need to go to the second level to sweep this HTLC. // // First, we'll reconstruct the original HTLC success transaction, @@ -7399,7 +7571,7 @@ func newIncomingHtlcResolution(signer input.Signer, successTx, err := CreateHtlcSuccessTx( chanType, isCommitFromInitiator, op, secondLevelOutputAmt, csvDelay, leaseExpiry, keyRing.RevocationKey, - keyRing.ToLocalKey, + keyRing.ToLocalKey, fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7482,6 +7654,7 @@ func newIncomingHtlcResolution(signer input.Signer, htlcSweepScript, err = SecondLevelHtlcScript( chanType, isCommitFromInitiator, keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, leaseExpiry, + fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7490,7 +7663,7 @@ func newIncomingHtlcResolution(signer input.Signer, //nolint:lll secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree( keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay, - fn.None[txscript.TapLeaf](), + fn.FlattenOption(secondLevelAuxLeaf), ) if err != nil { return nil, err @@ -7582,7 +7755,8 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, signer input.Signer, htlcs []channeldb.HTLC, keyRing *CommitmentKeyRing, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, chanType channeldb.ChannelType, - isCommitFromInitiator bool, leaseExpiry uint32) (*HtlcResolutions, error) { + isCommitFromInitiator bool, leaseExpiry uint32, + auxLeaves fn.Option[CommitAuxLeaves]) (*HtlcResolutions, error) { // TODO(roasbeef): don't need to swap csv delay? dustLimit := remoteChanCfg.DustLimit @@ -7617,6 +7791,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, signer, localChanCfg, commitTx, &htlc, keyRing, feePerKw, uint32(csvDelay), leaseExpiry, ourCommit, isCommitFromInitiator, chanType, + auxLeaves, ) if err != nil { return nil, fmt.Errorf("incoming resolution "+ @@ -7630,7 +7805,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, ohr, err := newOutgoingHtlcResolution( signer, localChanCfg, commitTx, &htlc, keyRing, feePerKw, uint32(csvDelay), leaseExpiry, ourCommit, - isCommitFromInitiator, chanType, + isCommitFromInitiator, chanType, auxLeaves, ) if err != nil { return nil, fmt.Errorf("outgoing resolution "+ @@ -7865,11 +8040,19 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, // recovery there is not much we can do with HTLCs, so we'll always // use what we have in our latest state when extracting resolutions. localCommit := chanState.LocalCommitment + + auxLeaves, err := AuxLeavesFromCommit( + chanState, localCommit, leafStore, *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } + htlcResolutions, err := extractHtlcResolutions( chainfee.SatPerKWeight(localCommit.FeePerKw), true, signer, localCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitTx, chanState.ChanType, - chanState.IsInitiator, leaseExpiry, + chanState.IsInitiator, leaseExpiry, auxLeaves, ) if err != nil { return nil, fmt.Errorf("unable to gen htlc resolution: %w", err) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 24caef62d63..4914fa13946 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -9965,7 +9965,7 @@ func TestCreateHtlcRetribution(t *testing.T) { // Create the htlc retribution. hr, err := createHtlcRetribution( aliceChannel.channelState, keyRing, commitHash, - dummyPrivate, leaseExpiry, htlc, + dummyPrivate, leaseExpiry, htlc, fn.None[CommitAuxLeaves](), ) // Expect no error. require.NoError(t, err) @@ -10171,6 +10171,7 @@ func TestCreateBreachRetribution(t *testing.T) { tc.revocationLog, tx, aliceChannel.channelState, keyRing, dummyPrivate, leaseExpiry, + fn.None[CommitAuxLeaves](), ) // Check the error if expected. @@ -10410,7 +10411,7 @@ func TestExtractPayDescs(t *testing.T) { // NOTE: we use nil commitment key rings to avoid checking the htlc // scripts(`genHtlcScript`) as it should be tested independently. incomingPDs, outgoingPDs, err := lnChan.extractPayDescs( - 0, 0, htlcs, nil, nil, true, + 0, 0, htlcs, nil, nil, true, fn.None[CommitAuxLeaves](), ) require.NoError(t, err) diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index fa8e8e08330..bdd10555a10 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -420,14 +420,14 @@ func sweepSigHash(chanType channeldb.ChannelType) txscript.SigHashType { // argument should correspond to the owner of the commitment transaction which // we are generating the to_local script for. func SecondLevelHtlcScript(chanType channeldb.ChannelType, initiator bool, - revocationKey, delayKey *btcec.PublicKey, - csvDelay, leaseExpiry uint32) (input.ScriptDescriptor, error) { + revocationKey, delayKey *btcec.PublicKey, csvDelay, leaseExpiry uint32, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { switch { // For taproot channels, the pkScript is a segwit v1 p2tr output. case chanType.IsTaproot(): return input.TaprootSecondLevelScriptTree( - revocationKey, delayKey, csvDelay, input.NoneTapLeaf(), + revocationKey, delayKey, csvDelay, auxLeaf, ) // If we are the initiator of a leased channel, then we have an @@ -902,10 +902,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, theirBalance -= commitFeeMSat } - var ( - commitTx *wire.MsgTx - err error - ) + var commitTx *wire.MsgTx // Before we create the commitment transaction below, we'll try to see // if there're any aux leaves that need to be a part of the tapscript @@ -947,6 +944,19 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, return nil, err } + // Similarly, we'll now attempt to extract the set of aux leaves for + // the set of incoming and outgoing HTLCs. + incomingAuxLeaves := fn.MapOption( + func(leaves CommitAuxLeaves) input.AuxTapLeaves { + return leaves.IncomingHtlcLeaves + }, + )(auxLeaves) + outgoingAuxLeaves := fn.MapOption( + func(leaves CommitAuxLeaves) input.AuxTapLeaves { + return leaves.OutgoingHtlcLeaves + }, + )(auxLeaves) + // We'll now add all the HTLC outputs to the commitment transaction. // Each output includes an off-chain 2-of-2 covenant clause, so we'll // need the objective local/remote keys for this particular commitment @@ -966,9 +976,15 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } + auxLeaf := fn.MapOption( + func(leaves input.AuxTapLeaves) input.AuxTapLeaf { + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(outgoingAuxLeaves) + err := addHTLC( commitTx, isOurs, false, htlc, keyRing, - cb.chanState.ChanType, + cb.chanState.ChanType, fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -984,9 +1000,15 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, continue } + auxLeaf := fn.MapOption( + func(leaves input.AuxTapLeaves) input.AuxTapLeaf { + return leaves[htlc.HtlcIndex].AuxTapLeaf + }, + )(incomingAuxLeaves) + err := addHTLC( commitTx, isOurs, true, htlc, keyRing, - cb.chanState.ChanType, + cb.chanState.ChanType, fn.FlattenOption(auxLeaf), ) if err != nil { return nil, err @@ -1266,8 +1288,8 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // genTaprootHtlcScript generates the HTLC scripts for a taproot+musig2 // channel. func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, - rHash [32]byte, - keyRing *CommitmentKeyRing) (*input.HtlcScriptTree, error) { + rHash [32]byte, keyRing *CommitmentKeyRing, + auxLeaf input.AuxTapLeaf) (*input.HtlcScriptTree, error) { var ( htlcScriptTree *input.HtlcScriptTree @@ -1284,8 +1306,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case isIncoming && ourCommit: htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, - input.NoneTapLeaf(), + keyRing.RevocationKey, rHash[:], ourCommit, auxLeaf, ) // We're being paid via an HTLC by the remote party, and the HTLC is @@ -1294,8 +1315,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case isIncoming && !ourCommit: htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, - input.NoneTapLeaf(), + keyRing.RevocationKey, rHash[:], ourCommit, auxLeaf, ) // We're sending an HTLC which is being added to our commitment @@ -1304,8 +1324,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case !isIncoming && ourCommit: htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, - input.NoneTapLeaf(), + keyRing.RevocationKey, rHash[:], ourCommit, auxLeaf, ) // Finally, we're paying the remote party via an HTLC, which is being @@ -1314,8 +1333,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case !isIncoming && !ourCommit: htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, - input.NoneTapLeaf(), + keyRing.RevocationKey, rHash[:], ourCommit, auxLeaf, ) } @@ -1329,8 +1347,8 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, // we need to sign for the remote party (2nd level HTLCs) is also returned // along side the multiplexer. func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, - timeout uint32, rHash [32]byte, - keyRing *CommitmentKeyRing) (input.ScriptDescriptor, error) { + timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing, + auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) { if !chanType.IsTaproot() { return genSegwitV0HtlcScript( @@ -1340,7 +1358,7 @@ func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, } return genTaprootHtlcScript( - isIncoming, ourCommit, timeout, rHash, keyRing, + isIncoming, ourCommit, timeout, rHash, keyRing, auxLeaf, ) } @@ -1353,13 +1371,15 @@ func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, // the descriptor itself. func addHTLC(commitTx *wire.MsgTx, ourCommit bool, isIncoming bool, paymentDesc *PaymentDescriptor, - keyRing *CommitmentKeyRing, chanType channeldb.ChannelType) error { + keyRing *CommitmentKeyRing, chanType channeldb.ChannelType, + auxLeaf input.AuxTapLeaf) error { timeout := paymentDesc.Timeout rHash := paymentDesc.RHash scriptInfo, err := genHtlcScript( chanType, isIncoming, ourCommit, timeout, rHash, keyRing, + auxLeaf, ) if err != nil { return err diff --git a/lnwallet/transactions.go b/lnwallet/transactions.go index 1cf954d3cb5..da86650bc63 100644 --- a/lnwallet/transactions.go +++ b/lnwallet/transactions.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/input" ) const ( @@ -50,8 +51,8 @@ var ( // - func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, htlcOutput wire.OutPoint, htlcAmt btcutil.Amount, csvDelay, - leaseExpiry uint32, revocationKey, delayKey *btcec.PublicKey) ( - *wire.MsgTx, error) { + leaseExpiry uint32, revocationKey, delayKey *btcec.PublicKey, + auxLeaf input.AuxTapLeaf) (*wire.MsgTx, error) { // Create a version two transaction (as the success version of this // spends an output with a CSV timeout). @@ -71,7 +72,7 @@ func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, // HTLC outputs. scriptInfo, err := SecondLevelHtlcScript( chanType, initiator, revocationKey, delayKey, csvDelay, - leaseExpiry, + leaseExpiry, auxLeaf, ) if err != nil { return nil, err @@ -110,7 +111,8 @@ func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool, func CreateHtlcTimeoutTx(chanType channeldb.ChannelType, initiator bool, htlcOutput wire.OutPoint, htlcAmt btcutil.Amount, cltvExpiry, csvDelay, leaseExpiry uint32, - revocationKey, delayKey *btcec.PublicKey) (*wire.MsgTx, error) { + revocationKey, delayKey *btcec.PublicKey, + auxLeaf input.AuxTapLeaf) (*wire.MsgTx, error) { // Create a version two transaction (as the success version of this // spends an output with a CSV timeout), and set the lock-time to the @@ -134,7 +136,7 @@ func CreateHtlcTimeoutTx(chanType channeldb.ChannelType, initiator bool, // HTLC outputs. scriptInfo, err := SecondLevelHtlcScript( chanType, initiator, revocationKey, delayKey, csvDelay, - leaseExpiry, + leaseExpiry, auxLeaf, ) if err != nil { return nil, err