Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/main/java/me/golemcore/bot/auto/AutoModeScheduler.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,12 @@ public AutoModeScheduler(AutoModeService autoModeService, ScheduleService schedu

@PostConstruct
public void init() {
if (!runtimeConfigService.isAutoModeEnabled()) {
log.info("[AutoScheduler] Auto mode disabled");
return;
}

goalManagementTool.setMilestoneCallback(event -> sendMilestoneNotification(event.message()));

autoModeService.loadState();

if (runtimeConfigService.isAutoStartEnabled() && !autoModeService.isAutoModeEnabled()) {
boolean featureEnabled = runtimeConfigService.isAutoModeEnabled();
if (featureEnabled && runtimeConfigService.isAutoStartEnabled() && !autoModeService.isAutoModeEnabled()) {
autoModeService.enableAutoMode();
log.info("[AutoScheduler] Auto-started auto mode");
}
Expand All @@ -138,6 +134,9 @@ public void init() {
TimeUnit.SECONDS);

log.info("[AutoScheduler] Started with tick interval: {}s", tickIntervalSeconds);
if (!featureEnabled) {
log.info("[AutoScheduler] Auto mode feature disabled in runtime config; scheduler is idle");
}
}

@PreDestroy
Expand Down Expand Up @@ -210,6 +209,10 @@ public void sendMilestoneNotification(String text) {

void tick() {
try {
if (!runtimeConfigService.isAutoModeEnabled()) {
return;
}

if (!autoModeService.isAutoModeEnabled()) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public class TokenBucketRateLimiter implements RateLimitPort {

private final RuntimeConfigService runtimeConfigService;

private final Map<String, TokenBucket> buckets = new ConcurrentHashMap<>();
private final Map<String, ConfiguredBucket> buckets = new ConcurrentHashMap<>();

@Override
public RateLimitResult tryConsume() {
Expand All @@ -69,10 +69,8 @@ public RateLimitResult tryConsume() {
}

String key = "user:global";
TokenBucket bucket = buckets.computeIfAbsent(key, k -> {
int requestsPerMinute = runtimeConfigService.getUserRequestsPerMinute();
return new TokenBucket(requestsPerMinute, Duration.ofMinutes(1));
});
int requestsPerMinute = runtimeConfigService.getUserRequestsPerMinute();
TokenBucket bucket = resolveBucket(key, requestsPerMinute, Duration.ofMinutes(1));

RateLimitResult result = bucket.tryConsume();
if (!result.isAllowed()) {
Expand All @@ -88,10 +86,8 @@ public RateLimitResult tryConsumeChannel(String channelType) {
}

String key = "channel:" + channelType;
TokenBucket bucket = buckets.computeIfAbsent(key, k -> {
int messagesPerSecond = runtimeConfigService.getChannelMessagesPerSecond();
return new TokenBucket(messagesPerSecond, Duration.ofSeconds(1));
});
int messagesPerSecond = runtimeConfigService.getChannelMessagesPerSecond();
TokenBucket bucket = resolveBucket(key, messagesPerSecond, Duration.ofSeconds(1));

return bucket.tryConsume();
}
Expand All @@ -103,20 +99,32 @@ public RateLimitResult tryConsumeLlm(String providerId) {
}

String key = "llm:" + providerId;
TokenBucket bucket = buckets.computeIfAbsent(key, k -> {
int requestsPerMinute = runtimeConfigService.getLlmRequestsPerMinute();
return new TokenBucket(requestsPerMinute, Duration.ofMinutes(1));
});
int requestsPerMinute = runtimeConfigService.getLlmRequestsPerMinute();
TokenBucket bucket = resolveBucket(key, requestsPerMinute, Duration.ofMinutes(1));

return bucket.tryConsume();
}

@Override
public BucketState getBucketState(String key) {
TokenBucket bucket = buckets.get(key);
if (bucket == null) {
ConfiguredBucket configuredBucket = buckets.get(key);
if (configuredBucket == null) {
return null;
}
return bucket.getState(key);
return configuredBucket.bucket().getState(key);
}

private TokenBucket resolveBucket(String key, int capacity, Duration refillPeriod) {
ConfiguredBucket configured = buckets.compute(key, (bucketKey, existing) -> {
if (existing == null || existing.capacity() != capacity
|| !existing.refillPeriod().equals(refillPeriod)) {
return new ConfiguredBucket(new TokenBucket(capacity, refillPeriod), capacity, refillPeriod);
}
return existing;
});
return configured.bucket();
}

private record ConfiguredBucket(TokenBucket bucket, int capacity, Duration refillPeriod) {
}
}
56 changes: 56 additions & 0 deletions src/test/java/me/golemcore/bot/auto/AutoModeSchedulerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -546,4 +546,60 @@ void shouldNotAutoStartWhenAlreadyEnabled() {

newScheduler.shutdown();
}

@Test
void shouldInitializeSchedulerEvenWhenFeatureDisabledAtStartup() {
when(runtimeConfigService.isAutoModeEnabled()).thenReturn(false);

AutoModeScheduler newScheduler = new AutoModeScheduler(
autoModeService, scheduleService, agentLoop, runtimeConfigService,
goalManagementTool, List.of(channelPort));

newScheduler.init();

verify(autoModeService).loadState();
verify(autoModeService, never()).enableAutoMode();

newScheduler.shutdown();
}

@Test
void tickShouldSkipWhenRuntimeFeatureDisabled() {
when(runtimeConfigService.isAutoModeEnabled()).thenReturn(false);
when(autoModeService.isAutoModeEnabled()).thenReturn(true);

scheduler.tick();

verify(scheduleService, never()).getDueSchedules();
verify(agentLoop, never()).processMessage(any(Message.class));
}

@Test
void shouldApplyRuntimeFeatureToggleImmediatelyBetweenTicks() {
when(runtimeConfigService.isAutoModeEnabled()).thenReturn(false, true);
when(autoModeService.isAutoModeEnabled()).thenReturn(true);

Goal goal = Goal.builder()
.id(GOAL_ID)
.title(GOAL_TITLE)
.status(Goal.GoalStatus.ACTIVE)
.tasks(new ArrayList<>())
.createdAt(Instant.now())
.build();
when(autoModeService.getGoal(GOAL_ID)).thenReturn(Optional.of(goal));

ScheduleEntry schedule = ScheduleEntry.builder()
.id("sched-goal-toggle")
.type(ScheduleEntry.ScheduleType.GOAL)
.targetId(GOAL_ID)
.cronExpression(TEST_CRON)
.enabled(true)
.build();
when(scheduleService.getDueSchedules()).thenReturn(List.of(schedule));

scheduler.tick();
scheduler.tick();

verify(agentLoop, org.mockito.Mockito.times(1)).processMessage(any(Message.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -18,14 +21,23 @@ class TokenBucketRateLimiterTest {

private RuntimeConfigService runtimeConfigService;
private TokenBucketRateLimiter rateLimiter;
private AtomicBoolean rateLimitEnabled;
private AtomicInteger userRequestsPerMinute;
private AtomicInteger channelMessagesPerSecond;
private AtomicInteger llmRequestsPerMinute;

@BeforeEach
void setUp() {
runtimeConfigService = mock(RuntimeConfigService.class);
when(runtimeConfigService.isRateLimitEnabled()).thenReturn(true);
when(runtimeConfigService.getUserRequestsPerMinute()).thenReturn(5);
when(runtimeConfigService.getChannelMessagesPerSecond()).thenReturn(10);
when(runtimeConfigService.getLlmRequestsPerMinute()).thenReturn(20);
rateLimitEnabled = new AtomicBoolean(true);
userRequestsPerMinute = new AtomicInteger(5);
channelMessagesPerSecond = new AtomicInteger(10);
llmRequestsPerMinute = new AtomicInteger(20);
when(runtimeConfigService.isRateLimitEnabled()).thenAnswer(invocation -> rateLimitEnabled.get());
when(runtimeConfigService.getUserRequestsPerMinute()).thenAnswer(invocation -> userRequestsPerMinute.get());
when(runtimeConfigService.getChannelMessagesPerSecond())
.thenAnswer(invocation -> channelMessagesPerSecond.get());
when(runtimeConfigService.getLlmRequestsPerMinute()).thenAnswer(invocation -> llmRequestsPerMinute.get());

rateLimiter = new TokenBucketRateLimiter(runtimeConfigService);
}
Expand Down Expand Up @@ -200,4 +212,34 @@ void shouldDecrementRemainingTokens() {

assertTrue(first.getRemainingTokens() > second.getRemainingTokens());
}

@Test
void shouldApplyUpdatedGlobalLimitImmediatelyForExistingBucket() {
assertTrue(rateLimiter.tryConsume().isAllowed());

userRequestsPerMinute.set(1);

assertTrue(rateLimiter.tryConsume().isAllowed());
assertFalse(rateLimiter.tryConsume().isAllowed());
}

@Test
void shouldApplyUpdatedChannelLimitImmediatelyForExistingBucket() {
assertTrue(rateLimiter.tryConsumeChannel(CHANNEL_TELEGRAM).isAllowed());

channelMessagesPerSecond.set(1);

assertTrue(rateLimiter.tryConsumeChannel(CHANNEL_TELEGRAM).isAllowed());
assertFalse(rateLimiter.tryConsumeChannel(CHANNEL_TELEGRAM).isAllowed());
}

@Test
void shouldApplyUpdatedLlmLimitImmediatelyForExistingBucket() {
assertTrue(rateLimiter.tryConsumeLlm(PROVIDER_OPENAI).isAllowed());

llmRequestsPerMinute.set(1);

assertTrue(rateLimiter.tryConsumeLlm(PROVIDER_OPENAI).isAllowed());
assertFalse(rateLimiter.tryConsumeLlm(PROVIDER_OPENAI).isAllowed());
}
}
Loading