diff --git a/channeldb/db.go b/channeldb/db.go index 17275c2922a..edbc880a5b2 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -230,6 +230,8 @@ type DB struct { graph *ChannelGraph clock clock.Clock dryRun bool + + chanCache *kvdb.Cache } // Open opens or creates channeldb. Any necessary schemas migrations due @@ -260,6 +262,16 @@ func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) { return db, err } +func resetChanStateCache(cache *kvdb.Cache) error { + // We don't use cache for Bolt backend. + if cache == nil { + return nil + } + + cache.Wipe() + return cache.Init() +} + // CreateWithBackend creates channeldb instance using the passed kvdb.Backend. // Any necessary schemas migrations due to updates will take place as necessary. func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB, error) { @@ -274,18 +286,64 @@ func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB, chanDB := &DB{ Backend: backend, - channelStateDB: &ChannelStateDB{ - linkNodeDB: &LinkNodeDB{ - backend: backend, - }, - backend: backend, - }, - clock: opts.clock, - dryRun: opts.dryRun, + clock: opts.clock, + dryRun: opts.dryRun, + } + + chanStateBackend := backend + + // Override the chan state backend if we require to cache chan state. + if opts.ChanStateCache { + skipped := [][]byte{ + // Skip the graph buckets. + nodeBucket, + edgeBucket, + edgeIndexBucket, + graphMetaBucket, + + // Skip some non performance critical large buckets. + closedChannelBucket, + closeSummaryBucket, + fwdPackagesKey, + revocationLogBucket, + } + + topLevel := [][]byte{ + // Read through the graph buckets. + nodeBucket, + edgeBucket, + edgeIndexBucket, + graphMetaBucket, + + // Cache important channel state. + openChannelBucket, + outpointBucket, + nodeInfoBucket, + + // Channel state buckets to read through. + closedChannelBucket, + closeSummaryBucket, + fwdPackagesKey, + } + + cache := kvdb.NewCache(backend, topLevel, skipped) + if err := cache.Init(); err != nil { + return nil, err + } + + chanStateBackend = cache + chanDB.chanCache = cache } - // Set the parent pointer (only used in tests). - chanDB.channelStateDB.parent = chanDB + chanDB.channelStateDB = &ChannelStateDB{ + linkNodeDB: &LinkNodeDB{ + backend: chanStateBackend, + }, + backend: chanStateBackend, + + // Set the parent pointer (only used in tests). + parent: chanDB, + } var err error chanDB.graph, err = NewChannelGraph( @@ -343,7 +401,11 @@ func (d *DB) Wipe() error { return err } - return initChannelDB(d.Backend) + if err := initChannelDB(d.Backend); err != nil { + return err + } + + return resetChanStateCache(d.chanCache) } // initChannelDB creates and initializes a fresh version of channeldb. In the @@ -518,7 +580,6 @@ func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) ( "chan_point=%v: %v", outPoint, err) } oChannel.Db = c - channels = append(channels, oChannel) return nil @@ -1440,6 +1501,11 @@ func MakeTestDB(modifiers ...OptionModifier) (*DB, func(), error) { return nil, nil, err } + // Use a channel state cache when testing with remote backends. + if kvdb.TestBackend != kvdb.BoltBackendName { + modifiers = append(modifiers, OptionWithChannelStateCache(true)) + } + cdb, err := CreateWithBackend(backend, modifiers...) if err != nil { backendCleanup() diff --git a/channeldb/db_test.go b/channeldb/db_test.go index 5731c03a8a2..de7950e3f42 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -41,7 +41,14 @@ func TestOpenWithCreate(t *testing.T) { } defer cleanup() - cdb, err := CreateWithBackend(backend) + var modifiers []OptionModifier + // Use a channel state cache when testing with remote backends. + if kvdb.TestBackend != kvdb.BoltBackendName { + modifiers = append(modifiers, OptionWithChannelStateCache(true)) + } + + cdb, err := CreateWithBackend(backend, modifiers...) + if err != nil { t.Fatalf("unable to create channeldb: %v", err) } @@ -87,7 +94,13 @@ func TestWipe(t *testing.T) { } defer cleanup() - fullDB, err := CreateWithBackend(backend) + var modifiers []OptionModifier + // Use a channel state cache when testing with remote backends. + if kvdb.TestBackend != kvdb.BoltBackendName { + modifiers = append(modifiers, OptionWithChannelStateCache(true)) + } + + fullDB, err := CreateWithBackend(backend, modifiers...) if err != nil { t.Fatalf("unable to create channeldb: %v", err) } diff --git a/channeldb/options.go b/channeldb/options.go index ad22fa8ed21..6823bc5c94e 100644 --- a/channeldb/options.go +++ b/channeldb/options.go @@ -29,6 +29,10 @@ const ( type Options struct { kvdb.BoltBackendConfig + // ChanStateCache when true turns of in-memory caching of important + // channel state buckets. + ChanStateCache bool + // RejectCacheSize is the maximum number of rejectCacheEntries to hold // in the rejection cache. RejectCacheSize int @@ -72,6 +76,14 @@ func DefaultOptions() Options { // OptionModifier is a function signature for modifying the default Options. type OptionModifier func(*Options) +// OptionWithChannelStateCache turns on in-memory caching of important channel +// state buckets. +func OptionWithChannelStateCache(cache bool) OptionModifier { + return func(o *Options) { + o.ChanStateCache = cache + } +} + // OptionSetRejectCacheSize sets the RejectCacheSize to n. func OptionSetRejectCacheSize(n int) OptionModifier { return func(o *Options) { diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index ada24e0c262..e70ce055819 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -378,8 +378,11 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, return err } - // Retrieve attempt info for the notification. - payment, err = fetchPayment(bucket) + p.HTLCs = append(p.HTLCs, HTLCAttempt{ + HTLCAttemptInfo: *attempt, + }) + + payment = p return err }) if err != nil { @@ -405,7 +408,9 @@ func (p *PaymentControl) SettleAttempt(hash lntypes.Hash, } settleBytes := b.Bytes() - return p.updateHtlcKey(hash, attemptID, htlcSettleInfoKey, settleBytes) + return p.updateHtlcKey( + hash, attemptID, htlcSettleInfoKey, settleBytes, settleInfo, nil, + ) } // FailAttempt marks the given payment attempt failed. @@ -418,12 +423,15 @@ func (p *PaymentControl) FailAttempt(hash lntypes.Hash, } failBytes := b.Bytes() - return p.updateHtlcKey(hash, attemptID, htlcFailInfoKey, failBytes) + return p.updateHtlcKey( + hash, attemptID, htlcFailInfoKey, failBytes, nil, failInfo, + ) } // updateHtlcKey updates a database key for the specified htlc. func (p *PaymentControl) updateHtlcKey(paymentHash lntypes.Hash, - attemptID uint64, key, value []byte) (*MPPayment, error) { + attemptID uint64, key, value []byte, settleInfo *HTLCSettleInfo, + failInfo *HTLCFailInfo) (*MPPayment, error) { aid := make([]byte, 8) binary.BigEndian.PutUint64(aid, attemptID) @@ -450,33 +458,43 @@ func (p *PaymentControl) updateHtlcKey(paymentHash lntypes.Hash, return err } - htlcsBucket := bucket.NestedReadWriteBucket(paymentHtlcsBucket) - if htlcsBucket == nil { - return fmt.Errorf("htlcs bucket not found") - } + for i := range p.HTLCs { + if p.HTLCs[i].AttemptID != attemptID { + continue + } - if htlcsBucket.Get(htlcBucketKey(htlcAttemptInfoKey, aid)) == nil { - return fmt.Errorf("HTLC with ID %v not registered", - attemptID) - } + if p.HTLCs[i].Failure != nil { + return ErrAttemptAlreadyFailed + } - // Make sure the shard is not already failed or settled. - if htlcsBucket.Get(htlcBucketKey(htlcFailInfoKey, aid)) != nil { - return ErrAttemptAlreadyFailed - } + if p.HTLCs[i].Settle != nil { + return ErrAttemptAlreadySettled + } - if htlcsBucket.Get(htlcBucketKey(htlcSettleInfoKey, aid)) != nil { - return ErrAttemptAlreadySettled - } + // Udate the DB. + htlcsBucket := bucket.NestedReadWriteBucket( + paymentHtlcsBucket, + ) + if htlcsBucket == nil { + return fmt.Errorf("htlcs bucket not found") + } - // Add or update the key for this htlc. - err = htlcsBucket.Put(htlcBucketKey(key, aid), value) - if err != nil { - return err + // Add or update the key for this htlc. + err = htlcsBucket.Put(htlcBucketKey(key, aid), value) + if err != nil { + return err + } + + // Update the fetched payment. + if settleInfo != nil { + p.HTLCs[i].Settle = settleInfo + } else if failInfo != nil { + p.HTLCs[i].Failure = failInfo + } } - // Retrieve attempt info for the notification. - payment, err = fetchPayment(bucket) + updatePaymentStatus(p) + payment = p return err }) if err != nil { diff --git a/channeldb/payments.go b/channeldb/payments.go index 496b7a5fd1e..35c7c155d8d 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -303,43 +303,11 @@ func fetchCreationInfo(bucket kvdb.RBucket) (*PaymentCreationInfo, error) { return deserializePaymentCreationInfo(r) } -func fetchPayment(bucket kvdb.RBucket) (*MPPayment, error) { - seqBytes := bucket.Get(paymentSequenceKey) - if seqBytes == nil { - return nil, fmt.Errorf("sequence number not found") - } - - sequenceNum := binary.BigEndian.Uint64(seqBytes) - - // Get the PaymentCreationInfo. - creationInfo, err := fetchCreationInfo(bucket) - if err != nil { - return nil, err - - } - - var htlcs []HTLCAttempt - htlcsBucket := bucket.NestedReadBucket(paymentHtlcsBucket) - if htlcsBucket != nil { - // Get the payment attempts. This can be empty. - htlcs, err = fetchHtlcAttempts(htlcsBucket) - if err != nil { - return nil, err - } - } - - // Get failure reason if available. - var failureReason *FailureReason - b := bucket.Get(paymentFailInfoKey) - if b != nil { - reason := FailureReason(b[0]) - failureReason = &reason - } - +func updatePaymentStatus(payment *MPPayment) { // Go through all HTLCs for this payment, noting whether we have any // settled HTLC, and any still in-flight. var inflight, settled bool - for _, h := range htlcs { + for _, h := range payment.HTLCs { if h.Failure != nil { continue } @@ -366,7 +334,7 @@ func fetchPayment(bucket kvdb.RBucket) (*MPPayment, error) { // If we have no in-flight HTLCs, and the payment failure is set, the // payment is considered failed. - case !inflight && failureReason != nil: + case !inflight && payment.FailureReason != nil: paymentStatus = StatusFailed // Otherwise it is still in flight. @@ -374,13 +342,51 @@ func fetchPayment(bucket kvdb.RBucket) (*MPPayment, error) { paymentStatus = StatusInFlight } - return &MPPayment{ + payment.Status = paymentStatus +} + +func fetchPayment(bucket kvdb.RBucket) (*MPPayment, error) { + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return nil, fmt.Errorf("sequence number not found") + } + + sequenceNum := binary.BigEndian.Uint64(seqBytes) + + // Get the PaymentCreationInfo. + creationInfo, err := fetchCreationInfo(bucket) + if err != nil { + return nil, err + + } + + var htlcs []HTLCAttempt + htlcsBucket := bucket.NestedReadBucket(paymentHtlcsBucket) + if htlcsBucket != nil { + // Get the payment attempts. This can be empty. + htlcs, err = fetchHtlcAttempts(htlcsBucket) + if err != nil { + return nil, err + } + } + + // Get failure reason if available. + var failureReason *FailureReason + b := bucket.Get(paymentFailInfoKey) + if b != nil { + reason := FailureReason(b[0]) + failureReason = &reason + } + + payment := &MPPayment{ SequenceNum: sequenceNum, Info: creationInfo, HTLCs: htlcs, FailureReason: failureReason, - Status: paymentStatus, - }, nil + } + + updatePaymentStatus(payment) + return payment, nil } // fetchHtlcAttempts retrives all htlc attempts made for the payment found in diff --git a/docs/release-notes/release-notes-0.14.0.md b/docs/release-notes/release-notes-0.14.0.md index ac1619312ff..f3731f6c743 100644 --- a/docs/release-notes/release-notes-0.14.0.md +++ b/docs/release-notes/release-notes-0.14.0.md @@ -373,6 +373,12 @@ you. buffer each time we decrypt an incoming message, as we recycle these buffers in the peer. +* [Cache the channel state](https://github.com/lightningnetwork/lnd/pull/5595) + to achieve better performance when running LND using a remote DB backend. + +* [Do not re-fetch payments if we already have them in memory](https://github.com/lightningnetwork/lnd/pull/5769) + in certain cases. + ## Log system * [Save compressed log files from logrorate during diff --git a/kvdb/bolt_test.go b/kvdb/bolt_test.go index a6cec252b7d..a0d04cd23fd 100644 --- a/kvdb/bolt_test.go +++ b/kvdb/bolt_test.go @@ -1,9 +1,11 @@ package kvdb import ( + "fmt" "testing" "github.com/btcsuite/btcwallet/walletdb" + "github.com/stretchr/testify/require" ) func TestBolt(t *testing.T) { @@ -71,14 +73,54 @@ func TestBolt(t *testing.T) { for _, test := range tests { test := test + cache := []bool{false, true} + for _, useCache := range cache { + name := fmt.Sprintf("%v/Cache=%v", test.name, useCache) + t.Run(name, func(t *testing.T) { + t.Parallel() + + f := NewBoltFixture(t) + defer f.Cleanup() + + backend := f.NewBackend() + if useCache { + cache := NewCache(backend, nil, nil) + require.NoError(t, cache.Init()) + backend = cache + } + + test.test(t, backend) + }) + } + } +} + +func TestCacheBolt(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, walletdb.DB) + }{ + { + name: "cache fill", + test: testCacheFill, + }, + { + name: "cache rollback", + test: testCacheRollback, + }, + } + + for _, test := range tests { + test := test t.Run(test.name, func(t *testing.T) { t.Parallel() f := NewBoltFixture(t) defer f.Cleanup() - test.test(t, f.NewBackend()) + backend := f.NewBackend() + test.test(t, backend) }) } } diff --git a/kvdb/cache.go b/kvdb/cache.go new file mode 100644 index 00000000000..3dcc0769019 --- /dev/null +++ b/kvdb/cache.go @@ -0,0 +1,1116 @@ +package kvdb + +import ( + "container/list" + "fmt" + "io" + "sync" + + "github.com/btcsuite/btcwallet/walletdb" + "github.com/google/btree" +) + +const ( + // treeDeg is the degree of the B-trees we use in our cached buckets. + treeDeg = 3 +) + +// cacheBucket is a bucket we use for storing cached items. +type cacheBucket struct { + seq *uint64 + tree *btree.BTree +} + +// newReadThroughCacheBucket creates a cacheBucket that doesn't hold any data +// and is meant to work as a placeholder for top level buckets that we'll only +// allow on-demand read through (fetching all data from the DB every time). +func newReadThroughCacheBucket() *cacheBucket { + return &cacheBucket{} +} + +// newCacheBucket creates a new cacheBucket. +func newCacheBucket() *cacheBucket { + return &cacheBucket{ + tree: btree.New(treeDeg), + } +} + +// cached returns true if this cacheBucket is actually cached. +func (c *cacheBucket) cached() bool { + return c.tree != nil +} + +// get returns the corresponding cachedItem for the key if there's any or nil +// otherwise. +func (c *cacheBucket) get(key []byte) *cachedItem { + if key == nil { + return nil + } + + keyItem := &cachedItem{ + key: string(key), + } + + if valItem := c.tree.Get(keyItem); valItem != nil { + return valItem.(*cachedItem) + } + + return nil +} + +// put inserts or replaces the passed item. +func (c *cacheBucket) put(item *cachedItem) { + c.tree.ReplaceOrInsert(item) +} + +// del removes the passed item. +func (c *cacheBucket) del(item *cachedItem) { + c.tree.Delete(item) +} + +// cachedItem is holder for key/values or buckets that we store in the cache. +type cachedItem struct { + key string + value string + + // bucket is non nil if this item is a bucket. + bucket *cacheBucket +} + +// Less implements a strict ordering operator in order to insert items into the +// cache's b-tree. +func (c *cachedItem) Less(than btree.Item) bool { + return c.key < than.(*cachedItem).key +} + +// pendingChange is a common interface for all pending changes to the cache. +type pendingChange interface { + // Reverts the pending change. + Revert() +} + +// pendingAdd holds a new key/value/bucket added to the cache. +type pendingAdd struct { + parent *cachedItem + newChild *cachedItem +} + +// Revert reverts the cache to the state before the add. +func (p *pendingAdd) Revert() { + p.parent.bucket.del(p.newChild) +} + +// pendingUpdate holds an updated value. +type pendingUpdate struct { + parent *cachedItem + oldValue string +} + +// Revert reverts the cache to the state before the update. +func (p *pendingUpdate) Revert() { + p.parent.value = p.oldValue +} + +// pendingDelete holds a pending deleted key/value/bucket. +type pendingDelete struct { + parent *cachedItem + oldChild *cachedItem +} + +// Revert reverts back the cache to the state before the delete. +func (p *pendingDelete) Revert() { + p.parent.bucket.put(p.oldChild) +} + +// Cache is a simple write through cache implementing the kvdb.Backend +// interface. It's capable of recursively caching top-level buckets speeding up +// reads by reducing roundtrips to the actual DB backend while remaining +// consistend with DB state as long as the cache is used when mutating those +// buckets. It's also able to skip buckets in the tree structure keeping thos +// read/write-through, which is useful when we want to skip large buckets. +type Cache struct { + mx sync.RWMutex + + // skipped tracks buckets that the cache tracks, but never fetches. + skipped map[string]bool + + // topLevelBuckets stores prefetched top-level buckets. + topLevelBuckets []string + + // currRwTx holds the current RW DB transaction. + currRwTx *cacheReadWriteTx + + // root tracks the cache's top-level buckets and the buckets underneath + // them. + root *cachedItem + + // pending holds any pending changes before the transaction commit. + pending *list.List + + // backend it the underlying DB backend. + backend Backend +} + +// Enforce that Cache implements the ExtendedBackend interface. +var _ walletdb.DB = (*Cache)(nil) + +// NewCache constructs a new cache. Top level buckets are recursively read and +// all content is added to the cache. Skipped keys will be skipped on all levels. +func NewCache(backend Backend, topLevelBuckets [][]byte, + skippedKeys [][]byte) *Cache { + + cache := &Cache{ + skipped: make(map[string]bool), + topLevelBuckets: make([]string, len(topLevelBuckets)), + root: &cachedItem{ + bucket: newCacheBucket(), + }, + pending: list.New(), + backend: backend, + } + + for _, skippedKey := range skippedKeys { + cache.skipped[string(skippedKey)] = true + } + + for i, bucket := range topLevelBuckets { + cache.topLevelBuckets[i] = string(bucket) + } + + return cache +} + +// pendingAdd adds a new item to the parent. +func (c *Cache) pendingAdd(parent *cachedItem, newChild *cachedItem) { + c.pending.PushBack(&pendingAdd{parent, newChild}) +} + +// pendingUpdate updates parent with a new value. +func (c *Cache) pendingUpdate(parent *cachedItem, oldValue string) { + c.pending.PushBack(&pendingUpdate{parent, oldValue}) +} + +// pendingDelete deletes the child form the parent. +func (c *Cache) pendingDelete(parent *cachedItem, oldChild *cachedItem) { + c.pending.PushBack(&pendingDelete{parent, oldChild}) +} + +// traversalHelper is just a struct we use to help the recursive traversal of +// top-level buckets when they're added to the cache. +type traversalHelper struct { + bucket walletdb.ReadBucket + root *cachedItem +} + +// scanBucket scans and adds a bucket and its sub-buckets recursively to the +// cache. +func (c *Cache) scanBucket(bucket walletdb.ReadBucket, root *cachedItem) error { + var queue []*traversalHelper + currRoot := root + + for { + err := bucket.ForEach(func(k, v []byte) error { + item := &cachedItem{ + key: string(k), + } + + // This is a value, fetch it. + if v != nil { + item.value = string(v) + currRoot.bucket.put(item) + return nil + } + + // This is a bucket. + if c.skipped[string(k)] { + // Bucket is read-through, no need to fetch it. + item.bucket = newReadThroughCacheBucket() + } else { + // We cache this bucket and its contents. + item.bucket = newCacheBucket() + + bucket := bucket.NestedReadBucket(k) + queue = append( + queue, + &traversalHelper{ + bucket: bucket, + root: item, + }, + ) + + } + currRoot.bucket.put(item) + + return nil + }) + + if err != nil { + return err + } + + if len(queue) == 0 { + break + } + + bucket = queue[0].bucket + currRoot = queue[0].root + queue[0] = nil + queue = queue[1:] + } + + return nil +} + +// addTopLevelBucket recursively reads the passed top-level bucket and adds +// all content below it to the cache. +func (c *Cache) addTopLevelBucket(key []byte) error { + if c.skipped[string(key)] { + c.root.bucket.put(&cachedItem{ + key: string(key), + bucket: newReadThroughCacheBucket(), + }) + + return nil + } + + var root *cachedItem + + if err := View(c.backend, func(tx RTx) error { + bucket := tx.ReadBucket(key) + if bucket == nil { + return nil + } + + root = &cachedItem{ + key: string(key), + bucket: newCacheBucket(), + } + + return c.scanBucket(bucket, root) + }, func() {}); err != nil { + return err + } + + if root != nil { + c.root.bucket.put(root) + } + + return nil +} + +// Wipe wipes the cache state. +func (c *Cache) Wipe() { + c.mx.Lock() + defer c.mx.Unlock() + + c.root = &cachedItem{ + bucket: newCacheBucket(), + } +} + +// Init refetches the tracked top-level buckets. +func (c *Cache) Init() error { + c.mx.Lock() + defer c.mx.Unlock() + + for _, bucket := range c.topLevelBuckets { + if err := c.addTopLevelBucket([]byte(bucket)); err != nil { + return err + } + } + + return nil +} + +func (c *Cache) BeginReadTx() (walletdb.ReadTx, error) { + c.mx.RLock() + + dbTx, err := c.backend.BeginReadTx() + if err != nil { + c.mx.RUnlock() + return nil, err + } + + return newCacheReadTx(c, dbTx), nil +} + +func (c *Cache) BeginReadWriteTx() (walletdb.ReadWriteTx, error) { + c.mx.Lock() + + dbTx, err := c.backend.BeginReadWriteTx() + if err != nil { + c.mx.Unlock() + return nil, err + } + + c.currRwTx = newCacheReadWriteTx(c, dbTx) + return c.currRwTx, nil +} + +func (c *Cache) Copy(w io.Writer) error { + return fmt.Errorf("unavailable") +} + +func (c *Cache) Close() error { + err := c.backend.Close() + c.Wipe() + return err +} + +func (c *Cache) PrintStats() string { + return "" +} + +func (c *Cache) View(f func(tx walletdb.ReadTx) error, reset func()) error { + tx, err := c.BeginReadTx() + if err != nil { + return err + } + + reset() + + err = f(tx) + rollbackErr := tx.Rollback() + + if err != nil { + return err + } + + return rollbackErr +} + +func (c *Cache) Update(f func(tx walletdb.ReadWriteTx) error, + reset func()) error { + + tx, err := c.BeginReadWriteTx() + if err != nil { + return err + } + + reset() + + // Apply the tx closure, rollback on error. + if err := f(tx); err != nil { + _ = tx.Rollback() + return err + } + + // Attempt to commit, rollback on error. Note that since we have + // exclusive access Commit should only fail with database error and + // never with any error that we'd normally retry on. + if err := tx.Commit(); err != nil { + _ = tx.Rollback() + return err + } + + return nil +} + +func (c *Cache) Batch(f func(tx walletdb.ReadWriteTx) error) error { + return c.Update(f, func() {}) +} + +type cacheReadTx struct { + cache *Cache + dbTx walletdb.ReadTx + active bool +} + +var _ walletdb.ReadTx = (*cacheReadTx)(nil) + +func newCacheReadTx(cache *Cache, dbTx walletdb.ReadTx) *cacheReadTx { + return &cacheReadTx{ + cache: cache, + dbTx: dbTx, + active: true, + } +} + +func topLevelReadBucketImpl(dbTx walletdb.ReadTx, + cache *Cache, key []byte) walletdb.ReadBucket { + + if root := cache.root.bucket.get(key); root != nil { + if root.bucket.cached() { + return newCacheReadBucket(dbTx, nil, cache, root) + } else { + // For "read-through" top level buckets we simply + // return a DB ReadBucket that is independent from the + // cached state. + return dbTx.ReadBucket(key) + } + } + + return nil +} + +func forEachBucketImpl(cache *Cache, f func(key []byte) error) error { + cache.root.bucket.tree.Ascend(func(item btree.Item) bool { + c := item.(*cachedItem) + if f([]byte(c.key)) != nil { + return false + } + + return true + }) + + return nil +} + +func (c *cacheReadTx) ReadBucket(key []byte) walletdb.ReadBucket { + return topLevelReadBucketImpl(c.dbTx, c.cache, key) +} + +func (c *cacheReadTx) ForEachBucket(f func(key []byte) error) error { + return forEachBucketImpl(c.cache, f) +} + +func (c *cacheReadTx) Rollback() error { + if c.active { + defer func() { + c.active = false + c.cache.mx.RUnlock() + }() + + return c.dbTx.Rollback() + } + + return nil +} + +type cacheReadWriteTx struct { + cache *Cache + dbTx walletdb.ReadWriteTx + active bool + onCommit func() +} + +var _ walletdb.ReadWriteTx = (*cacheReadWriteTx)(nil) + +func newCacheReadWriteTx(cache *Cache, + dbTx walletdb.ReadWriteTx) *cacheReadWriteTx { + + return &cacheReadWriteTx{ + cache: cache, + dbTx: dbTx, + active: true, + } +} + +func (c *cacheReadWriteTx) ReadBucket(key []byte) walletdb.ReadBucket { + return topLevelReadBucketImpl(c.dbTx, c.cache, key) +} + +func (c *cacheReadWriteTx) ForEachBucket(f func(key []byte) error) error { + return forEachBucketImpl(c.cache, f) +} + +func (c *cacheReadWriteTx) Rollback() error { + if c.active { + defer func() { + c.active = false + c.cache.mx.Unlock() + }() + + // First revert changes to the cache itself. + for e := c.cache.pending.Back(); e != nil; e = e.Prev() { + e.Value.(pendingChange).Revert() + } + + // Now that we got back the old cache state, we can reset the + // pending change list and revert the DB transaction. + c.cache.pending = list.New() + return c.dbTx.Rollback() + } + + return nil +} + +func (c *cacheReadWriteTx) ReadWriteBucket(key []byte) walletdb.ReadWriteBucket { + root := c.cache.root.bucket.get(key) + + // Bucket is not known. + if root == nil { + return nil + } + + dbBucket := c.dbTx.ReadWriteBucket(key) + if dbBucket == nil { + return nil + } + + // We cache the bucket state. + if root.bucket.cached() { + return newCacheReadWriteBucket(c.cache, root, dbBucket) + } + + // Read-through bucket. + return dbBucket +} + +func (c *cacheReadWriteTx) CreateTopLevelBucket(key []byte) ( + walletdb.ReadWriteBucket, error) { + + // First we need to make sure the DB is able to find/create this top + // level bucket. + dbBucket, err := c.dbTx.CreateTopLevelBucket(key) + if err != nil { + return nil, err + } + + // Now check if we already track this bucket in the cache. + root := c.cache.root.bucket.get(key) + if root != nil { + // Bucket is tracked and contents are cached too. + if root.bucket.cached() { + return newCacheReadWriteBucket( + c.cache, root, dbBucket, + ), nil + } + + // Bucket is tracked but we don't cache the contents. + return dbBucket, nil + } + + // Bucket is not yet tracked, we need to add it to the cache. + root = &cachedItem{ + key: string(key), + bucket: newCacheBucket(), + } + + c.cache.root.bucket.put(root) + c.cache.pendingAdd(c.cache.root, root) + + return newCacheReadWriteBucket(c.cache, root, dbBucket), nil +} + +func (c *cacheReadWriteTx) DeleteTopLevelBucket(key []byte) error { + if err := c.dbTx.DeleteTopLevelBucket(key); err != nil { + return err + } + + return deleteFromCache(c.cache, c.cache.root, key, true) +} + +func (c *cacheReadWriteTx) Commit() error { + if err := c.dbTx.Commit(); err != nil { + return err + } + + defer func() { + c.active = false + c.cache.mx.Unlock() + }() + + c.cache.pending = list.New() + if c.onCommit != nil { + c.onCommit() + } + + return nil +} + +func (c *cacheReadWriteTx) OnCommit(f func()) { + c.onCommit = f +} + +func forEachImpl(root *cachedItem, f func(k, v []byte) error) error { + var err error + + root.bucket.tree.Ascend(func(item btree.Item) bool { + c := item.(*cachedItem) + var val []byte + + if c.bucket == nil { + val = []byte(c.value) + } + + if err = f([]byte(c.key), val); err != nil { + return false + } + + return true + }) + + return err +} + +func getImpl(root *cachedItem, key []byte) []byte { + cacheItem := root.bucket.get(key) + if cacheItem != nil { + if cacheItem.bucket == nil { + return []byte(cacheItem.value) + } + + return nil + } + + return nil +} + +// cacheReadBucket is a walletdb.ReadBucket compatible bucket implementation +// operating on already cached values or reading on demand for skipped +// (read-through) sub buckets. +type cacheReadBucket struct { + parent *cacheReadBucket + cache *Cache + root *cachedItem + dbTx walletdb.ReadTx + + // dbBucket tracks the DB ReadBucket and is intentionally nil and only + // fetched if needed. + dbBucket walletdb.ReadBucket +} + +var _ walletdb.ReadBucket = (*cacheReadBucket)(nil) + +func newCacheReadBucket(dbTx walletdb.ReadTx, parent *cacheReadBucket, + cache *Cache, root *cachedItem) *cacheReadBucket { + + return &cacheReadBucket{ + parent: parent, + cache: cache, + root: root, + dbTx: dbTx, + } +} + +// fetchBucet is a helper function to "fetch" all DB buckets up to the top from +// the current one (if not yet fetched). This is necessary when using +// "read-through" buckets. The return value the DB ReadBucket for this +// cacheReadBucket. +func (c *cacheReadBucket) fetchBucket() walletdb.ReadBucket { + if c.dbBucket != nil { + return c.dbBucket + } + + if c.parent != nil { + c.dbBucket = c.parent.fetchBucket().NestedReadBucket( + []byte(c.root.key), + ) + } else { + // This is a top level ReadBucket. + c.dbBucket = c.dbTx.ReadBucket([]byte(c.root.key)) + } + + return c.dbBucket +} + +func (c *cacheReadBucket) NestedReadBucket(key []byte) walletdb.ReadBucket { + if root := c.root.bucket.get(key); root != nil { + if root.bucket != nil { + if root.bucket.cached() { + return newCacheReadBucket( + c.dbTx, c, c.cache, root, + ) + } else { + return c.fetchBucket().NestedReadBucket(key) + } + } + } + + return nil +} + +func (c *cacheReadBucket) ForEach(f func(k, v []byte) error) error { + return forEachImpl(c.root, f) +} + +func (c *cacheReadBucket) Get(key []byte) []byte { + return getImpl(c.root, key) +} + +func (c *cacheReadBucket) ReadCursor() walletdb.ReadCursor { + return newCacheReadCursor(c) +} + +func deleteFromCache(cache *Cache, root *cachedItem, key []byte, + bucket bool) error { + + if cacheItem := root.bucket.get(key); cacheItem != nil { + // Sanity checks. + if bucket && cacheItem.bucket == nil { + return walletdb.ErrIncompatibleValue + } + + if !bucket && cacheItem.bucket != nil { + return walletdb.ErrIncompatibleValue + } + + cache.pendingDelete(root, cacheItem) + root.bucket.del(cacheItem) + } + + return nil + +} + +// cacheReadWriteBucket is a walletdb.ReadWriteBucket compatible bucket +// implementation operating on already cached values or reading on demand for +// skipped (read-through) sub buckets. Updates on this bucket or sub buckets +// will be added to the cache unless in a skipped (write-through) bucket. +type cacheReadWriteBucket struct { + cache *Cache + root *cachedItem + dbBucket walletdb.ReadWriteBucket +} + +var _ walletdb.ReadWriteBucket = (*cacheReadWriteBucket)(nil) + +func newCacheReadWriteBucket(cache *Cache, root *cachedItem, + dbBucket walletdb.ReadWriteBucket) *cacheReadWriteBucket { + + return &cacheReadWriteBucket{ + cache: cache, + root: root, + dbBucket: dbBucket, + } +} + +func (c *cacheReadWriteBucket) NestedReadBucket(key []byte) walletdb.ReadBucket { + return c.NestedReadWriteBucket(key) +} + +func (c *cacheReadWriteBucket) ForEach(f func(k, v []byte) error) error { + return forEachImpl(c.root, f) +} + +func (c *cacheReadWriteBucket) Get(key []byte) []byte { + return getImpl(c.root, key) +} + +func (c *cacheReadWriteBucket) ReadCursor() walletdb.ReadCursor { + return newCacheReadWriteCursor(c) +} + +func (c *cacheReadWriteBucket) NestedReadWriteBucket( + key []byte) walletdb.ReadWriteBucket { + + if root := c.root.bucket.get(key); root != nil { + if root.bucket == nil { + return nil + } + + dbBucket := c.dbBucket.NestedReadWriteBucket(key) + + // The bucket is cached. + if dbBucket != nil && root.bucket.cached() { + return newCacheReadWriteBucket(c.cache, root, dbBucket) + } + + return dbBucket + } + + return nil +} + +func (c *cacheReadWriteBucket) createBucketImpl(key []byte) ( + walletdb.ReadWriteBucket, error) { + + dbBucket, err := c.dbBucket.CreateBucket(key) + if err != nil { + return nil, err + } + + root := &cachedItem{ + key: string(key), + } + + skipped := c.cache.skipped[string(key)] + if skipped { + // We add the bucket reference even though we'll be reading + // through it. + root.bucket = newReadThroughCacheBucket() + } else { + root.bucket = newCacheBucket() + } + + c.root.bucket.put(root) + c.cache.pendingAdd(c.root, root) + + if !skipped { + return newCacheReadWriteBucket(c.cache, root, dbBucket), nil + } + + return dbBucket, nil +} + +func (c *cacheReadWriteBucket) CreateBucket(key []byte) ( + walletdb.ReadWriteBucket, error) { + + if root := c.root.bucket.get(key); root != nil { + return nil, ErrBucketExists + } + + return c.createBucketImpl(key) +} + +func (c *cacheReadWriteBucket) CreateBucketIfNotExists(key []byte) ( + walletdb.ReadWriteBucket, error) { + + dbBucket, err := c.dbBucket.CreateBucketIfNotExists(key) + if err != nil { + return nil, err + } + + // Return existing bucket if exists. + if root := c.root.bucket.get(key); root != nil { + if root.bucket == nil { + return nil, walletdb.ErrIncompatibleValue + } + + if root.bucket.cached() { + return newCacheReadWriteBucket( + c.cache, root, dbBucket, + ), nil + } + + return dbBucket, nil + } + + // Insert new bucket otherwise. + root := &cachedItem{ + key: string(key), + } + + // We do add this new bucket reference even if though we won't cache + // its contents. + skipped := c.cache.skipped[string(key)] + if skipped { + root.bucket = newReadThroughCacheBucket() + } else { + root.bucket = newCacheBucket() + } + + c.root.bucket.put(root) + c.cache.pendingAdd(c.root, root) + + if !skipped { + return newCacheReadWriteBucket(c.cache, root, dbBucket), nil + } + + return dbBucket, nil +} + +func (c *cacheReadWriteBucket) DeleteNestedBucket(key []byte) error { + if err := c.dbBucket.DeleteNestedBucket(key); err != nil { + return err + } + + return deleteFromCache(c.cache, c.root, key, true) +} + +func (c *cacheReadWriteBucket) Put(key, value []byte) error { + if err := c.dbBucket.Put(key, value); err != nil { + return err + } + + if cacheItem := c.root.bucket.get(key); cacheItem != nil { + if cacheItem.bucket != nil { + return walletdb.ErrIncompatibleValue + } + + c.cache.pendingUpdate(cacheItem, cacheItem.value) + cacheItem.value = string(value) + } else { + newItem := &cachedItem{ + key: string(key), + value: string(value), + } + c.root.bucket.put(newItem) + c.cache.pendingAdd(c.root, newItem) + } + + return nil +} + +func (c *cacheReadWriteBucket) Delete(key []byte) error { + if err := c.dbBucket.Delete(key); err != nil { + return err + } + + return deleteFromCache(c.cache, c.root, key, false) +} + +func (c *cacheReadWriteBucket) ReadWriteCursor() walletdb.ReadWriteCursor { + return newCacheReadWriteCursor(c) +} + +func (c *cacheReadWriteBucket) Tx() walletdb.ReadWriteTx { + return c.cache.currRwTx +} + +func (c *cacheReadWriteBucket) NextSequence() (uint64, error) { + next, err := c.dbBucket.NextSequence() + if err != nil { + return 0, err + } + + c.root.bucket.seq = &next + return *c.root.bucket.seq, nil +} + +func (c *cacheReadWriteBucket) SetSequence(v uint64) error { + if err := c.dbBucket.SetSequence(v); err != nil { + return err + } + + c.root.bucket.seq = &v + return nil +} + +func (c *cacheReadWriteBucket) Sequence() uint64 { + if c.root.bucket.seq == nil { + seq := c.dbBucket.Sequence() + c.root.bucket.seq = &seq + } + + return *c.root.bucket.seq +} + +// cacheCursor implements common functions used in the cacheReadCursor and +// cacheReadWriteCursor, technically implementing the walletdb.ReadWriteCursor +// for cached buckets. +type cacheCursor struct { + root *cachedItem + currKey string +} + +func (c *cacheCursor) First() (key, value []byte) { + valItem := c.root.bucket.tree.Min() + if valItem != nil { + cacheItem := valItem.(*cachedItem) + c.currKey = cacheItem.key + if cacheItem.bucket == nil { + value = []byte(cacheItem.value) + } + + return []byte(cacheItem.key), value + } + + return nil, nil +} + +func (c *cacheCursor) Last() (key, value []byte) { + valItem := c.root.bucket.tree.Max() + if valItem != nil { + cacheItem := valItem.(*cachedItem) + c.currKey = cacheItem.key + + if cacheItem.bucket == nil { + value = []byte(cacheItem.value) + } + + return []byte(cacheItem.key), value + } + + return nil, nil +} + +func (c *cacheCursor) next(seekKey string, includeSeekKey bool) ( + key, value []byte) { + + keyItem := &cachedItem{ + key: seekKey, + } + + c.root.bucket.tree.AscendGreaterOrEqual( + keyItem, + func(nextItem btree.Item) bool { + cacheItem := nextItem.(*cachedItem) + if !includeSeekKey && cacheItem.key == seekKey { + return true + } + + key = []byte(cacheItem.key) + if cacheItem.bucket == nil { + value = []byte(cacheItem.value) + } + + return false + }, + ) + + if key != nil { + c.currKey = string(key) + } + + return key, value +} + +func (c *cacheCursor) Next() (key, value []byte) { + return c.next(c.currKey, false) +} + +func (c *cacheCursor) Seek(seek []byte) (key, value []byte) { + return c.next(string(seek), true) +} + +func (c *cacheCursor) Prev() ([]byte, []byte) { + keyItem := &cachedItem{ + key: c.currKey, + } + + var key, value []byte + c.root.bucket.tree.DescendLessOrEqual( + keyItem, + func(nextItem btree.Item) bool { + cacheItem := nextItem.(*cachedItem) + if cacheItem.key == c.currKey { + return true + } + + key = []byte(cacheItem.key) + if cacheItem.bucket == nil { + value = []byte(cacheItem.value) + } + + return false + }, + ) + + if key != nil { + c.currKey = string(key) + } + + return key, value +} + +// cacheReadCursor is a walletdb.ReadCursor compatible cursor for +// cached buckets. +type cacheReadCursor struct { + cacheCursor +} + +var _ walletdb.ReadCursor = (*cacheReadCursor)(nil) + +func newCacheReadCursor(cacheBucket *cacheReadBucket) *cacheReadCursor { + return &cacheReadCursor{ + cacheCursor: cacheCursor{ + root: cacheBucket.root, + }, + } +} + +// cacheReadWriteCursor is a walletdb.ReadWriteCursor compatible cursor for +// cached buckets. +type cacheReadWriteCursor struct { + cacheCursor + cacheBucket *cacheReadWriteBucket +} + +var _ walletdb.ReadWriteCursor = (*cacheReadWriteCursor)(nil) + +func newCacheReadWriteCursor( + cacheBucket *cacheReadWriteBucket) *cacheReadWriteCursor { + + return &cacheReadWriteCursor{ + cacheCursor: cacheCursor{ + root: cacheBucket.root, + }, + cacheBucket: cacheBucket, + } +} + +func (c *cacheReadWriteCursor) Delete() error { + return c.cacheBucket.Delete([]byte(c.currKey)) +} diff --git a/kvdb/cache_test.go b/kvdb/cache_test.go new file mode 100644 index 00000000000..f6a11a4841b --- /dev/null +++ b/kvdb/cache_test.go @@ -0,0 +1,200 @@ +package kvdb + +import ( + "fmt" + "testing" + + "github.com/btcsuite/btcwallet/walletdb" + "github.com/stretchr/testify/require" +) + +func testCacheFill(t *testing.T, db Backend) { + data := map[string]interface{}{ + "apple": map[string]interface{}{ + "a1": "av1", + "a2": "av2", + "banana": map[string]interface{}{ + "ab1": "abv1", + }, + }, + "banana": map[string]interface{}{ + "b1": "bv1", + }, + "coconut": map[string]interface{}{ + "c1": "cv1", + }, + } + + require.NoError(t, FillDB(db, data)) + + topLevelBuckets := [][]byte{[]byte("apple"), []byte("banana")} + // Skipping bucket with the name banana. + skippedKeys := [][]byte{[]byte("banana")} + + cache := NewCache(db, topLevelBuckets, skippedKeys) + require.NoError(t, cache.Init()) + + // Update the banana buckets so we can ensure the cache will fetch all + // values from the DB. + Update(cache, func(tx walletdb.ReadWriteTx) error { + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + require.NoError(t, err) + + // Create new rw bucket inside a cached bucket. + peach, err := apple.CreateBucketIfNotExists([]byte("peach")) + require.NoError(t, err) + peach.Put([]byte("ap1"), []byte("apv1")) + + banana := apple.NestedReadWriteBucket([]byte("banana")) + require.NotNil(t, banana) + banana.Put([]byte("ab2"), []byte("abv2")) + + banana, err = tx.CreateTopLevelBucket([]byte("banana")) + require.NoError(t, err) + + // Put a new value inisde a skipped bucket. + banana.Put([]byte("b2"), []byte("bv2")) + + // Crate a new rw bucket inside a write through bucket. + pear, err := banana.CreateBucketIfNotExists([]byte("pear")) + require.NoError(t, err) + pear.Put([]byte("bp1"), []byte("bpv1")) + + return nil + }, func() {}) + + // Now read back using the kvdb interface. + View(cache, func(tx walletdb.ReadTx) error { + apple := tx.ReadBucket([]byte("apple")) + require.NotNil(t, apple) + + peach := apple.NestedReadBucket([]byte("peach")) + require.NotNil(t, peach) + require.Equal(t, []byte("apv1"), peach.Get([]byte("ap1"))) + + banana := apple.NestedReadBucket([]byte("banana")) + require.NotNil(t, banana) + require.Equal(t, []byte("abv2"), banana.Get([]byte("ab2"))) + + banana = tx.ReadBucket([]byte("banana")) + require.NotNil(t, banana) + + require.Equal(t, []byte("bv2"), banana.Get([]byte("b2"))) + + pear := banana.NestedReadBucket([]byte("pear")) + require.NotNil(t, pear) + require.Equal(t, []byte("bpv1"), pear.Get([]byte("bp1"))) + + return nil + }, func() {}) + + expected := map[string]interface{}{ + "apple": map[string]interface{}{ + "a1": "av1", + "a2": "av2", + "banana": map[string]interface{}{ + "ab1": "abv1", + "ab2": "abv2", + }, + "peach": map[string]interface{}{ + "ap1": "apv1", + }, + }, + "banana": map[string]interface{}{ + "b1": "bv1", + "b2": "bv2", + "pear": map[string]interface{}{ + "bp1": "bpv1", + }, + }, + } + + // Verify that both the cache and the DB has all data + // we expect. + require.NoError(t, VerifyDB(cache, expected)) + require.NoError(t, VerifyDB(db, expected)) + + // Now wipe all data. + cache.Wipe() + empty := make(map[string]interface{}) + require.NoError(t, VerifyDB(cache, empty)) + + // We still expect everything in the DB. + require.NoError(t, VerifyDB(db, expected)) + require.NoError(t, cache.Close()) +} + +func testCacheRollback(t *testing.T, db Backend) { + data := map[string]interface{}{ + "apple": map[string]interface{}{ + "a1": "av1", + "a2": "av2", + "banana": map[string]interface{}{ + "ab1": "abv1", + }, + }, + } + require.NoError(t, FillDB(db, data)) + + cache := NewCache(db, [][]byte{[]byte("apple")}, nil) + require.NoError(t, cache.Init()) + require.NoError(t, VerifyDB(cache, data)) + + update := func(tx RwTx) error { + coconut, err := tx.CreateTopLevelBucket([]byte("coconut")) + require.NoError(t, err) + coconut.Put([]byte("key"), []byte("val")) + + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + require.NoError(t, err) + + // Add a new key. + apple.Put([]byte("key"), []byte("val")) + + // Delete an existing key. + apple.Delete([]byte("a1")) + + // Update an existing key. + apple.Put([]byte("a2"), []byte("new")) + + banana := apple.NestedReadWriteBucket([]byte("banana")) + require.NotNil(t, banana) + + banana.Delete([]byte("ab1")) + banana.Put([]byte("ab2"), []byte("abv2")) + + ab1, err := banana.CreateBucket([]byte("ab1")) + require.NoError(t, err) + ab1.Put([]byte("key"), []byte("val")) + + nested, err := banana.CreateBucket([]byte("nested")) + require.NoError(t, err) + + nested.Put([]byte("n1"), []byte("nv1")) + + apple.DeleteNestedBucket([]byte("banana")) + tx.DeleteTopLevelBucket([]byte("apple")) + + return nil + } + + // Check rollback with manual txn. + tx, err := cache.BeginReadWriteTx() + require.NoError(t, err) + update(tx) + require.NoError(t, tx.Rollback()) + + require.NoError(t, VerifyDB(cache, data)) + require.NoError(t, VerifyDB(db, data)) + + // Check rollback with closed form txn failing. + require.Error(t, cache.Update(func(tx RwTx) error { + require.NoError(t, update(tx)) + return fmt.Errorf("fail") + }, func() {})) + + require.NoError(t, VerifyDB(cache, data)) + require.NoError(t, VerifyDB(db, data)) + + require.NoError(t, cache.Close()) +} diff --git a/kvdb/etcd_test.go b/kvdb/etcd_test.go index aae1653188a..645a754b552 100644 --- a/kvdb/etcd_test.go +++ b/kvdb/etcd_test.go @@ -150,6 +150,56 @@ func TestEtcd(t *testing.T) { continue } + rwLock := []bool{false, true} + for _, doRwLock := range rwLock { + name := fmt.Sprintf("%v/RWLock=%v", test.name, doRwLock) + + cache := []bool{false, true} + for _, useCache := range cache { + name := fmt.Sprintf("%v/Cache=%v", name, useCache) + t.Run(name, func(t *testing.T) { + t.Parallel() + + f := etcd.NewEtcdTestFixture(t) + defer f.Cleanup() + + backend := f.NewBackend(doRwLock) + if useCache { + cache := NewCache( + backend, nil, nil, + ) + require.NoError(t, cache.Init()) + backend = cache + } + test.test(t, backend) + + if test.expectedDb != nil { + dump := f.Dump() + require.Equal(t, test.expectedDb, dump) + } + }) + } + } + } +} + +func TestCacheEtcd(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, walletdb.DB) + }{ + { + name: "cache fill", + test: testCacheFill, + }, + { + name: "cache rollback", + test: testCacheRollback, + }, + } + + for _, test := range tests { + test := test rwLock := []bool{false, true} for _, doRwLock := range rwLock { name := fmt.Sprintf("%v/RWLock=%v", test.name, doRwLock) @@ -160,12 +210,8 @@ func TestEtcd(t *testing.T) { f := etcd.NewEtcdTestFixture(t) defer f.Cleanup() - test.test(t, f.NewBackend(doRwLock)) - - if test.expectedDb != nil { - dump := f.Dump() - require.Equal(t, test.expectedDb, dump) - } + backend := f.NewBackend(doRwLock) + test.test(t, backend) }) } } diff --git a/kvdb/go.sum b/kvdb/go.sum index 2573f716493..f34d66fe239 100644 --- a/kvdb/go.sum +++ b/kvdb/go.sum @@ -143,6 +143,8 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= diff --git a/kvdb/test.go b/kvdb/test.go index 483862e939c..858e6928baa 100644 --- a/kvdb/test.go +++ b/kvdb/test.go @@ -1,5 +1,10 @@ package kvdb +import ( + "bytes" + "fmt" +) + type KV struct { key string val string @@ -12,3 +17,131 @@ func reverseKVs(a []KV) []KV { return a } + +// FillDB fills the passed db with the passed nested data. If a passed map value +// is string, then it'll inserted as map value, otherwise as a subbucket. +func FillDB(db Backend, data map[string]interface{}) error { + return Update(db, func(tx RwTx) error { + for key, val := range data { + bucket, err := tx.CreateTopLevelBucket([]byte(key)) + if err != nil { + return err + } + + m, ok := val.(map[string]interface{}) + if !ok { + return fmt.Errorf("invalid top level bucket: "+ + "%v", key) + } + + if err := fillBucket(bucket, m); err != nil { + return err + } + } + + return nil + }, func() {}) +} + +func fillBucket(bucket RwBucket, data map[string]interface{}) error { + for k, v := range data { + switch value := v.(type) { + + // Key contains value. + case string: + err := bucket.Put([]byte(k), []byte(value)) + if err != nil { + return err + } + + // Key contains a sub-bucket. + case map[string]interface{}: + subBucket, err := bucket.CreateBucket([]byte(k)) + if err != nil { + return err + } + + if err := fillBucket(subBucket, value); err != nil { + return err + } + + default: + return fmt.Errorf("invalid value type: %T, for key: %v", + k, value) + } + } + + return nil +} + +// VerifyDB verifies the database against the given data set. +func VerifyDB(db Backend, data map[string]interface{}) error { + return View(db, func(tx RTx) error { + for key, val := range data { + bucket := tx.ReadBucket([]byte(key)) + if bucket == nil { + return fmt.Errorf("top level bucket %v not "+ + "found", key) + } + + m, ok := val.(map[string]interface{}) + if !ok { + return fmt.Errorf("invalid top level bucket: "+ + "%v", key) + } + + if err := verifyBucket(bucket, m); err != nil { + return err + } + } + + return nil + }, func() {}) +} + +func verifyBucket(bucket RBucket, data map[string]interface{}) error { + for k, v := range data { + switch value := v.(type) { + + // Key contains value. + case string: + dbVal := bucket.Get([]byte(k)) + if !bytes.Equal(dbVal, []byte(value)) { + return fmt.Errorf("value mismatch. Key: %v, "+ + "val: %v, expected: %v", k, dbVal, value) + } + + // Key contains a sub-bucket. + case map[string]interface{}: + subBucket := bucket.NestedReadBucket([]byte(k)) + if subBucket == nil { + return fmt.Errorf("bucket %v not found", k) + } + + err := verifyBucket(subBucket, value) + if err != nil { + return err + } + + default: + return fmt.Errorf("invalid value type: %T for key: %v", + value, k) + } + } + + keyCount := 0 + err := bucket.ForEach(func(k, v []byte) error { + keyCount++ + return nil + }) + if err != nil { + return err + } + + if keyCount != len(data) { + return fmt.Errorf("unexpected keys in database, got: %v, "+ + "expected: %v", keyCount, len(data)) + } + + return nil +} diff --git a/lnd.go b/lnd.go index 23f9b835ead..93f90ef3a70 100644 --- a/lnd.go +++ b/lnd.go @@ -1685,6 +1685,10 @@ func initializeDatabases(ctx context.Context, channeldb.OptionSetChannelCacheSize(cfg.Caches.ChannelCacheSize), channeldb.OptionSetBatchCommitInterval(cfg.DB.BatchCommitInterval), channeldb.OptionDryRunMigration(cfg.DryRunMigration), + channeldb.OptionWithChannelStateCache( + // Cache channel state when not running on Bolt. + cfg.DB.Backend != lncfg.BoltBackend, + ), } // We want to pre-allocate the channel graph cache according to what we