diff --git a/.gitignore b/.gitignore index db9d3e1f..ef5d4ec8 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,4 @@ examples/lib/.classpath examples/lib/.project examples/lib/bin .vscode/ +momento-sdk/bin \ No newline at end of file diff --git a/Makefile b/Makefile index 52c1c1e7..075471f8 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,10 @@ test-leaderboard-service: test-topics-service: @CONSISTENT_READS=1 ./gradlew test-topics-service +## Run the topics subscription initialization tests +test-topics-subscription-initialization: + @CONSISTENT_READS=1 ./gradlew test-topics-subscription-initialization + ## Run the http service tests test-http-service: @echo "No tests for http service." diff --git a/momento-sdk/build.gradle.kts b/momento-sdk/build.gradle.kts index 439a9711..89dc6634 100644 --- a/momento-sdk/build.gradle.kts +++ b/momento-sdk/build.gradle.kts @@ -109,3 +109,8 @@ registerIntegrationTestTask( "test-retries", listOf("momento.sdk.retry.*") ) + +registerIntegrationTestTask( + "test-topics-subscription-initialization", + listOf("momento.sdk.retry.TopicsSubscriptionInitializationTest") +) diff --git a/momento-sdk/src/intTest/java/momento/sdk/retry/BaseMomentoLocalTestClass.java b/momento-sdk/src/intTest/java/momento/sdk/retry/BaseMomentoLocalTestClass.java index 886494b9..ed4159db 100644 --- a/momento-sdk/src/intTest/java/momento/sdk/retry/BaseMomentoLocalTestClass.java +++ b/momento-sdk/src/intTest/java/momento/sdk/retry/BaseMomentoLocalTestClass.java @@ -14,6 +14,8 @@ import momento.sdk.config.Configurations; import momento.sdk.config.TopicConfiguration; import momento.sdk.config.TopicConfigurations; +import momento.sdk.config.transport.GrpcConfiguration; +import momento.sdk.config.transport.StaticTransportStrategy; import momento.sdk.responses.cache.control.CacheCreateResponse; import momento.sdk.retry.utils.MomentoLocalMiddleware; import momento.sdk.retry.utils.MomentoLocalMiddlewareArgs; @@ -126,6 +128,39 @@ public static void withCacheAndTopicClient( } } + public static void withCacheAndTopicClientWithNumStreamChannels( + int numStreamChannels, + MomentoLocalMiddlewareArgs testMetricsMiddlewareArgs, + TopicTestCallback testCallback) + throws Exception { + + final String cacheName = testCacheName(); + final String hostname = + Optional.ofNullable(System.getenv("MOMENTO_HOSTNAME")).orElse("127.0.0.1"); + final int port = + Optional.ofNullable(System.getenv("MOMENTO_PORT")).map(Integer::parseInt).orElse(8080); + final CredentialProvider credentialProvider = new MomentoLocalProvider(hostname, port); + + final GrpcConfiguration grpcConfig = + new GrpcConfiguration(Duration.ofMillis(15000)) + .withNumStreamGrpcChannels(numStreamChannels); + final TopicConfiguration topicConfiguration = + new TopicConfiguration(new StaticTransportStrategy(grpcConfig)) + .withMiddleware(new MomentoLocalMiddleware(testMetricsMiddlewareArgs)); + + try (final CacheClient cacheClient = + CacheClient.builder( + credentialProvider, Configurations.Laptop.latest(), DEFAULT_TTL_SECONDS) + .build(); + final TopicClient topicClient = + TopicClient.builder(credentialProvider, topicConfiguration).build()) { + if (cacheClient.createCache(cacheName).join() instanceof CacheCreateResponse.Error) { + throw new RuntimeException("Failed to create cache: " + cacheName); + } + testCallback.run(topicClient, cacheName); + } + } + @FunctionalInterface public interface CacheTestCallback { void run(CacheClient cc, String cacheName) throws Exception; diff --git a/momento-sdk/src/intTest/java/momento/sdk/retry/TopicsSubscriptionInitializationTest.java b/momento-sdk/src/intTest/java/momento/sdk/retry/TopicsSubscriptionInitializationTest.java new file mode 100644 index 00000000..2527224c --- /dev/null +++ b/momento-sdk/src/intTest/java/momento/sdk/retry/TopicsSubscriptionInitializationTest.java @@ -0,0 +1,461 @@ +package momento.sdk.retry; + +import static momento.sdk.retry.BaseMomentoLocalTestClass.withCacheAndTopicClientWithNumStreamChannels; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.slf4j.LoggerFactory.getLogger; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicInteger; +import momento.sdk.ISubscriptionCallbacks; +import momento.sdk.exceptions.MomentoErrorCode; +import momento.sdk.responses.topic.TopicMessage; +import momento.sdk.responses.topic.TopicSubscribeResponse; +import momento.sdk.retry.utils.MomentoLocalMiddlewareArgs; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.Logger; + +public class TopicsSubscriptionInitializationTest { + private int unsubscribeCounter = 0; + + private ISubscriptionCallbacks callbacks() { + return new ISubscriptionCallbacks() { + @Override + public void onItem(TopicMessage message) {} + + @Override + public void onCompleted() { + unsubscribeCounter++; + } + + @Override + public void onError(Throwable t) {} + }; + } + + private static Logger logger; + + @BeforeAll + static void setup() { + logger = getLogger(TopicsSubscriptionInitializationTest.class); + } + + @Test + @Timeout(30) + public void oneStreamChannel_doesNotSilentlyQueueSubscribeRequestOnFullChannel() + throws Exception { + unsubscribeCounter = 0; + + withCacheAndTopicClientWithNumStreamChannels( + 1, + new MomentoLocalMiddlewareArgs.Builder(logger, UUID.randomUUID().toString()).build(), + (topicClient, cacheName) -> { + // These should all succeed + // Starting 100 subscriptions on 1 channel should be fine + List subscriptions = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + final TopicSubscribeResponse response = + topicClient.subscribe(cacheName, "test-topic", callbacks()).join(); + assertThat(response).isInstanceOf(TopicSubscribeResponse.Subscription.class); + subscriptions.add((TopicSubscribeResponse.Subscription) response); + } + + // Wait a bit for all subscriptions to be fully established + try { + Thread.sleep(500); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Test interrupted while waiting for subscriptions", e); + } + + // Starting one more subscription should produce resource exhausted error + final TopicSubscribeResponse response = + topicClient.subscribe(cacheName, "test-topic", callbacks()).join(); + assertThat(response).isInstanceOf(TopicSubscribeResponse.Error.class); + assertEquals( + MomentoErrorCode.CLIENT_RESOURCE_EXHAUSTED, + ((TopicSubscribeResponse.Error) response).getErrorCode()); + + // Ending a subscription should free up one new stream + subscriptions.get(0).unsubscribe(); + // Wait for the subscription to end + try { + Thread.sleep(500); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Test interrupted while waiting for subscriptions", e); + } + assertEquals(1, unsubscribeCounter); + + final TopicSubscribeResponse response2 = + topicClient.subscribe(cacheName, "test-topic", callbacks()).join(); + assertThat(response2).isInstanceOf(TopicSubscribeResponse.Subscription.class); + subscriptions.add((TopicSubscribeResponse.Subscription) response2); + + // Cleanup + for (TopicSubscribeResponse.Subscription sub : subscriptions) { + if (sub != null) { + sub.unsubscribe(); + } + } + }); + } + + @ParameterizedTest + @ValueSource(ints = {2, 10, 20}) + @Timeout(30) + public void multipleStreamChannels_handlesBurstOfSubscribeAndUnsubscribeRequests( + int numGrpcChannels) throws Exception { + unsubscribeCounter = 0; + final int maxStreamCapacity = 100 * numGrpcChannels; + + withCacheAndTopicClientWithNumStreamChannels( + numGrpcChannels, + new MomentoLocalMiddlewareArgs.Builder(logger, UUID.randomUUID().toString()).build(), + (topicClient, cacheName) -> { + List> subscribeRequests = new ArrayList<>(); + for (int i = 0; i < maxStreamCapacity; i++) { + final CompletableFuture response = + topicClient.subscribe(cacheName, "test-topic", callbacks()); + subscribeRequests.add(response); + } + // Wait for all the subscribe requests to complete + CompletableFuture.allOf(subscribeRequests.toArray(new CompletableFuture[0])).join(); + + // Wait a bit for all subscriptions to be fully established + try { + Thread.sleep(500); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Test interrupted while waiting for subscriptions", e); + } + + // Verify they all succeeded + List subscriptions = new ArrayList<>(); + for (CompletableFuture future : subscribeRequests) { + TopicSubscribeResponse response = future.join(); + assertThat(response).isInstanceOf(TopicSubscribeResponse.Subscription.class); + subscriptions.add((TopicSubscribeResponse.Subscription) response); + } + + // Unsubscribe half of the subscriptions + final int unsubscribeBurstSize = maxStreamCapacity / 2; + for (int i = 0; i < unsubscribeBurstSize; i++) { + subscriptions.get(i).unsubscribe(); + } + // Wait a bit for the subscription to end + try { + Thread.sleep(500); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Test interrupted while waiting for subscriptions", e); + } + assertEquals(unsubscribeBurstSize, unsubscribeCounter); + + // Burst of subscribe requests should succeed + final int subscribeBurstSize = maxStreamCapacity / 2 + 10; + List> subscribeRequests2 = new ArrayList<>(); + for (int i = 0; i < subscribeBurstSize; i++) { + final CompletableFuture subscribePromise = + topicClient.subscribe(cacheName, "test-topic", callbacks()); + subscribeRequests2.add(subscribePromise); + } + CompletableFuture.allOf(subscribeRequests2.toArray(new CompletableFuture[0])).join(); + + // Wait a bit for all subscriptions to be fully established + try { + Thread.sleep(500); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Test interrupted while waiting for subscriptions", e); + } + + List successfulSubscriptions2 = new ArrayList<>(); + int numFailedSubscriptions = 0; + for (CompletableFuture future : subscribeRequests2) { + TopicSubscribeResponse response = future.join(); + if (response instanceof TopicSubscribeResponse.Subscription) { + successfulSubscriptions2.add((TopicSubscribeResponse.Subscription) response); + } else { + numFailedSubscriptions++; + } + } + assertEquals(10, numFailedSubscriptions); + assertEquals(subscribeBurstSize - 10, successfulSubscriptions2.size()); + + // Cleanup + for (TopicSubscribeResponse.Subscription sub : subscriptions) { + sub.unsubscribe(); + } + }); + } + + @ParameterizedTest + @ValueSource(ints = {2, 10, 20}) + @Timeout(30) + public void multipleStreamChannels_handlesBurstOfSubscribeRequestsAtMaxCapacity( + int numGrpcChannels) throws Exception { + final int maxStreamCapacity = 100 * numGrpcChannels; + + withCacheAndTopicClientWithNumStreamChannels( + numGrpcChannels, + new MomentoLocalMiddlewareArgs.Builder(logger, UUID.randomUUID().toString()).build(), + (topicClient, cacheName) -> { + List> subscribeRequests = new ArrayList<>(); + for (int i = 0; i < maxStreamCapacity; i++) { + final CompletableFuture response = + topicClient.subscribe(cacheName, "test-topic", callbacks()); + subscribeRequests.add(response); + } + // Wait for all the subscribe requests to complete + CompletableFuture.allOf(subscribeRequests.toArray(new CompletableFuture[0])).join(); + + // Wait a bit for all subscriptions to be fully established + try { + Thread.sleep(500); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Test interrupted while waiting for subscriptions", e); + } + + // Verify they all succeeded + List subscriptions = new ArrayList<>(); + for (CompletableFuture future : subscribeRequests) { + TopicSubscribeResponse response = future.join(); + assertThat(response).isInstanceOf(TopicSubscribeResponse.Subscription.class); + subscriptions.add((TopicSubscribeResponse.Subscription) response); + } + + // Cleanup + for (TopicSubscribeResponse.Subscription sub : subscriptions) { + sub.unsubscribe(); + } + }); + } + + @ParameterizedTest + @ValueSource(ints = {2, 10, 20}) + @Timeout(30) + public void multipleStreamChannels_handlesBurstOfSubscribeRequestsAtOverMaxCapacity( + int numGrpcChannels) throws Exception { + final int maxStreamCapacity = 100 * numGrpcChannels; + + withCacheAndTopicClientWithNumStreamChannels( + numGrpcChannels, + new MomentoLocalMiddlewareArgs.Builder(logger, UUID.randomUUID().toString()).build(), + (topicClient, cacheName) -> { + List> subscribeRequests = new ArrayList<>(); + for (int i = 0; i < maxStreamCapacity + 10; i++) { + final CompletableFuture response = + topicClient.subscribe(cacheName, "test-topic", callbacks()); + subscribeRequests.add(response); + } + // Wait for all the subscribe requests to complete + CompletableFuture.allOf(subscribeRequests.toArray(new CompletableFuture[0])).join(); + + // Wait a bit for all subscriptions to be fully established + try { + Thread.sleep(500); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Test interrupted while waiting for subscriptions", e); + } + + // Verify they all succeeded + List subscriptions = new ArrayList<>(); + int numFailedSubscriptions = 0; + for (CompletableFuture future : subscribeRequests) { + TopicSubscribeResponse response = future.join(); + if (response instanceof TopicSubscribeResponse.Error) { + numFailedSubscriptions++; + } else { + assertThat(response).isInstanceOf(TopicSubscribeResponse.Subscription.class); + subscriptions.add((TopicSubscribeResponse.Subscription) response); + } + } + assertEquals(10, numFailedSubscriptions); + assertEquals(maxStreamCapacity, subscriptions.size()); + + // Cleanup + for (TopicSubscribeResponse.Subscription sub : subscriptions) { + sub.unsubscribe(); + } + }); + } + + @ParameterizedTest + @ValueSource(ints = {2, 10, 20}) + @Timeout(30) + public void multipleStreamChannels_handlesBurstOfSubscribeRequestsAtHalfOfMaxCapacity( + int numGrpcChannels) throws Exception { + final int maxStreamCapacity = 100 * numGrpcChannels; + + withCacheAndTopicClientWithNumStreamChannels( + numGrpcChannels, + new MomentoLocalMiddlewareArgs.Builder(logger, UUID.randomUUID().toString()).build(), + (topicClient, cacheName) -> { + List> subscribeRequests = new ArrayList<>(); + for (int i = 0; i < maxStreamCapacity / 2; i++) { + final CompletableFuture response = + topicClient.subscribe(cacheName, "test-topic", callbacks()); + subscribeRequests.add(response); + } + // Wait for all the subscribe requests to complete + CompletableFuture.allOf(subscribeRequests.toArray(new CompletableFuture[0])).join(); + + // Wait a bit for all subscriptions to be fully established + try { + Thread.sleep(500); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Test interrupted while waiting for subscriptions", e); + } + + // Verify they all succeeded + List subscriptions = new ArrayList<>(); + for (CompletableFuture future : subscribeRequests) { + TopicSubscribeResponse response = future.join(); + assertThat(response).isInstanceOf(TopicSubscribeResponse.Subscription.class); + subscriptions.add((TopicSubscribeResponse.Subscription) response); + } + + // Cleanup + for (TopicSubscribeResponse.Subscription sub : subscriptions) { + sub.unsubscribe(); + } + }); + } + + @Test + @Timeout(30) + public void shouldDecrementActiveSubscriptionsCountWhenSubscribeRequestsFail() throws Exception { + final int numGrpcChannels = 1; + final int maxStreamCapacity = 100 * numGrpcChannels; + + withCacheAndTopicClientWithNumStreamChannels( + numGrpcChannels, + new MomentoLocalMiddlewareArgs.Builder(logger, UUID.randomUUID().toString()).build(), + (topicClient, cacheName) -> { + final Semaphore errorSemaphore = new Semaphore(0); + final AtomicInteger errorCounter = new AtomicInteger(0); + + final ISubscriptionCallbacks callbacks = + new ISubscriptionCallbacks() { + @Override + public void onItem(TopicMessage message) {} + + @Override + public void onCompleted() {} + + @Override + public void onError(Throwable t) { + errorCounter.incrementAndGet(); + errorSemaphore.release(); + } + }; + + // Should successfully start the maximum number of subscriptions because 10 attempts ran + // into NOT_FOUND_ERROR. The errors should have decremented the active subscriptions + // count. + List successfulSubscriptions = new ArrayList<>(); + for (int i = 0; i < maxStreamCapacity + 10; i++) { + String cacheNameToUse = cacheName; + if (i % 11 == 0) { + cacheNameToUse = "this-cache-does-not-exist"; + } + TopicSubscribeResponse attempt = + topicClient.subscribe(cacheNameToUse, "test-topic", callbacks).join(); + if (attempt instanceof TopicSubscribeResponse.Subscription) { + successfulSubscriptions.add((TopicSubscribeResponse.Subscription) attempt); + } else { + assertThat(attempt).isInstanceOf(TopicSubscribeResponse.Error.class); + assertThat(((TopicSubscribeResponse.Error) attempt).getErrorCode()) + .isEqualTo(MomentoErrorCode.NOT_FOUND_ERROR); + errorCounter.incrementAndGet(); + } + } + + // Assert that we have received maxStreamCapacity number of successful subscriptions + assertThat(successfulSubscriptions.size()).isEqualTo(maxStreamCapacity); + + // Assert that we have received 10 NOT_FOUND_ERRORs + assertThat(errorCounter).hasValue(10); + + // Cleanup + for (TopicSubscribeResponse.Subscription sub : successfulSubscriptions) { + sub.unsubscribe(); + } + }); + } + + @Test + @Timeout(30) + public void oneStreamChannel_properlyDecrementsWhenErrorOccursMidStream() throws Exception { + unsubscribeCounter = 0; + final AtomicInteger unsubscribeOnErrorCounter = new AtomicInteger(0); + final ISubscriptionCallbacks callbacks = + new ISubscriptionCallbacks() { + @Override + public void onItem(TopicMessage message) {} + + @Override + public void onCompleted() { + System.out.println("onCompleted"); + unsubscribeCounter++; + } + + @Override + public void onError(Throwable t) { + System.out.println("onError"); + unsubscribeOnErrorCounter.incrementAndGet(); + } + }; + + final MomentoLocalMiddlewareArgs middlewareArgs = + new MomentoLocalMiddlewareArgs.Builder(logger, UUID.randomUUID().toString()) + .streamError(MomentoErrorCode.NOT_FOUND_ERROR) + .streamErrorRpcList(Collections.singletonList(MomentoRpcMethod.TOPIC_SUBSCRIBE)) + .streamErrorMessageLimit(3) + .build(); + + withCacheAndTopicClientWithNumStreamChannels( + 1, + middlewareArgs, + (topicClient, cacheName) -> { + List subscriptions = new ArrayList<>(); + + // Subscribe but expecting an error after a couple of heartbeats + final TopicSubscribeResponse response = + topicClient.subscribe(cacheName, "topic", callbacks).join(); + assertThat(response).isInstanceOf(TopicSubscribeResponse.Subscription.class); + subscriptions.add((TopicSubscribeResponse.Subscription) response); + + // Wait for the subscription that ran into the error to be closed + try { + Thread.sleep(3000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Test interrupted while waiting for subscriptions", e); + } + + // Cleanup + for (TopicSubscribeResponse.Subscription sub : subscriptions) { + if (sub != null) { + sub.unsubscribe(); + } + } + + assertEquals(0, unsubscribeCounter); + assertEquals(1, unsubscribeOnErrorCounter.get()); + }); + } +} diff --git a/momento-sdk/src/main/java/momento/sdk/ScsTopicClient.java b/momento-sdk/src/main/java/momento/sdk/ScsTopicClient.java index cf32f214..867fb441 100644 --- a/momento-sdk/src/main/java/momento/sdk/ScsTopicClient.java +++ b/momento-sdk/src/main/java/momento/sdk/ScsTopicClient.java @@ -10,6 +10,7 @@ import momento.sdk.auth.CredentialProvider; import momento.sdk.config.TopicConfiguration; import momento.sdk.exceptions.CacheServiceExceptionMapper; +import momento.sdk.exceptions.ClientSdkException; import momento.sdk.internal.SubscriptionState; import momento.sdk.responses.topic.TopicPublishResponse; import momento.sdk.responses.topic.TopicSubscribeResponse; @@ -118,40 +119,55 @@ private CompletableFuture sendSubscribe( String cacheName, String topicName, ISubscriptionCallbacks callbacks) { final SubscriptionState subscriptionState = new SubscriptionState(); - final IScsTopicConnection connection = - (request, subscription) -> - topicGrpcStubsManager.getNextStreamStub().subscribe(request, subscription); - - long configuredTimeoutSeconds = - topicGrpcStubsManager - .getConfiguration() - .getTransportStrategy() - .getGrpcConfiguration() - .getDeadline() - .getSeconds(); - long firstMessageSubscribeTimeoutSeconds = - configuredTimeoutSeconds > 0 ? configuredTimeoutSeconds : DEFAULT_REQUEST_TIMEOUT_SECONDS; - - @SuppressWarnings("resource") // the wrapper closes itself when a subscription ends. - final SubscriptionWrapper subscriptionWrapper = - new SubscriptionWrapper( - cacheName, - topicName, - connection, - callbacks, - subscriptionState, - firstMessageSubscribeTimeoutSeconds, - subscriptionRetryStrategy); - final CompletableFuture subscribeFuture = subscriptionWrapper.subscribeWithRetry(); - return subscribeFuture.handle( - (v, ex) -> { - if (ex != null) { - return new TopicSubscribeResponse.Error(CacheServiceExceptionMapper.convert(ex)); - } else { - subscriptionState.setUnsubscribeFn(subscriptionWrapper::unsubscribe); - return new TopicSubscribeResponse.Subscription(subscriptionState); - } - }); + try { + // Wrap in try-catch because getNextStreamStub() can throw an exception + // if the number of active subscriptions is already at max capacity. + final StreamStubWithCount stubWithCount = topicGrpcStubsManager.getNextStreamStub(); + + final IScsTopicConnection connection = + (request, subscription) -> stubWithCount.getStub().subscribe(request, subscription); + + long configuredTimeoutSeconds = + topicGrpcStubsManager + .getConfiguration() + .getTransportStrategy() + .getGrpcConfiguration() + .getDeadline() + .getSeconds(); + long firstMessageSubscribeTimeoutSeconds = + configuredTimeoutSeconds > 0 ? configuredTimeoutSeconds : DEFAULT_REQUEST_TIMEOUT_SECONDS; + + @SuppressWarnings("resource") // the wrapper closes itself when a subscription ends. + final SubscriptionWrapper subscriptionWrapper = + new SubscriptionWrapper( + cacheName, + topicName, + connection, + callbacks, + subscriptionState, + firstMessageSubscribeTimeoutSeconds, + subscriptionRetryStrategy); + + final CompletableFuture subscribeFuture = subscriptionWrapper.subscribeWithRetry(); + return subscribeFuture.handle( + (v, ex) -> { + if (ex != null) { + stubWithCount.decrementCount(); + return new TopicSubscribeResponse.Error(CacheServiceExceptionMapper.convert(ex)); + } else { + subscriptionState.setDecrementActiveSubscriptionsCountFn( + stubWithCount::decrementCount); + subscriptionState.setUnsubscribeFn(subscriptionWrapper::unsubscribe); + return new TopicSubscribeResponse.Subscription(subscriptionState); + } + }); + } catch (ClientSdkException e) { + // getNextStreamStub() may throw a ClientSdkException + return CompletableFuture.completedFuture(new TopicSubscribeResponse.Error(e)); + } catch (TopicSubscribeResponse.Error e) { + // subscribeWithRetry() may throw a TopicSubscribeResponse.Error + return CompletableFuture.completedFuture(e); + } } @Override diff --git a/momento-sdk/src/main/java/momento/sdk/ScsTopicGrpcStubsManager.java b/momento-sdk/src/main/java/momento/sdk/ScsTopicGrpcStubsManager.java index a4f43626..484df6ef 100644 --- a/momento-sdk/src/main/java/momento/sdk/ScsTopicGrpcStubsManager.java +++ b/momento-sdk/src/main/java/momento/sdk/ScsTopicGrpcStubsManager.java @@ -19,8 +19,47 @@ import momento.sdk.config.TopicConfiguration; import momento.sdk.config.middleware.Middleware; import momento.sdk.config.middleware.MiddlewareRequestHandlerContext; +import momento.sdk.exceptions.ClientSdkException; +import momento.sdk.exceptions.MomentoErrorCode; import momento.sdk.internal.GrpcChannelOptions; +// Helper class for bookkeeping the number of active concurrent subscriptions. +final class StreamStubWithCount { + private final PubsubGrpc.PubsubStub stub; + private final AtomicInteger count = new AtomicInteger(0); + + StreamStubWithCount(PubsubGrpc.PubsubStub stub) { + this.stub = stub; + } + + PubsubGrpc.PubsubStub getStub() { + return stub; + } + + int getCount() { + return count.get(); + } + + int incrementCount() { + return count.incrementAndGet(); + } + + int decrementCount() { + return count.decrementAndGet(); + } + + void acquireStubOrThrow() throws ClientSdkException { + if (count.incrementAndGet() <= 100) { + return; + } else { + count.decrementAndGet(); + throw new ClientSdkException( + MomentoErrorCode.CLIENT_RESOURCE_EXHAUSTED, + "Maximum number of active subscriptions reached"); + } + } +} + /** * Manager responsible for GRPC channels and stubs for the Topics. * @@ -35,7 +74,7 @@ final class ScsTopicGrpcStubsManager implements Closeable { private final AtomicInteger unaryIndex = new AtomicInteger(0); private final List streamChannels; - private final List streamStubs; + private final List streamStubs; private final AtomicInteger streamIndex = new AtomicInteger(0); public static final UUID CONNECTION_ID_KEY = UUID.randomUUID(); @@ -65,7 +104,10 @@ final class ScsTopicGrpcStubsManager implements Closeable { .mapToObj(i -> setupConnection(credentialProvider, configuration)) .collect(Collectors.toList()); this.streamStubs = - streamChannels.stream().map(PubsubGrpc::newStub).collect(Collectors.toList()); + streamChannels.stream() + .map(PubsubGrpc::newStub) + .map(StreamStubWithCount::new) + .collect(Collectors.toList()); } private static ManagedChannel setupConnection( @@ -100,8 +142,27 @@ PubsubGrpc.PubsubStub getNextUnaryStub() { } /** Round-robin subscribe stub. */ - PubsubGrpc.PubsubStub getNextStreamStub() { - return streamStubs.get(streamIndex.getAndIncrement() % this.numStreamGrpcChannels); + StreamStubWithCount getNextStreamStub() { + // Try to get a client with capacity for another subscription + // by round-robining through the stubs. + // Allow up to maximumActiveSubscriptions attempts to account for large bursts of requests. + final int maximumActiveSubscriptions = this.numStreamGrpcChannels * 100; + for (int i = 0; i < maximumActiveSubscriptions; i++) { + final StreamStubWithCount stubWithCount = + streamStubs.get(streamIndex.getAndIncrement() % this.numStreamGrpcChannels); + try { + stubWithCount.acquireStubOrThrow(); + return stubWithCount; + } catch (ClientSdkException e) { + // If the stub is at capacity, continue to the next one. + continue; + } + } + + // Otherwise return an error if no stubs have capacity. + throw new ClientSdkException( + MomentoErrorCode.CLIENT_RESOURCE_EXHAUSTED, + "Maximum number of active subscriptions reached"); } TopicConfiguration getConfiguration() { diff --git a/momento-sdk/src/main/java/momento/sdk/SubscriptionWrapper.java b/momento-sdk/src/main/java/momento/sdk/SubscriptionWrapper.java index 1906c828..d0ab0838 100644 --- a/momento-sdk/src/main/java/momento/sdk/SubscriptionWrapper.java +++ b/momento-sdk/src/main/java/momento/sdk/SubscriptionWrapper.java @@ -92,7 +92,7 @@ CompletableFuture subscribeWithRetry() { logger.warn( "First message timeout exceeded for topic {} on cache {}", topicName, cacheName); - if (subscription != null) { + if (subscription.get() != null) { subscription.get().cancel("Timed out waiting for first message", null); } @@ -241,6 +241,7 @@ public void onCompleted() { } private void completeExceptionally(CompletableFuture future, Throwable t) { + subscriptionState.decrementActiveSubscriptionsCount(); future.completeExceptionally( new TopicSubscribeResponse.Error(CacheServiceExceptionMapper.convert(t))); close(); @@ -251,6 +252,7 @@ private void scheduleRetry(Duration retryDelay, Runnable retryAction) { } private void handleSubscriptionCompleted() { + subscriptionState.decrementActiveSubscriptionsCount(); callbacks.onCompleted(); } @@ -346,6 +348,7 @@ public void unsubscribe() { @Override public void close() { + subscriptionState.decrementActiveSubscriptionsCount(); scheduler.shutdown(); } } diff --git a/momento-sdk/src/main/java/momento/sdk/internal/SubscriptionState.java b/momento-sdk/src/main/java/momento/sdk/internal/SubscriptionState.java index 5ba3702e..1e1a05bf 100644 --- a/momento-sdk/src/main/java/momento/sdk/internal/SubscriptionState.java +++ b/momento-sdk/src/main/java/momento/sdk/internal/SubscriptionState.java @@ -3,6 +3,7 @@ /** Represents the state of a subscription to a topic. */ public class SubscriptionState { + private Runnable decrementActiveSubscriptionsCountFn; private Runnable unsubscribeFn; private Long lastTopicSequenceNumber; private Long lastTopicSequencePage; @@ -10,6 +11,7 @@ public class SubscriptionState { /** Constructs a new SubscriptionState instance with default values. */ public SubscriptionState() { + this.decrementActiveSubscriptionsCountFn = () -> {}; this.unsubscribeFn = () -> {}; this.isSubscribed = false; } @@ -59,8 +61,24 @@ public void setUnsubscribeFn(Runnable unsubscribeFn) { /** Unsubscribes from the topic, executing the unsubscribe function. */ public void unsubscribe() { if (isSubscribed) { + decrementActiveSubscriptionsCountFn.run(); unsubscribeFn.run(); this.isSubscribed = false; } } + + /** + * Sets the function to be decrement the active subscriptions count for a given stub. + * + * @param decrementActiveSubscriptionsCountFn The function to decrement the active subscriptions + * count. + */ + public void setDecrementActiveSubscriptionsCountFn(Runnable decrementActiveSubscriptionsCountFn) { + this.decrementActiveSubscriptionsCountFn = decrementActiveSubscriptionsCountFn; + } + + /** Decrements the active subscriptions count for a given stub. */ + public void decrementActiveSubscriptionsCount() { + decrementActiveSubscriptionsCountFn.run(); + } } diff --git a/momento-sdk/src/test/java/momento/sdk/SubscriptionWrapperTest.java b/momento-sdk/src/test/java/momento/sdk/SubscriptionWrapperTest.java index ceb0e44b..169dca09 100644 --- a/momento-sdk/src/test/java/momento/sdk/SubscriptionWrapperTest.java +++ b/momento-sdk/src/test/java/momento/sdk/SubscriptionWrapperTest.java @@ -121,5 +121,7 @@ public void subscribe( waitingForSubscriptionAttempt.acquire(); assertTrue(gotConnectionRestoredCallback.get()); + + subscriptionWrapper.close(); } }