Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ trait NetworkDb {

def updateChannel(u: ChannelUpdate): Unit

def removeChannel(shortChannelId: ShortChannelId) = removeChannels(Set(shortChannelId)): Unit
def removeChannel(shortChannelId: ShortChannelId): Unit = removeChannels(Set(shortChannelId))

def removeChannels(shortChannelIds: Iterable[ShortChannelId]): Unit

Expand Down
14 changes: 11 additions & 3 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
package fr.acinq.eclair.db.pg

import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, Satoshi}
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.RealShortChannelId
import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics
import fr.acinq.eclair.db.Monitoring.Tags.DbBackends
import fr.acinq.eclair.db.NetworkDb
import fr.acinq.eclair.router.Router.PublicChannel
import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec}
import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement}
import fr.acinq.eclair.{RealShortChannelId, ShortChannelId}
import grizzled.slf4j.Logging
import scodec.bits.BitVector

Expand Down Expand Up @@ -80,7 +79,16 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging {
if (v < 4) {
migration34(statement)
}
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
case Some(CURRENT_VERSION) =>
// We clean up channels that contain an invalid channel update (e.g. missing htlc_maximum_msat).
statement.executeQuery("SELECT short_channel_id, channel_update_1, channel_update_2 FROM network.public_channels").map(rs => {
val shortChannelId = rs.getLong("short_channel_id")
val validChannelUpdate1 = rs.getBitVectorOpt("channel_update_1").forall(channelUpdateCodec.decode(_).isSuccessful)
val validChannelUpdate2 = rs.getBitVectorOpt("channel_update_2").forall(channelUpdateCodec.decode(_).isSuccessful)
(shortChannelId, validChannelUpdate1 && validChannelUpdate2)
}).collect {
case (scid, false) => statement.executeUpdate(s"DELETE FROM network.public_channels WHERE short_channel_id=$scid")
}
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
setVersion(statement, DB_NAME, CURRENT_VERSION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
package fr.acinq.eclair.db.sqlite

import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, Satoshi}
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.RealShortChannelId
import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics
import fr.acinq.eclair.db.Monitoring.Tags.DbBackends
import fr.acinq.eclair.db.NetworkDb
import fr.acinq.eclair.router.Router.PublicChannel
import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec}
import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement}
import fr.acinq.eclair.{RealShortChannelId, ShortChannelId}
import grizzled.slf4j.Logging

import java.sql.{Connection, Statement}
Expand Down Expand Up @@ -61,7 +60,16 @@ class SqliteNetworkDb(val sqlite: Connection) extends NetworkDb with Logging {
case Some(v@1) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
migration12(statement)
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
case Some(CURRENT_VERSION) =>
// We clean up channels that contain an invalid channel update (e.g. missing htlc_maximum_msat).
statement.executeQuery("SELECT short_channel_id, channel_update_1, channel_update_2 FROM channels").map(rs => {
val shortChannelId = rs.getLong("short_channel_id")
val validChannelUpdate1 = rs.getBitVectorOpt("channel_update_1").forall(channelUpdateCodec.decode(_).isSuccessful)
val validChannelUpdate2 = rs.getBitVectorOpt("channel_update_2").forall(channelUpdateCodec.decode(_).isSuccessful)
(shortChannelId, validChannelUpdate1 && validChannelUpdate2)
}).collect {
case (scid, false) => statement.executeUpdate(s"DELETE FROM channels WHERE short_channel_id=$scid")
}
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
setVersion(statement, DB_NAME, CURRENT_VERSION)
Expand Down Expand Up @@ -129,12 +137,12 @@ class SqliteNetworkDb(val sqlite: Connection) extends NetworkDb with Logging {
using(sqlite.createStatement()) { statement =>
statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels")
.foldLeft(SortedMap.empty[RealShortChannelId, PublicChannel]) { (m, rs) =>
val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value
val txId = ByteVector32.fromValidHex(rs.getString("txid"))
val capacity = rs.getLong("capacity_sat")
val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value)
val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value)
m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None))
val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value
val txId = ByteVector32.fromValidHex(rs.getString("txid"))
val capacity = rs.getLong("capacity_sat")
val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value)
val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value)
m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None))
}
}
}
Expand Down Expand Up @@ -166,7 +174,7 @@ class SqliteNetworkDb(val sqlite: Connection) extends NetworkDb with Logging {
}

override def removeFromPruned(shortChannelId: RealShortChannelId): Unit = withMetrics("network/remove-from-pruned", DbBackends.Sqlite) {
using(sqlite.prepareStatement(s"DELETE FROM pruned WHERE short_channel_id=?")) { statement =>
using(sqlite.prepareStatement("DELETE FROM pruned WHERE short_channel_id=?")) { statement =>
statement.setLong(1, shortChannelId.toLong)
statement.executeUpdate()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{channelAnnouncement
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshiLong, RealShortChannelId, ShortChannelId, TestDatabases, randomBytes32, randomKey}
import org.scalatest.funsuite.AnyFunSuite
import scodec.bits.HexStringSyntax

import scala.collection.{SortedMap, mutable}
import scala.util.Random
Expand Down Expand Up @@ -63,7 +64,7 @@ class NetworkDbSpec extends AnyFunSuite {
assert(db.listNodes().toSet == Set.empty)
db.addNode(node_1)
db.addNode(node_1) // duplicate is ignored
assert(db.getNode(node_1.nodeId) == Some(node_1))
assert(db.getNode(node_1.nodeId).contains(node_1))
assert(db.listNodes().size == 1)
db.addNode(node_2)
db.addNode(node_3)
Expand Down Expand Up @@ -326,6 +327,40 @@ class NetworkDbSpec extends AnyFunSuite {
)
}

test("remove channel updates without htlc_maximum_msat") {
forAllDbs { dbs =>
val t1 = channelTestCases(0)
val t2 = channelTestCases(1)
val db1 = dbs.network
db1.addChannel(t1.channel, t1.txid, t2.capacity)
db1.addChannel(t2.channel, t2.txid, t2.capacity)
// The DB contains a channel update missing the `htlc_maximum_msat` field.
val channelUpdateWithoutHtlcMax = hex"12540b6a236e21932622d61432f52913d9442cc09a1057c386119a286153f8681c66d2a0f17d32505ba71bb37c8edcfa9c11e151b2b38dae98b825eff1c040b36fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d619000000000008850f00058e00015e6a782e0000009000000000000003e8000003e800000002"
dbs match {
case sqlite: TestSqliteDatabases =>
using(sqlite.connection.prepareStatement("UPDATE channels SET channel_update_1=? WHERE short_channel_id=?")) { statement =>
statement.setBytes(1, channelUpdateWithoutHtlcMax.toArray)
statement.setLong(2, t1.shortChannelId.toLong)
statement.executeUpdate()
}
case pg: TestPgDatabases =>
using(pg.connection.prepareStatement("UPDATE network.public_channels SET channel_update_1=? WHERE short_channel_id=?")) { statement =>
statement.setBytes(1, channelUpdateWithoutHtlcMax.toArray)
statement.setLong(2, t1.shortChannelId.toLong)
statement.executeUpdate()
}
}
assertThrows[IllegalArgumentException](db1.listChannels())
// We restart eclair and automatically clean up invalid entries.
val db2 = dbs match {
case sqlite: TestSqliteDatabases => new SqliteNetworkDb(sqlite.connection)
case pg: TestPgDatabases => new PgNetworkDb()(pg.datasource)
}
val channels = db2.listChannels()
assert(channels.keySet == Set(t2.shortChannelId))
}
}

test("json column reset (postgres)") {
val dbs = TestPgDatabases()
val db = dbs.network
Expand Down