diff --git a/src/main/groovy/io/seqera/wave/store/cache/AbstractTieredCache.groovy b/src/main/groovy/io/seqera/wave/store/cache/AbstractTieredCache.groovy index 96fe7ded2b..2d44f9f473 100644 --- a/src/main/groovy/io/seqera/wave/store/cache/AbstractTieredCache.groovy +++ b/src/main/groovy/io/seqera/wave/store/cache/AbstractTieredCache.groovy @@ -32,6 +32,7 @@ import com.github.benmanes.caffeine.cache.RemovalCause import com.github.benmanes.caffeine.cache.RemovalListener import groovy.transform.Canonical import groovy.transform.CompileStatic +import groovy.transform.Memoized import groovy.transform.ToString import groovy.util.logging.Slf4j import io.seqera.serde.encode.StringEncodingStrategy @@ -107,17 +108,45 @@ abstract class AbstractTieredCache implements Ti } } - abstract int getMaxSize() + abstract protected int getMaxSize() abstract protected getName() abstract protected String getPrefix() + /** + * The cache probabilistic revalidation internal. + * + * See https://blog.cloudflare.com/sometimes-i-cache/ + * + * @return + * The cache cache revalidation internal as a {@link Duration} value. + * When {@link Duration#ZERO} probabilistic revalidation is disabled. + */ + protected Duration getCacheRevalidationInterval() { + return Duration.ZERO + } + + /** + * The cache probabilistic revalidation steepness value. + * + * By default is implemented as 1 / {@link #getCacheRevalidationInterval()} (as millis). + * Subclasses can override this method to provide a different value. + * + * See https://blog.cloudflare.com/sometimes-i-cache/ + * + * @return Returns the revalidation steepness value. + */ + @Memoized + protected double getRevalidationSteepness() { + return 1 / getCacheRevalidationInterval().toMillis() + } + private RemovalListener removalListener0() { new RemovalListener() { @Override void onRemoval(@Nullable key, @Nullable value, RemovalCause cause) { - if( log.isTraceEnabled( )) { + if( log.isTraceEnabled() ) { log.trace "Cache '${name}' removing key=$key; value=$value; cause=$cause" } } @@ -186,47 +215,59 @@ abstract class AbstractTieredCache implements Ti private V getOrCompute0(String key, Function> loader) { assert key!=null, "Argument key cannot be null" - if( log.isTraceEnabled() ) log.trace "Cache '${name}' checking key=$key" + final ts = Instant.now() // Try L1 cache first - V value = l1Get(key) - if (value != null) { + Entry entry = l1Get(key) + Boolean needsRevalidation = entry ? shouldRevalidate(entry.expiresAt, ts) : null + if( entry && !needsRevalidation ) { if( log.isTraceEnabled() ) - log.trace "Cache '${name}' L1 hit (a) - key=$key => value=$value" - return value + log.trace "Cache '${name}' L1 hit (a) - key=$key => entry=$entry" + return (V) entry.value } final sync = locks.get(key) sync.lock() try { - value = l1Get(key) - if (value != null) { + // check again L1 cache once in the sync block + if( !entry ) { + entry = l1Get(key) + needsRevalidation = entry ? shouldRevalidate(entry.expiresAt, ts) : null + } + if( entry && !needsRevalidation ) { if( log.isTraceEnabled() ) - log.trace "Cache '${name}' L1 hit (b) - key=$key => value=$value" - return value + log.trace "Cache '${name}' L1 hit (b) - key=$key => entry=$entry" + return (V)entry.value } // Fallback to L2 cache - final entry = l2GetEntry(key) - if (entry != null) { + if( !entry ) { + entry = l2Get(key) + needsRevalidation = entry ? shouldRevalidate(entry.expiresAt, ts) : null + } + if( entry && !needsRevalidation ) { if( log.isTraceEnabled() ) - log.trace "Cache '${name}' L2 hit - key=$key => entry=$entry" + log.trace "Cache '${name}' L2 hit (c) - key=$key => entry=$entry" // Rehydrate L1 cache l1.put(key, entry) return (V) entry.value } - // still not value found, use loader function to fetch the value - if( value==null && loader!=null ) { - if( log.isTraceEnabled() ) + // still not entry found or cache revalidation needed + // use the loader function to fetch the value + V value = null + if( loader!=null ) { + if( entry && needsRevalidation ) + log.debug "Cache '${name}' invoking loader - entry=$entry needs refresh" + else if( log.isTraceEnabled() ) log.trace "Cache '${name}' invoking loader - key=$key" final ret = loader.apply(key) value = ret?.v1 Duration ttl = ret?.v2 if( value!=null && ttl!=null ) { final exp = Instant.now().plus(ttl).toEpochMilli() - final newEntry = new Entry(value,exp) + final newEntry = new Entry(value, exp) l1Put(key, newEntry) l2Put(key, newEntry, ttl) } @@ -256,28 +297,15 @@ abstract class AbstractTieredCache implements Ti protected String key0(String k) { return getPrefix() + ':' + k } - protected V l1Get(String key) { - return (V) l1GetEntry(key)?.value - } - - protected Entry l1GetEntry(String key) { - final entry = l1.getIfPresent(key) - if( entry == null ) - return null - - if( System.currentTimeMillis() > entry.expiresAt ) { - if( log.isTraceEnabled() ) - log.trace "Cache '${name}' L1 expired - key=$key => entry=$entry" - return null - } - return entry + protected Entry l1Get(String key) { + return l1.getIfPresent(key) } protected void l1Put(String key, Entry entry) { l1.put(key, entry) } - protected Entry l2GetEntry(String key) { + protected Entry l2Get(String key) { if( l2 == null ) return null @@ -285,18 +313,9 @@ abstract class AbstractTieredCache implements Ti if( raw == null ) return null - final Entry entry = encoder.decode(raw) - if( System.currentTimeMillis() > entry.expiresAt ) { - if( log.isTraceEnabled() ) - log.trace "Cache '${name}' L2 expired - key=$key => value=${entry}" - return null - } - return entry + return encoder.decode(raw) } - protected V l2Get(String key) { - return (V) l2GetEntry(key)?.value - } protected void l2Put(String key, Entry entry, Duration ttl) { if( l2 != null ) { @@ -309,4 +328,30 @@ abstract class AbstractTieredCache implements Ti l1.invalidateAll() } + protected boolean shouldRevalidate(long expiration, Instant time=Instant.now()) { + // when 'remainingCacheTime' is less than or equals to zero, it means + // the current time is beyond the expiration time, therefore a cache validation is needed + final remainingCacheTime = expiration - time.toEpochMilli() + if (remainingCacheTime <= 0) { + return true + } + + // otherwise, when remaining is greater than the cache revalidation interval + // no revalidation is needed + final cacheRevalidationMills = cacheRevalidationInterval.toMillis() + if( cacheRevalidationMills < remainingCacheTime ) { + return false + } + + // finally the remaining time is shorter the validation interval + // i.e. it's approaching the cache expiration, in this cache the needed + // for cache revalidation is determined in a probabilistic manner + // see https://blog.cloudflare.com/sometimes-i-cache/ + return randomRevalidate(cacheRevalidationMills-remainingCacheTime) + } + + protected boolean randomRevalidate(long remainingTime) { + return Math.random() < Math.exp(-revalidationSteepness * remainingTime) + } + } diff --git a/src/main/groovy/io/seqera/wave/tower/client/cache/ClientCache.groovy b/src/main/groovy/io/seqera/wave/tower/client/cache/ClientCache.groovy index 5a8dbcd95c..9f5cf9d245 100644 --- a/src/main/groovy/io/seqera/wave/tower/client/cache/ClientCache.groovy +++ b/src/main/groovy/io/seqera/wave/tower/client/cache/ClientCache.groovy @@ -66,7 +66,7 @@ class ClientCache extends AbstractTieredCache { } @Override - int getMaxSize() { + protected int getMaxSize() { return maxSize } diff --git a/src/test/groovy/io/seqera/wave/store/cache/AbstractTieredCacheTest.groovy b/src/test/groovy/io/seqera/wave/store/cache/AbstractTieredCacheTest.groovy index 2fa049da2f..53e5a98369 100644 --- a/src/test/groovy/io/seqera/wave/store/cache/AbstractTieredCacheTest.groovy +++ b/src/test/groovy/io/seqera/wave/store/cache/AbstractTieredCacheTest.groovy @@ -18,7 +18,10 @@ package io.seqera.wave.store.cache +import spock.lang.Retry + import java.time.Duration +import java.time.Instant import com.squareup.moshi.JsonAdapter import com.squareup.moshi.adapters.PolymorphicJsonAdapterFactory @@ -99,7 +102,7 @@ class AbstractTieredCacheTest extends Specification implements RedisTestContaine cache1.put(k, value, TTL) then: - def entry1 = cache1.l1GetEntry(k) + def entry1 = cache1.l1Get(k) and: entry1.expiresAt > begin then: @@ -220,4 +223,83 @@ class AbstractTieredCacheTest extends Specification implements RedisTestContaine cache.get(k2) == null } + def 'should validate revalidation logic' () { + given: + def REVALIDATION_INTERVAL_SECS = 10 + def now = Instant.now() + def cache = Spy(MyCache) + cache.getCacheRevalidationInterval() >> Duration.ofSeconds(REVALIDATION_INTERVAL_SECS) + + when: + // when expiration is past, then 'revalidate' should be true + def expiration = now.minusSeconds(1) + def revalidate = cache.shouldRevalidate(expiration.toEpochMilli(), now) + then: + 0 * cache.randomRevalidate(_) >> null + and: + revalidate + + when: + // when expiration is longer than the revalidation internal, then 'revalidate' is false + expiration = now.plusSeconds(REVALIDATION_INTERVAL_SECS +1) + revalidate = cache.shouldRevalidate(expiration.toEpochMilli(), now) + then: + 0 * cache.randomRevalidate(_) >> null + and: + !revalidate + + when: + // when expiration is less than or equal the revalidation internal, then 'revalidate' is computed randomly + expiration = now.plusSeconds(REVALIDATION_INTERVAL_SECS) + revalidate = cache.shouldRevalidate(expiration.toEpochMilli(), now) + then: + 1 * cache.randomRevalidate(_) >> true + and: + revalidate + + when: + // when expiration is less than or equal the revalidation internal, then 'revalidate' is computed randomly + expiration = now.plusSeconds(REVALIDATION_INTERVAL_SECS -1) + revalidate = cache.shouldRevalidate(expiration.toEpochMilli(), now) + then: + 1 * cache.randomRevalidate(_) >> false + and: + !revalidate + } + + def 'should validate random function' () { + given: + def now = Instant.now() + def cache = Spy(MyCache) + cache.getCacheRevalidationInterval() >> Duration.ofSeconds(10) + expect: + cache.randomRevalidate(0) + } + + @Retry(count = 5) + def 'should validate random revalidate with interval 10s' () { + given: + def now = Instant.now() + def cache = Spy(MyCache) + cache.getCacheRevalidationInterval() >> Duration.ofSeconds(10) + expect: + // when remaining time is approaching 0 + // the function should return true + cache.randomRevalidate(10) // 10 millis + cache.randomRevalidate(100) // 100 millis + } + + @Retry(count = 5) + def 'should validate random revalidate with interval 300s' () { + given: + def now = Instant.now() + def cache = Spy(MyCache) + cache.getCacheRevalidationInterval() >> Duration.ofSeconds(300) + expect: + // when remaining time is approaching 0 + // the function should return true + cache.randomRevalidate(10) // 10 millis + cache.randomRevalidate(100) // 100 millis + cache.randomRevalidate(500) // 100 millis + } }