diff --git a/src/main/java/me/golemcore/bot/auto/AutoModeScheduler.java b/src/main/java/me/golemcore/bot/auto/AutoModeScheduler.java index b0479ddc..381282de 100644 --- a/src/main/java/me/golemcore/bot/auto/AutoModeScheduler.java +++ b/src/main/java/me/golemcore/bot/auto/AutoModeScheduler.java @@ -110,16 +110,12 @@ public AutoModeScheduler(AutoModeService autoModeService, ScheduleService schedu @PostConstruct public void init() { - if (!runtimeConfigService.isAutoModeEnabled()) { - log.info("[AutoScheduler] Auto mode disabled"); - return; - } - goalManagementTool.setMilestoneCallback(event -> sendMilestoneNotification(event.message())); autoModeService.loadState(); - if (runtimeConfigService.isAutoStartEnabled() && !autoModeService.isAutoModeEnabled()) { + boolean featureEnabled = runtimeConfigService.isAutoModeEnabled(); + if (featureEnabled && runtimeConfigService.isAutoStartEnabled() && !autoModeService.isAutoModeEnabled()) { autoModeService.enableAutoMode(); log.info("[AutoScheduler] Auto-started auto mode"); } @@ -138,6 +134,9 @@ public void init() { TimeUnit.SECONDS); log.info("[AutoScheduler] Started with tick interval: {}s", tickIntervalSeconds); + if (!featureEnabled) { + log.info("[AutoScheduler] Auto mode feature disabled in runtime config; scheduler is idle"); + } } @PreDestroy @@ -210,6 +209,10 @@ public void sendMilestoneNotification(String text) { void tick() { try { + if (!runtimeConfigService.isAutoModeEnabled()) { + return; + } + if (!autoModeService.isAutoModeEnabled()) { return; } diff --git a/src/main/java/me/golemcore/bot/ratelimit/TokenBucketRateLimiter.java b/src/main/java/me/golemcore/bot/ratelimit/TokenBucketRateLimiter.java index 7de42ffb..c2199be0 100644 --- a/src/main/java/me/golemcore/bot/ratelimit/TokenBucketRateLimiter.java +++ b/src/main/java/me/golemcore/bot/ratelimit/TokenBucketRateLimiter.java @@ -60,7 +60,7 @@ public class TokenBucketRateLimiter implements RateLimitPort { private final RuntimeConfigService runtimeConfigService; - private final Map buckets = new ConcurrentHashMap<>(); + private final Map buckets = new ConcurrentHashMap<>(); @Override public RateLimitResult tryConsume() { @@ -69,10 +69,8 @@ public RateLimitResult tryConsume() { } String key = "user:global"; - TokenBucket bucket = buckets.computeIfAbsent(key, k -> { - int requestsPerMinute = runtimeConfigService.getUserRequestsPerMinute(); - return new TokenBucket(requestsPerMinute, Duration.ofMinutes(1)); - }); + int requestsPerMinute = runtimeConfigService.getUserRequestsPerMinute(); + TokenBucket bucket = resolveBucket(key, requestsPerMinute, Duration.ofMinutes(1)); RateLimitResult result = bucket.tryConsume(); if (!result.isAllowed()) { @@ -88,10 +86,8 @@ public RateLimitResult tryConsumeChannel(String channelType) { } String key = "channel:" + channelType; - TokenBucket bucket = buckets.computeIfAbsent(key, k -> { - int messagesPerSecond = runtimeConfigService.getChannelMessagesPerSecond(); - return new TokenBucket(messagesPerSecond, Duration.ofSeconds(1)); - }); + int messagesPerSecond = runtimeConfigService.getChannelMessagesPerSecond(); + TokenBucket bucket = resolveBucket(key, messagesPerSecond, Duration.ofSeconds(1)); return bucket.tryConsume(); } @@ -103,20 +99,32 @@ public RateLimitResult tryConsumeLlm(String providerId) { } String key = "llm:" + providerId; - TokenBucket bucket = buckets.computeIfAbsent(key, k -> { - int requestsPerMinute = runtimeConfigService.getLlmRequestsPerMinute(); - return new TokenBucket(requestsPerMinute, Duration.ofMinutes(1)); - }); + int requestsPerMinute = runtimeConfigService.getLlmRequestsPerMinute(); + TokenBucket bucket = resolveBucket(key, requestsPerMinute, Duration.ofMinutes(1)); return bucket.tryConsume(); } @Override public BucketState getBucketState(String key) { - TokenBucket bucket = buckets.get(key); - if (bucket == null) { + ConfiguredBucket configuredBucket = buckets.get(key); + if (configuredBucket == null) { return null; } - return bucket.getState(key); + return configuredBucket.bucket().getState(key); + } + + private TokenBucket resolveBucket(String key, int capacity, Duration refillPeriod) { + ConfiguredBucket configured = buckets.compute(key, (bucketKey, existing) -> { + if (existing == null || existing.capacity() != capacity + || !existing.refillPeriod().equals(refillPeriod)) { + return new ConfiguredBucket(new TokenBucket(capacity, refillPeriod), capacity, refillPeriod); + } + return existing; + }); + return configured.bucket(); + } + + private record ConfiguredBucket(TokenBucket bucket, int capacity, Duration refillPeriod) { } } diff --git a/src/test/java/me/golemcore/bot/auto/AutoModeSchedulerTest.java b/src/test/java/me/golemcore/bot/auto/AutoModeSchedulerTest.java index 0a6d6e07..21e10205 100644 --- a/src/test/java/me/golemcore/bot/auto/AutoModeSchedulerTest.java +++ b/src/test/java/me/golemcore/bot/auto/AutoModeSchedulerTest.java @@ -546,4 +546,60 @@ void shouldNotAutoStartWhenAlreadyEnabled() { newScheduler.shutdown(); } + + @Test + void shouldInitializeSchedulerEvenWhenFeatureDisabledAtStartup() { + when(runtimeConfigService.isAutoModeEnabled()).thenReturn(false); + + AutoModeScheduler newScheduler = new AutoModeScheduler( + autoModeService, scheduleService, agentLoop, runtimeConfigService, + goalManagementTool, List.of(channelPort)); + + newScheduler.init(); + + verify(autoModeService).loadState(); + verify(autoModeService, never()).enableAutoMode(); + + newScheduler.shutdown(); + } + + @Test + void tickShouldSkipWhenRuntimeFeatureDisabled() { + when(runtimeConfigService.isAutoModeEnabled()).thenReturn(false); + when(autoModeService.isAutoModeEnabled()).thenReturn(true); + + scheduler.tick(); + + verify(scheduleService, never()).getDueSchedules(); + verify(agentLoop, never()).processMessage(any(Message.class)); + } + + @Test + void shouldApplyRuntimeFeatureToggleImmediatelyBetweenTicks() { + when(runtimeConfigService.isAutoModeEnabled()).thenReturn(false, true); + when(autoModeService.isAutoModeEnabled()).thenReturn(true); + + Goal goal = Goal.builder() + .id(GOAL_ID) + .title(GOAL_TITLE) + .status(Goal.GoalStatus.ACTIVE) + .tasks(new ArrayList<>()) + .createdAt(Instant.now()) + .build(); + when(autoModeService.getGoal(GOAL_ID)).thenReturn(Optional.of(goal)); + + ScheduleEntry schedule = ScheduleEntry.builder() + .id("sched-goal-toggle") + .type(ScheduleEntry.ScheduleType.GOAL) + .targetId(GOAL_ID) + .cronExpression(TEST_CRON) + .enabled(true) + .build(); + when(scheduleService.getDueSchedules()).thenReturn(List.of(schedule)); + + scheduler.tick(); + scheduler.tick(); + + verify(agentLoop, org.mockito.Mockito.times(1)).processMessage(any(Message.class)); + } } diff --git a/src/test/java/me/golemcore/bot/ratelimit/TokenBucketRateLimiterTest.java b/src/test/java/me/golemcore/bot/ratelimit/TokenBucketRateLimiterTest.java index aa84d7c4..4c46f6e9 100644 --- a/src/test/java/me/golemcore/bot/ratelimit/TokenBucketRateLimiterTest.java +++ b/src/test/java/me/golemcore/bot/ratelimit/TokenBucketRateLimiterTest.java @@ -6,6 +6,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -18,14 +21,23 @@ class TokenBucketRateLimiterTest { private RuntimeConfigService runtimeConfigService; private TokenBucketRateLimiter rateLimiter; + private AtomicBoolean rateLimitEnabled; + private AtomicInteger userRequestsPerMinute; + private AtomicInteger channelMessagesPerSecond; + private AtomicInteger llmRequestsPerMinute; @BeforeEach void setUp() { runtimeConfigService = mock(RuntimeConfigService.class); - when(runtimeConfigService.isRateLimitEnabled()).thenReturn(true); - when(runtimeConfigService.getUserRequestsPerMinute()).thenReturn(5); - when(runtimeConfigService.getChannelMessagesPerSecond()).thenReturn(10); - when(runtimeConfigService.getLlmRequestsPerMinute()).thenReturn(20); + rateLimitEnabled = new AtomicBoolean(true); + userRequestsPerMinute = new AtomicInteger(5); + channelMessagesPerSecond = new AtomicInteger(10); + llmRequestsPerMinute = new AtomicInteger(20); + when(runtimeConfigService.isRateLimitEnabled()).thenAnswer(invocation -> rateLimitEnabled.get()); + when(runtimeConfigService.getUserRequestsPerMinute()).thenAnswer(invocation -> userRequestsPerMinute.get()); + when(runtimeConfigService.getChannelMessagesPerSecond()) + .thenAnswer(invocation -> channelMessagesPerSecond.get()); + when(runtimeConfigService.getLlmRequestsPerMinute()).thenAnswer(invocation -> llmRequestsPerMinute.get()); rateLimiter = new TokenBucketRateLimiter(runtimeConfigService); } @@ -200,4 +212,34 @@ void shouldDecrementRemainingTokens() { assertTrue(first.getRemainingTokens() > second.getRemainingTokens()); } + + @Test + void shouldApplyUpdatedGlobalLimitImmediatelyForExistingBucket() { + assertTrue(rateLimiter.tryConsume().isAllowed()); + + userRequestsPerMinute.set(1); + + assertTrue(rateLimiter.tryConsume().isAllowed()); + assertFalse(rateLimiter.tryConsume().isAllowed()); + } + + @Test + void shouldApplyUpdatedChannelLimitImmediatelyForExistingBucket() { + assertTrue(rateLimiter.tryConsumeChannel(CHANNEL_TELEGRAM).isAllowed()); + + channelMessagesPerSecond.set(1); + + assertTrue(rateLimiter.tryConsumeChannel(CHANNEL_TELEGRAM).isAllowed()); + assertFalse(rateLimiter.tryConsumeChannel(CHANNEL_TELEGRAM).isAllowed()); + } + + @Test + void shouldApplyUpdatedLlmLimitImmediatelyForExistingBucket() { + assertTrue(rateLimiter.tryConsumeLlm(PROVIDER_OPENAI).isAllowed()); + + llmRequestsPerMinute.set(1); + + assertTrue(rateLimiter.tryConsumeLlm(PROVIDER_OPENAI).isAllowed()); + assertFalse(rateLimiter.tryConsumeLlm(PROVIDER_OPENAI).isAllowed()); + } }