From 6576f53906e5871c32ce7e54e072fa68d7966e0f Mon Sep 17 00:00:00 2001 From: t-bast Date: Thu, 11 Aug 2022 15:27:35 +0200 Subject: [PATCH 1/2] Remove invalid channel updates from DB at startup Following #2361, we reject channel updates that don't contain the `htlc_maximum_msat` field. However, the network DB may contain such channel updates, that we need to remove when starting up. --- .../scala/fr/acinq/eclair/db/NetworkDb.scala | 2 +- .../fr/acinq/eclair/db/pg/PgNetworkDb.scala | 15 ++++++-- .../eclair/db/sqlite/SqliteNetworkDb.scala | 29 ++++++++++----- .../fr/acinq/eclair/db/NetworkDbSpec.scala | 37 ++++++++++++++++++- 4 files changed, 68 insertions(+), 15 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala index ec3031fd0a..a2536602a2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala @@ -41,7 +41,7 @@ trait NetworkDb { def updateChannel(u: ChannelUpdate): Unit - def removeChannel(shortChannelId: ShortChannelId) = removeChannels(Set(shortChannelId)): Unit + def removeChannel(shortChannelId: ShortChannelId): Unit = removeChannels(Set(shortChannelId)) def removeChannels(shortChannelIds: Iterable[ShortChannelId]): Unit diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala index c44ca7c892..a16a67c9e7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala @@ -17,8 +17,7 @@ package fr.acinq.eclair.db.pg import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, Satoshi} -import fr.acinq.eclair.ShortChannelId -import fr.acinq.eclair.RealShortChannelId +import fr.acinq.eclair.{RealShortChannelId, ShortChannelId} import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics import fr.acinq.eclair.db.Monitoring.Tags.DbBackends import fr.acinq.eclair.db.NetworkDb @@ -80,7 +79,17 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { if (v < 4) { migration34(statement) } - case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do + case Some(CURRENT_VERSION) => + // We clean up channels that contain an invalid channel update (e.g. missing htlc_maximum_msat). + val invalidChannels = statement.executeQuery("SELECT short_channel_id, channel_update_1, channel_update_2 FROM network.public_channels").map(rs => { + val shortChannelId = rs.getLong("short_channel_id") + val validChannelUpdate1 = rs.getBitVectorOpt("channel_update_1").forall(channelUpdateCodec.decode(_).isSuccessful) + val validChannelUpdate2 = rs.getBitVectorOpt("channel_update_2").forall(channelUpdateCodec.decode(_).isSuccessful) + (shortChannelId, validChannelUpdate1 && validChannelUpdate2) + }).filter { case (_, isValid) => !isValid }.map { case (scid, _) => scid }.toSeq + invalidChannels.foreach(scid => { + statement.executeUpdate(s"DELETE FROM network.public_channels WHERE short_channel_id=$scid") + }) case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } setVersion(statement, DB_NAME, CURRENT_VERSION) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala index b05e743ffb..b73835b8d8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala @@ -17,8 +17,7 @@ package fr.acinq.eclair.db.sqlite import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, Satoshi} -import fr.acinq.eclair.ShortChannelId -import fr.acinq.eclair.RealShortChannelId +import fr.acinq.eclair.{RealShortChannelId, ShortChannelId} import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics import fr.acinq.eclair.db.Monitoring.Tags.DbBackends import fr.acinq.eclair.db.NetworkDb @@ -61,7 +60,17 @@ class SqliteNetworkDb(val sqlite: Connection) extends NetworkDb with Logging { case Some(v@1) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration12(statement) - case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do + case Some(CURRENT_VERSION) => + // We clean up channels that contain an invalid channel update (e.g. missing htlc_maximum_msat). + val invalidChannels = statement.executeQuery("SELECT short_channel_id, channel_update_1, channel_update_2 FROM channels").map(rs => { + val shortChannelId = rs.getLong("short_channel_id") + val validChannelUpdate1 = rs.getBitVectorOpt("channel_update_1").forall(channelUpdateCodec.decode(_).isSuccessful) + val validChannelUpdate2 = rs.getBitVectorOpt("channel_update_2").forall(channelUpdateCodec.decode(_).isSuccessful) + (shortChannelId, validChannelUpdate1 && validChannelUpdate2) + }).filter { case (_, isValid) => !isValid }.map { case (scid, _) => scid }.toSeq + invalidChannels.foreach(scid => { + statement.executeUpdate(s"DELETE FROM channels WHERE short_channel_id=$scid") + }) case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } setVersion(statement, DB_NAME, CURRENT_VERSION) @@ -129,12 +138,12 @@ class SqliteNetworkDb(val sqlite: Connection) extends NetworkDb with Logging { using(sqlite.createStatement()) { statement => statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels") .foldLeft(SortedMap.empty[RealShortChannelId, PublicChannel]) { (m, rs) => - val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value - val txId = ByteVector32.fromValidHex(rs.getString("txid")) - val capacity = rs.getLong("capacity_sat") - val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value) - val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value) - m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None)) + val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value + val txId = ByteVector32.fromValidHex(rs.getString("txid")) + val capacity = rs.getLong("capacity_sat") + val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value) + val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value) + m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None)) } } } @@ -166,7 +175,7 @@ class SqliteNetworkDb(val sqlite: Connection) extends NetworkDb with Logging { } override def removeFromPruned(shortChannelId: RealShortChannelId): Unit = withMetrics("network/remove-from-pruned", DbBackends.Sqlite) { - using(sqlite.prepareStatement(s"DELETE FROM pruned WHERE short_channel_id=?")) { statement => + using(sqlite.prepareStatement("DELETE FROM pruned WHERE short_channel_id=?")) { statement => statement.setLong(1, shortChannelId.toLong) statement.executeUpdate() } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/NetworkDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/NetworkDbSpec.scala index 3d56d767fb..1af4ce6196 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/NetworkDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/NetworkDbSpec.scala @@ -31,6 +31,7 @@ import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{channelAnnouncement import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshiLong, RealShortChannelId, ShortChannelId, TestDatabases, randomBytes32, randomKey} import org.scalatest.funsuite.AnyFunSuite +import scodec.bits.HexStringSyntax import scala.collection.{SortedMap, mutable} import scala.util.Random @@ -63,7 +64,7 @@ class NetworkDbSpec extends AnyFunSuite { assert(db.listNodes().toSet == Set.empty) db.addNode(node_1) db.addNode(node_1) // duplicate is ignored - assert(db.getNode(node_1.nodeId) == Some(node_1)) + assert(db.getNode(node_1.nodeId).contains(node_1)) assert(db.listNodes().size == 1) db.addNode(node_2) db.addNode(node_3) @@ -326,6 +327,40 @@ class NetworkDbSpec extends AnyFunSuite { ) } + test("remove channel updates without htlc_maximum_msat") { + forAllDbs { dbs => + val t1 = channelTestCases(0) + val t2 = channelTestCases(1) + val db1 = dbs.network + db1.addChannel(t1.channel, t1.txid, t2.capacity) + db1.addChannel(t2.channel, t2.txid, t2.capacity) + // The DB contains a channel update missing the `htlc_maximum_msat` field. + val channelUpdateWithoutHtlcMax = hex"12540b6a236e21932622d61432f52913d9442cc09a1057c386119a286153f8681c66d2a0f17d32505ba71bb37c8edcfa9c11e151b2b38dae98b825eff1c040b36fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d619000000000008850f00058e00015e6a782e0000009000000000000003e8000003e800000002" + dbs match { + case sqlite: TestSqliteDatabases => + using(sqlite.connection.prepareStatement("UPDATE channels SET channel_update_1=? WHERE short_channel_id=?")) { statement => + statement.setBytes(1, channelUpdateWithoutHtlcMax.toArray) + statement.setLong(2, t1.shortChannelId.toLong) + statement.executeUpdate() + } + case pg: TestPgDatabases => + using(pg.connection.prepareStatement("UPDATE network.public_channels SET channel_update_1=? WHERE short_channel_id=?")) { statement => + statement.setBytes(1, channelUpdateWithoutHtlcMax.toArray) + statement.setLong(2, t1.shortChannelId.toLong) + statement.executeUpdate() + } + } + assertThrows[IllegalArgumentException](db1.listChannels()) + // We restart eclair and automatically clean up invalid entries. + val db2 = dbs match { + case sqlite: TestSqliteDatabases => new SqliteNetworkDb(sqlite.connection) + case pg: TestPgDatabases => new PgNetworkDb()(pg.datasource) + } + val channels = db2.listChannels() + assert(channels.keySet == Set(t2.shortChannelId)) + } + } + test("json column reset (postgres)") { val dbs = TestPgDatabases() val db = dbs.network From 55b919f8ae079a559e7200c39b74f26bbc896304 Mon Sep 17 00:00:00 2001 From: t-bast Date: Thu, 11 Aug 2022 16:45:20 +0200 Subject: [PATCH 2/2] fixup! Remove invalid channel updates from DB at startup --- .../scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala | 11 +++++------ .../fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala | 11 +++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala index a16a67c9e7..e744f401c0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala @@ -17,13 +17,13 @@ package fr.acinq.eclair.db.pg import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, Satoshi} -import fr.acinq.eclair.{RealShortChannelId, ShortChannelId} import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics import fr.acinq.eclair.db.Monitoring.Tags.DbBackends import fr.acinq.eclair.db.NetworkDb import fr.acinq.eclair.router.Router.PublicChannel import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec} import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement} +import fr.acinq.eclair.{RealShortChannelId, ShortChannelId} import grizzled.slf4j.Logging import scodec.bits.BitVector @@ -81,15 +81,14 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { } case Some(CURRENT_VERSION) => // We clean up channels that contain an invalid channel update (e.g. missing htlc_maximum_msat). - val invalidChannels = statement.executeQuery("SELECT short_channel_id, channel_update_1, channel_update_2 FROM network.public_channels").map(rs => { + statement.executeQuery("SELECT short_channel_id, channel_update_1, channel_update_2 FROM network.public_channels").map(rs => { val shortChannelId = rs.getLong("short_channel_id") val validChannelUpdate1 = rs.getBitVectorOpt("channel_update_1").forall(channelUpdateCodec.decode(_).isSuccessful) val validChannelUpdate2 = rs.getBitVectorOpt("channel_update_2").forall(channelUpdateCodec.decode(_).isSuccessful) (shortChannelId, validChannelUpdate1 && validChannelUpdate2) - }).filter { case (_, isValid) => !isValid }.map { case (scid, _) => scid }.toSeq - invalidChannels.foreach(scid => { - statement.executeUpdate(s"DELETE FROM network.public_channels WHERE short_channel_id=$scid") - }) + }).collect { + case (scid, false) => statement.executeUpdate(s"DELETE FROM network.public_channels WHERE short_channel_id=$scid") + } case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } setVersion(statement, DB_NAME, CURRENT_VERSION) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala index b73835b8d8..b6b42ee7fd 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala @@ -17,13 +17,13 @@ package fr.acinq.eclair.db.sqlite import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, Satoshi} -import fr.acinq.eclair.{RealShortChannelId, ShortChannelId} import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics import fr.acinq.eclair.db.Monitoring.Tags.DbBackends import fr.acinq.eclair.db.NetworkDb import fr.acinq.eclair.router.Router.PublicChannel import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec} import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement} +import fr.acinq.eclair.{RealShortChannelId, ShortChannelId} import grizzled.slf4j.Logging import java.sql.{Connection, Statement} @@ -62,15 +62,14 @@ class SqliteNetworkDb(val sqlite: Connection) extends NetworkDb with Logging { migration12(statement) case Some(CURRENT_VERSION) => // We clean up channels that contain an invalid channel update (e.g. missing htlc_maximum_msat). - val invalidChannels = statement.executeQuery("SELECT short_channel_id, channel_update_1, channel_update_2 FROM channels").map(rs => { + statement.executeQuery("SELECT short_channel_id, channel_update_1, channel_update_2 FROM channels").map(rs => { val shortChannelId = rs.getLong("short_channel_id") val validChannelUpdate1 = rs.getBitVectorOpt("channel_update_1").forall(channelUpdateCodec.decode(_).isSuccessful) val validChannelUpdate2 = rs.getBitVectorOpt("channel_update_2").forall(channelUpdateCodec.decode(_).isSuccessful) (shortChannelId, validChannelUpdate1 && validChannelUpdate2) - }).filter { case (_, isValid) => !isValid }.map { case (scid, _) => scid }.toSeq - invalidChannels.foreach(scid => { - statement.executeUpdate(s"DELETE FROM channels WHERE short_channel_id=$scid") - }) + }).collect { + case (scid, false) => statement.executeUpdate(s"DELETE FROM channels WHERE short_channel_id=$scid") + } case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } setVersion(statement, DB_NAME, CURRENT_VERSION)