diff --git a/electrum/commands.py b/electrum/commands.py index a886197a4537..4f4ef36de42f 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -1405,18 +1405,18 @@ async def add_hold_invoice( assert inbound_capacity > satoshis(amount or 0), \ f"Not enough inbound capacity [{inbound_capacity} sat] to receive this payment" + wallet.lnworker.add_payment_info_for_hold_invoice( + bfh(payment_hash), + lightning_amount_sat=satoshis(amount) if amount else None, + min_final_cltv_delta=min_final_cltv_expiry_delta, + exp_delay=expiry, + ) + info = wallet.lnworker.get_payment_info(bfh(payment_hash)) lnaddr, invoice = wallet.lnworker.get_bolt11_invoice( - payment_hash=bfh(payment_hash), - amount_msat=satoshis(amount) * 1000 if amount else None, + payment_info=info, message=memo, - expiry=expiry, - min_final_cltv_expiry_delta=min_final_cltv_expiry_delta, fallback_address=None ) - wallet.lnworker.add_payment_info_for_hold_invoice( - bfh(payment_hash), - satoshis(amount) if amount else None, - ) wallet.lnworker.dont_settle_htlcs[payment_hash] = None wallet.set_label(payment_hash, memo) result = { diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index bad460b103a5..41844bde87c1 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2498,16 +2498,15 @@ def maybe_fulfill_htlc( Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded. Return (preimage, (payment_key, callback)) with at most a single element not None. """ + htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) if not processed_onion.are_we_final: if not self.lnworker.enable_htlc_forwarding: return None, None - # use the htlc key if we are forwarding - payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) callback = lambda: self.maybe_forward_htlc( incoming_chan=chan, htlc=htlc, processed_onion=processed_onion) - return None, (payment_key, callback) + return None, (htlc_key, callback) # use the htlc key if we are forwarding def log_fail_reason(reason: str): self.logger.info( @@ -2544,10 +2543,10 @@ def log_fail_reason(reason: str): ): return None, None - # TODO check against actual min_final_cltv_expiry_delta from invoice (and give 2-3 blocks of leeway?) + blocks_to_expiry = htlc.cltv_abs - local_height # note: payment_bundles might get split here, e.g. one payment is "already forwarded" and the other is not. # In practice, for the swap prepayment use case, this does not matter. - if local_height + MIN_FINAL_CLTV_DELTA_ACCEPTED > htlc.cltv_abs and not already_forwarded: + if blocks_to_expiry < MIN_FINAL_CLTV_DELTA_ACCEPTED and not already_forwarded: log_fail_reason(f"htlc.cltv_abs is unreasonably close") raise exc_incorrect_or_unknown_pd @@ -2581,10 +2580,6 @@ def log_fail_reason(reason: str): fw_payment_key=payment_key) return None, (payment_key, callback) - # TODO don't accept payments twice for same invoice - # note: we don't check invoice expiry (bolt11 'x' field) on the receiver-side. - # - semantics are weird: would make sense for simple-payment-receives, but not - # if htlc is expected to be pending for a while, e.g. for a hold-invoice. info = self.lnworker.get_payment_info(payment_hash) if info is None: log_fail_reason(f"no payment_info found for RHASH {htlc.payment_hash.hex()}") @@ -2605,6 +2600,27 @@ def log_fail_reason(reason: str): log_fail_reason(f"total_msat={total_msat} too different from invoice_msat={invoice_msat}") raise exc_incorrect_or_unknown_pd + if htlc_key not in self.lnworker.verified_pending_htlcs: + # these checks against the PaymentInfo have to be done only once after + # receiving the htlc + valid_expiry = htlc.timestamp < info.expiration_ts + if not valid_expiry and not already_forwarded: + log_fail_reason(f"invoice already expired: {info.expiration_ts=}") + raise exc_incorrect_or_unknown_pd + + valid_cltv = blocks_to_expiry >= info.min_final_cltv_delta + will_settle = preimage is not None and payment_hash.hex() not in self.lnworker.dont_settle_htlcs + if not valid_cltv and not will_settle and not already_forwarded: + # this check only really matters for htlcs which don't get settled right away + log_fail_reason(f"remaining locktime smaller than requested {blocks_to_expiry=} < {info.min_final_cltv_delta=}") + raise exc_incorrect_or_unknown_pd + + if info.status == PR_PAID: + log_fail_reason(f"invoice has already been paid") + raise exc_incorrect_or_unknown_pd + + self.lnworker.verified_pending_htlcs[htlc_key] = None + hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash) if hold_invoice_callback and not preimage: callback = lambda: hold_invoice_callback(payment_hash) @@ -3099,6 +3115,8 @@ async def htlc_switch(self): self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) if forwarding_key: self.lnworker.maybe_cleanup_forwarding(forwarding_key) + htlc_key = serialize_htlc_key(chan.short_channel_id, htlc_id) + self.lnworker.verified_pending_htlcs.pop(htlc_key, None) done.add(htlc_id) continue if onion_packet_hex is None: diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 3903f6979886..1db9a6c9b9a4 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -503,9 +503,11 @@ class LNProtocolWarning(Exception): # the minimum cltv_expiry accepted for newly received HTLCs # note: when changing, consider Blockchain.is_tip_stale() MIN_FINAL_CLTV_DELTA_ACCEPTED = 144 -# set it a tiny bit higher for invoices as blocks could get mined -# during forward path of payment -MIN_FINAL_CLTV_DELTA_FOR_INVOICE = MIN_FINAL_CLTV_DELTA_ACCEPTED + 3 +MIN_FINAL_CLTV_DELTA_FOR_INVOICE = MIN_FINAL_CLTV_DELTA_ACCEPTED +# Buffer added to the min final cltv delta of all created bolt11 invoices so that the received htlcs +# locktime is still above the limit requested by the creator of the invoice even if some blocks got +# mined during forwarding +MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE = 3 # the deadline for offered HTLCs: # the deadline after which the channel has to be failed and timed out on-chain diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 0261e3fe0604..590456a17287 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -20,6 +20,8 @@ from concurrent import futures import urllib.parse import itertools +import dataclasses +from dataclasses import dataclass import aiohttp import dns.asyncresolver @@ -67,7 +69,8 @@ LnKeyFamily, LOCAL, REMOTE, MIN_FINAL_CLTV_DELTA_FOR_INVOICE, SENT, RECEIVED, HTLCOwner, UpdateAddHtlc, LnFeatures, ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage, OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget, - NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT + NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT, + MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE, ) from .lnonion import decode_onion_error, OnionFailureCode, OnionRoutingFailure, OnionPacket from .lnmsg import decode_msg @@ -106,12 +109,44 @@ class PaymentDirection(IntEnum): FORWARDING = 3 -class PaymentInfo(NamedTuple): - payment_hash: bytes +@stored_in('lightning_payments') +@dataclass(frozen=True) +class PaymentInfo: + """Information required to handle incoming htlcs for a payment request""" + rhash: str amount_msat: Optional[int] + # direction is being used with PaymentDirection and lnutil.Direction? direction: int status: int + min_final_cltv_delta: int + # expiration can be used to clean-up PaymentInfo and fail htlcs coming in too late + expiry_delay: int + creation_ts: int = dataclasses.field(default_factory=lambda: int(time.time())) + @property + def payment_hash(self) -> bytes: + return bytes.fromhex(self.rhash) + + @property + def expiration_ts(self): + return self.creation_ts + self.expiry_delay + + def validate(self): + assert isinstance(self.rhash, str), type(self.rhash) + assert self.amount_msat is None or isinstance(self.amount_msat, int) + assert isinstance(self.direction, int) + assert isinstance(self.status, int) + assert isinstance(self.min_final_cltv_delta, int) + assert isinstance(self.expiry_delay, int) and self.expiry_delay > 0 + assert isinstance(self.creation_ts, int) + + def __post_init__(self): + self.validate() + + def to_json(self): + # required because PaymentInfo doesn't inherit StoredObject so it can be declared frozen + self.validate() + return dataclasses.asdict(self) # Note: these states are persisted in the wallet file. # Do not modify them without performing a wallet db upgrade @@ -869,7 +904,7 @@ def __init__(self, wallet: 'Abstract_Wallet', xprv): LNWorker.__init__(self, self.node_keypair, features, config=self.config) self.lnwatcher = LNWatcher(self) self.lnrater: LNRater = None - self.payment_info = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid + self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, PaymentInfo] self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage self._bolt11_cache = {} # note: this sweep_address is only used as fallback; as it might result in address-reuse @@ -896,6 +931,7 @@ def __init__(self, wallet: 'Abstract_Wallet', xprv): self._paysessions = dict() # type: Dict[bytes, PaySession] self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo] self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs') # type: Dict[str, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus + self.verified_pending_htlcs = self.db.get_dict('verified_pending_htlcs') # type: Dict[str, None] # htlc_key, to keep track of checks that have to be done only once when receiving the htlc # detect inflight payments self.inflight_payments = set() # (not persisted) keys of invoices that are in PR_INFLIGHT state @@ -1567,7 +1603,7 @@ async def pay_invoice( raise PaymentFailure(_("A payment was already initiated for this invoice")) if payment_hash in self.get_payments(status='inflight'): raise PaymentFailure(_("A previous attempt to pay this invoice did not clear")) - info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID) + info = PaymentInfo(key, amount_to_pay, SENT, PR_UNPAID, min_final_cltv_delta, LN_EXPIRY_NEVER) self.save_payment_info(info) self.wallet.set_label(key, lnaddr.get_description()) self.set_invoice_status(key, PR_INFLIGHT) @@ -2238,17 +2274,13 @@ def clear_invoices_cache(self): def get_bolt11_invoice( self, *, - payment_hash: bytes, - amount_msat: Optional[int], + payment_info: PaymentInfo, message: str, - expiry: int, # expiration of invoice (in seconds, relative) fallback_address: Optional[str], channels: Optional[Sequence[Channel]] = None, - min_final_cltv_expiry_delta: Optional[int] = None, ) -> Tuple[LnAddr, str]: - assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}" - - pair = self._bolt11_cache.get(payment_hash) + amount_msat = payment_info.amount_msat + pair = self._bolt11_cache.get(payment_info.payment_hash) if pair: lnaddr, invoice = pair assert lnaddr.get_amount_msat() == amount_msat @@ -2265,19 +2297,16 @@ def get_bolt11_invoice( if needs_jit: # jit only works with single htlcs, mpp will cause LSP to open channels for each htlc invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ - payment_secret = self.get_payment_secret(payment_hash) + payment_secret = self.get_payment_secret(payment_info.payment_hash) amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None - if expiry == 0: - expiry = LN_EXPIRY_NEVER - if min_final_cltv_expiry_delta is None: - min_final_cltv_expiry_delta = MIN_FINAL_CLTV_DELTA_FOR_INVOICE + min_final_cltv_delta_requested = payment_info.min_final_cltv_delta + MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE lnaddr = LnAddr( - paymenthash=payment_hash, + paymenthash=payment_info.payment_hash, amount=amount_btc, tags=[ ('d', message), - ('c', min_final_cltv_expiry_delta), - ('x', expiry), + ('c', min_final_cltv_delta_requested), + ('x', payment_info.expiry_delay), ('9', invoice_features), ('f', fallback_address), ] + routing_hints, @@ -2285,7 +2314,7 @@ def get_bolt11_invoice( payment_secret=payment_secret) invoice = lnencode(lnaddr, self.node_keypair.privkey) pair = lnaddr, invoice - self._bolt11_cache[payment_hash] = pair + self._bolt11_cache[payment_info.payment_hash] = pair return pair def get_payment_secret(self, payment_hash): @@ -2299,10 +2328,17 @@ def _get_payment_key(self, payment_hash: bytes) -> bytes: payment_secret = self.get_payment_secret(payment_hash) return payment_hash + payment_secret - def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes: + def create_payment_info( + self, *, + amount_msat: Optional[int], + min_final_cltv_delta: Optional[int] = None, + exp_delay: int = LN_EXPIRY_NEVER, + write_to_disk=True + ) -> bytes: payment_preimage = os.urandom(32) payment_hash = sha256(payment_preimage) - info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID) + min_final_cltv_delta = min_final_cltv_delta if min_final_cltv_delta else MIN_FINAL_CLTV_DELTA_FOR_INVOICE + info = PaymentInfo(payment_hash.hex(), amount_msat, RECEIVED, PR_UNPAID, min_final_cltv_delta, exp_delay) self.save_preimage(payment_hash, payment_preimage, write_to_disk=False) self.save_payment_info(info, write_to_disk=False) if write_to_disk: @@ -2374,14 +2410,17 @@ def get_payment_info(self, payment_hash: bytes) -> Optional[PaymentInfo]: """returns None if payment_hash is a payment we are forwarding""" key = payment_hash.hex() with self.lock: - if key in self.payment_info: - amount_msat, direction, status = self.payment_info[key] - return PaymentInfo(payment_hash, amount_msat, direction, status) - return None + return self.payment_info.get(key, None) - def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: Optional[int]): + def add_payment_info_for_hold_invoice( + self, + payment_hash: bytes, *, + lightning_amount_sat: Optional[int], + min_final_cltv_delta: int, + exp_delay: int, + ): amount = lightning_amount_sat * 1000 if lightning_amount_sat else None - info = PaymentInfo(payment_hash, amount, RECEIVED, PR_UNPAID) + info = PaymentInfo(payment_hash.hex(), amount, RECEIVED, PR_UNPAID, min_final_cltv_delta, exp_delay) self.save_payment_info(info, write_to_disk=False) def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]): @@ -2396,11 +2435,13 @@ def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> if old_info := self.get_payment_info(payment_hash=info.payment_hash): if info == old_info: return # already saved - if info != old_info._replace(status=info.status): + if info.direction == SENT: + # allow saving of newer PaymentInfo if it is a sending attempt + old_info = dataclasses.replace(old_info, creation_ts=info.creation_ts) + if info != dataclasses.replace(old_info, status=info.status): # differs more than in status. let's fail - raise Exception("payment_hash already in use") - key = info.payment_hash.hex() - self.payment_info[key] = info.amount_msat, info.direction, info.status + raise Exception(f"payment_hash already in use: {info=} != {old_info=}") + self.payment_info[info.rhash] = info if write_to_disk: self.wallet.save_db() @@ -2577,7 +2618,7 @@ def set_payment_status(self, payment_hash: bytes, status: int) -> None: if info is None: # if we are forwarding return - info = info._replace(status=status) + info = dataclasses.replace(info, status=status) self.save_payment_info(info) def is_forwarded_htlc(self, htlc_key) -> Optional[str]: @@ -3016,12 +3057,14 @@ async def rebalance_channels(self, chan1: Channel, chan2: Channel, *, amount_msa raise Exception('Rebalance requires two different channels') if self.uses_trampoline() and chan1.node_id == chan2.node_id: raise Exception('Rebalance requires channels from different trampolines') - payment_hash = self.create_payment_info(amount_msat=amount_msat) - lnaddr, invoice = self.get_bolt11_invoice( - payment_hash=payment_hash, + payment_hash = self.create_payment_info( amount_msat=amount_msat, + exp_delay=3600, + ) + info = self.get_payment_info(payment_hash) + lnaddr, invoice = self.get_bolt11_invoice( + payment_info=info, message='rebalance', - expiry=3600, fallback_address=None, channels=[chan2], ) diff --git a/electrum/plugins/nwc/nwcserver.py b/electrum/plugins/nwc/nwcserver.py index 64b16a56b36f..c2cd925f59ce 100644 --- a/electrum/plugins/nwc/nwcserver.py +++ b/electrum/plugins/nwc/nwcserver.py @@ -480,12 +480,11 @@ async def handle_make_invoice(self, request_event: nEvent, params: dict): address=None ) req: Request = self.wallet.get_request(key) + info = self.wallet.lnworker.get_payment_info(req.payment_hash) try: lnaddr, b11 = self.wallet.lnworker.get_bolt11_invoice( - payment_hash=req.payment_hash, - amount_msat=amount_msat, + payment_info=info, message=description, - expiry=expiry, fallback_address=None ) except Exception: @@ -538,11 +537,10 @@ async def handle_lookup_invoice(self, request_event: nEvent, params: dict): b11 = invoice.lightning_invoice elif self.wallet.get_request(invoice.rhash): direction = "incoming" + info = self.wallet.lnworker.get_payment_info(invoice.payment_hash) _, b11 = self.wallet.lnworker.get_bolt11_invoice( - payment_hash=bytes.fromhex(invoice.rhash), - amount_msat=invoice.amount_msat, + payment_info=info, message=invoice.message, - expiry=invoice.exp, fallback_address=None ) @@ -749,11 +747,10 @@ def on_event_request_status(self, wallet, key, status): request: Optional[Request] = self.wallet.get_request(key) if not request or not request.is_lightning() or not status == PR_PAID: return + info = self.wallet.lnworker.get_payment_info(request.payment_hash) _, b11 = self.wallet.lnworker.get_bolt11_invoice( - payment_hash=request.payment_hash, - amount_msat=request.get_amount_msat(), + payment_info=info, message=request.message, - expiry=request.exp, fallback_address=None ) diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index bf006ac73058..2c27d1211b18 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -36,7 +36,8 @@ run_sync_function_on_asyncio_thread, trigger_callback, NoDynamicFeeEstimates, UserFacingException, ) from . import lnutil -from .lnutil import hex_to_bytes, REDEEM_AFTER_DOUBLE_SPENT_DELAY, Keypair +from .lnutil import (hex_to_bytes, REDEEM_AFTER_DOUBLE_SPENT_DELAY, Keypair, + MIN_FINAL_CLTV_DELTA_FOR_INVOICE) from .lnaddr import lndecode from .json_db import StoredObject, stored_in from . import constants @@ -644,33 +645,38 @@ def add_normal_swap( else: invoice_amount_sat = lightning_amount_sat + # add payment info to lnworker + self.lnworker.add_payment_info_for_hold_invoice( + payment_hash, + lightning_amount_sat=invoice_amount_sat, + min_final_cltv_delta=min_final_cltv_expiry_delta or MIN_FINAL_CLTV_DELTA_FOR_INVOICE, + exp_delay=300, + ) + info = self.lnworker.get_payment_info(payment_hash) lnaddr1, invoice = self.lnworker.get_bolt11_invoice( - payment_hash=payment_hash, - amount_msat=invoice_amount_sat * 1000, + payment_info=info, message='Submarine swap', - expiry=300, fallback_address=None, channels=channels, - min_final_cltv_expiry_delta=min_final_cltv_expiry_delta, ) margin_to_get_refund_tx_mined = MIN_LOCKTIME_DELTA if not (locktime + margin_to_get_refund_tx_mined < self.network.get_local_height() + lnaddr1.get_min_final_cltv_delta()): raise Exception( f"onchain locktime ({locktime}+{margin_to_get_refund_tx_mined}) " f"too close to LN-htlc-expiry ({self.network.get_local_height()+lnaddr1.get_min_final_cltv_delta()})") - # add payment info to lnworker - self.lnworker.add_payment_info_for_hold_invoice(payment_hash, invoice_amount_sat) if prepay: - prepay_hash = self.lnworker.create_payment_info(amount_msat=prepay_amount_sat*1000) + prepay_hash = self.lnworker.create_payment_info( + amount_msat=prepay_amount_sat*1000, + min_final_cltv_delta=min_final_cltv_expiry_delta or MIN_FINAL_CLTV_DELTA_FOR_INVOICE, + exp_delay=300, + ) + info = self.lnworker.get_payment_info(prepay_hash) lnaddr2, prepay_invoice = self.lnworker.get_bolt11_invoice( - payment_hash=prepay_hash, - amount_msat=prepay_amount_sat * 1000, + payment_info=info, message='Submarine swap prepayment', - expiry=300, fallback_address=None, channels=channels, - min_final_cltv_expiry_delta=min_final_cltv_expiry_delta, ) self.lnworker.bundle_payments([payment_hash, prepay_hash]) self._prepayments[prepay_hash] = payment_hash diff --git a/electrum/wallet.py b/electrum/wallet.py index e28778abddbd..c51ccdee8612 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -2995,11 +2995,11 @@ def get_bolt11_invoice(self, req: Request) -> str: return '' amount_msat = req.get_amount_msat() or None assert (amount_msat is None or amount_msat > 0), amount_msat + info = self.lnworker.get_payment_info(payment_hash) + assert info.amount_msat == amount_msat, f"{info.amount_msat=} != {amount_msat=}" lnaddr, invoice = self.lnworker.get_bolt11_invoice( - payment_hash=payment_hash, - amount_msat=amount_msat, + payment_info=info, message=req.message, - expiry=req.exp, fallback_address=None) return invoice @@ -3015,7 +3015,11 @@ def create_request(self, amount_sat: Optional[int], message: Optional[str], exp_ timestamp = int(Request._get_cur_time()) if address is None: assert self.has_lightning() - payment_hash = self.lnworker.create_payment_info(amount_msat=amount_msat, write_to_disk=False) + payment_hash = self.lnworker.create_payment_info( + amount_msat=amount_msat, + exp_delay=exp_delay, + write_to_disk=False, + ) else: payment_hash = None outputs = [PartialTxOutput.from_address_and_value(address, amount_sat)] if address else [] diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 8f37430dc4b8..fb333b233c51 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -73,7 +73,7 @@ def __init__(self, wallet_db: 'WalletDB'): # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 60 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 61 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -236,6 +236,7 @@ def upgrade(self): self._convert_version_58() self._convert_version_59() self._convert_version_60() + self._convert_version_61() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): @@ -1157,6 +1158,23 @@ def _convert_version_60(self): cb['multisig_funding_privkey'] = None self.data['seed_version'] = 60 + def _convert_version_61(self): + if not self._is_upgrade_method_needed(60, 60): + return + lightning_payments = self.data.get('lightning_payments', {}) + for rhash, (amount_msat, direction, is_paid) in list(lightning_payments.items()): + new_dataclass_type = { + 'rhash': rhash, + 'amount_msat': amount_msat, + 'direction': direction, + 'status': is_paid, + 'min_final_cltv_delta': 144, + 'expiry_delay': 100 * 365 * 24 * 60 * 60, # LN_EXPIRY_NEVER + 'creation_ts': int(time.time()), + } + lightning_payments[rhash] = new_dataclass_type + self.data['seed_version'] = 61 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 01b3074a6288..76f0158c0724 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -11,6 +11,7 @@ from concurrent import futures from unittest import mock from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence +import time from aiorpcx import timeout_after, TaskTimeout from electrum_ecc import ECPrivkey @@ -41,7 +42,7 @@ from electrum.lnonion import OnionFailureCode, OnionRoutingFailure from electrum.lnutil import UpdateAddHtlc from electrum.lnutil import LOCAL, REMOTE -from electrum.invoices import PR_PAID, PR_UNPAID, Invoice +from electrum.invoices import PR_PAID, PR_UNPAID, Invoice, LN_EXPIRY_NEVER from electrum.interface import GracefulDisconnect from electrum.simple_config import SimpleConfig from electrum.fee_policy import FeeTimeEstimates, FEE_ETA_TARGETS @@ -217,6 +218,7 @@ def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_que self.hold_invoice_callbacks = {} self._payment_bundles_pkey_to_canon = {} # type: Dict[bytes, bytes] self._payment_bundles_canon_to_pkeylist = {} # type: Dict[bytes, Sequence[bytes]] + self.verified_pending_htlcs = {} # type: Dict[str, None] self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0 self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}") @@ -553,16 +555,15 @@ def prepare_invoice( payment_hash: bytes = None, invoice_features: LnFeatures = None, min_final_cltv_delta: int = None, + expiry: int = None, ) -> Tuple[LnAddr, Invoice]: amount_btc = amount_msat/Decimal(COIN*1000) if payment_preimage is None and not payment_hash: payment_preimage = os.urandom(32) if payment_hash is None: payment_hash = sha256(payment_preimage) - info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID) if payment_preimage: w2.save_preimage(payment_hash, payment_preimage) - w2.save_payment_info(info) if include_routing_hints: routing_hints = w2.calc_routing_hints_for_invoice(amount_msat) else: @@ -576,6 +577,8 @@ def prepare_invoice( payment_secret = None if min_final_cltv_delta is None: min_final_cltv_delta = lnutil.MIN_FINAL_CLTV_DELTA_FOR_INVOICE + info = PaymentInfo(payment_hash.hex(), amount_msat, RECEIVED, PR_UNPAID, min_final_cltv_delta, expiry or LN_EXPIRY_NEVER) + w2.save_payment_info(info) lnaddr1 = LnAddr( paymenthash=payment_hash, amount=amount_btc, @@ -583,6 +586,7 @@ def prepare_invoice( ('c', min_final_cltv_delta), ('d', 'coffee'), ('9', invoice_features), + ('x', expiry or 3600), ] + routing_hints, payment_secret=payment_secret, ) @@ -941,6 +945,143 @@ async def f(): with self.assertRaises(SuccessfulTest): await f() + async def test_reject_invalid_min_final_cltv_delta(self): + """Tests that htlcs with a final cltv delta < the minimum requested in the invoice get + rejected immediately upon receiving them if we aren't going to settle it""" + async def run_test(test_trampoline): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + async def try_pay_with_too_low_final_cltv_delta(lnaddr, w1=w1, w2=w2): + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) + assert lnaddr.get_min_final_cltv_delta() == 400 # what the receiver expects + lnaddr.tags = [tag for tag in lnaddr.tags if tag[0] != 'c'] + [['c', 144]] + b11 = lnencode(lnaddr, w2.node_keypair.privkey) + pay_req = Invoice.from_bech32(b11) + assert pay_req._lnaddr.get_min_final_cltv_delta() == 144 # what w1 will use to pay + try: + result, log = await asyncio.wait_for(w1.pay_invoice(pay_req), timeout=1) + except asyncio.TimeoutError: + # w2 has no preimage so it will never settle the payment + raise PaymentDone() + if not result: + raise PaymentFailure() + + # create invoice with high min final cltv delta + lnaddr, _pay_req = self.prepare_invoice(w2, min_final_cltv_delta=400) + # the check only applies for htlc which are not being settled right away + w2.dont_settle_htlcs[lnaddr.paymenthash.hex()] = None + del w2._preimages[lnaddr.paymenthash.hex()] + + if test_trampoline: + await self._activate_trampoline(w1) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(try_pay_with_too_low_final_cltv_delta(lnaddr)) + + with self.assertRaises(PaymentFailure): + await f() + + for _test_trampoline in [False, True]: + await run_test(_test_trampoline) + + async def test_reject_payment_for_expired_invoice(self): + """Tests that new htlcs paying an invoice that has already been expired will get rejected.""" + async def run_test(test_trampoline): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + # create lightning invoice in the past, so it is expired + with mock.patch('time.time', return_value=int(time.time()) - 10000): + lnaddr, _pay_req = self.prepare_invoice(w2, expiry=3600) + b11 = lnencode(lnaddr, w2.node_keypair.privkey) + pay_req = Invoice.from_bech32(b11) + + async def try_pay_expired_invoice(pay_req: Invoice, w1=w1): + assert pay_req.has_expired() + assert lnaddr.is_expired() + with mock.patch.object(w1, "_check_bolt11_invoice", return_value=lnaddr): + result, log = await w1.pay_invoice(pay_req) + if not result: + raise PaymentFailure() + raise PaymentDone() + + if test_trampoline: + await self._activate_trampoline(w1) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(try_pay_expired_invoice(pay_req)) + + with self.assertRaises(PaymentFailure): + await f() + + for _test_trampoline in [False, True]: + await run_test(_test_trampoline) + + async def test_reject_multiple_payments_of_same_invoice(self): + """Tests that new htlcs paying an invoice that has already been paid will get rejected.""" + async def run_test(test_trampoline): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + lnaddr, _pay_req = self.prepare_invoice(w2) + + async def try_pay_invoice_twice(pay_req: Invoice, w1=w1): + result, log = await w1.pay_invoice(pay_req) + assert result is True + # now pay the same invoice again, the payment should be rejected by w2 + w1.set_payment_status(pay_req._lnaddr.paymenthash, PR_UNPAID) + try: + result, log = await w1.pay_invoice(pay_req) + except PaymentFailure: + # pay_invoice can also raise PaymentFailure, e.g. payment status is already PR_PAID + raise Exception from PaymentFailure + if not result: + raise PaymentFailure() + raise PaymentDone() + + if test_trampoline: + await self._activate_trampoline(w1) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(try_pay_invoice_twice(_pay_req)) + + with self.assertRaises(PaymentFailure): + await f() + + for _test_trampoline in [False, True]: + await run_test(_test_trampoline) + async def test_payment_race(self): """Alice and Bob pay each other simultaneously. They both send 'update_add_htlc' and receive each other's update