diff --git a/db/migrations/000003_add_subscriptions_topic_idx.down.sql b/db/migrations/000003_add_subscriptions_topic_idx.down.sql new file mode 100644 index 0000000..2b8a318 --- /dev/null +++ b/db/migrations/000003_add_subscriptions_topic_idx.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS subscriptions_topic_id_idx; diff --git a/db/migrations/000003_add_subscriptions_topic_idx.up.sql b/db/migrations/000003_add_subscriptions_topic_idx.up.sql new file mode 100644 index 0000000..81e41b2 --- /dev/null +++ b/db/migrations/000003_add_subscriptions_topic_idx.up.sql @@ -0,0 +1 @@ +CREATE INDEX IF NOT EXISTS subscriptions_topic_id_idx ON subscriptions (topic_id); diff --git a/domain/queue.go b/domain/queue.go index 2e4eeeb..3417e2b 100644 --- a/domain/queue.go +++ b/domain/queue.go @@ -41,6 +41,7 @@ type QueueRepository interface { Create(ctx context.Context, queue *Queue) error Update(ctx context.Context, queue *Queue) error Get(ctx context.Context, id string) (*Queue, error) + GetMany(ctx context.Context, ids []string) (map[string]*Queue, error) List(ctx context.Context, offset, limit uint) ([]*Queue, error) Delete(ctx context.Context, id string) error Stats(ctx context.Context, id string) (*QueueStats, error) diff --git a/mocks/QueueRepository.go b/mocks/QueueRepository.go index a889c7a..e7c7ae1 100644 --- a/mocks/QueueRepository.go +++ b/mocks/QueueRepository.go @@ -98,6 +98,36 @@ func (_m *QueueRepository) Get(ctx context.Context, id string) (*domain.Queue, e return r0, r1 } +// GetMany provides a mock function with given fields: ctx, ids +func (_m *QueueRepository) GetMany(ctx context.Context, ids []string) (map[string]*domain.Queue, error) { + ret := _m.Called(ctx, ids) + + if len(ret) == 0 { + panic("no return value specified for GetMany") + } + + var r0 map[string]*domain.Queue + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []string) (map[string]*domain.Queue, error)); ok { + return rf(ctx, ids) + } + if rf, ok := ret.Get(0).(func(context.Context, []string) map[string]*domain.Queue); ok { + r0 = rf(ctx, ids) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]*domain.Queue) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok { + r1 = rf(ctx, ids) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // List provides a mock function with given fields: ctx, offset, limit func (_m *QueueRepository) List(ctx context.Context, offset uint, limit uint) ([]*domain.Queue, error) { ret := _m.Called(ctx, offset, limit) diff --git a/repository/message.go b/repository/message.go index 071925c..7695d31 100644 --- a/repository/message.go +++ b/repository/message.go @@ -68,14 +68,26 @@ func (m *Message) List(ctx context.Context, queue *domain.Queue, label *string, return nil, err } + if len(messages) == 0 { + return messages, tx.Commit(ctx) + } + + // Update messages in memory and collect IDs for batch database update + messageIDs := make([]string, len(messages)) for i := range messages { message := messages[i] - + // Update message object in memory so it reflects the correct state when returned message.DeliverySetup(queue, now) - if err := pgxutil.Update(ctx, tx, "", m.tableName, message.ID, &message); err != nil { - executeRollback(ctx, tx) - return nil, err - } + messageIDs[i] = message.ID + } + + // Batch update all messages in the database with a single query + // This updates the same fields that DeliverySetup modifies in memory + newScheduledAt := now.Add(time.Duration(queue.AckDeadlineSeconds) * time.Second) + sqlQuery := `UPDATE messages SET delivery_attempts = delivery_attempts + 1, scheduled_at = $1, updated_at = $2 WHERE id = ANY($3)` + if _, err := tx.Exec(ctx, sqlQuery, newScheduledAt, now, messageIDs); err != nil { + executeRollback(ctx, tx) + return nil, err } return messages, tx.Commit(ctx) diff --git a/repository/queue.go b/repository/queue.go index 9a94d1e..4c376c2 100644 --- a/repository/queue.go +++ b/repository/queue.go @@ -33,6 +33,27 @@ func (q *Queue) Get(ctx context.Context, id string) (*domain.Queue, error) { return &queue, parseError(err, domain.ErrQueueNotFound, domain.ErrQueueAlreadyExists) } +func (q *Queue) GetMany(ctx context.Context, ids []string) (map[string]*domain.Queue, error) { + if len(ids) == 0 { + return make(map[string]*domain.Queue), nil + } + + queues := []*domain.Queue{} + options := pgxutil.NewFindAllOptions().WithFilter("id.in", ids) + err := pgxutil.Select(ctx, q.pool, q.tableName, options, &queues) + if err != nil { + return nil, parseError(err, domain.ErrQueueNotFound, domain.ErrQueueAlreadyExists) + } + + // Convert slice to map for easy lookup + queueMap := make(map[string]*domain.Queue, len(queues)) + for _, queue := range queues { + queueMap[queue.ID] = queue + } + + return queueMap, nil +} + func (q *Queue) List(ctx context.Context, offset, limit uint) ([]*domain.Queue, error) { queues := []*domain.Queue{} options := pgxutil.NewFindAllOptions().WithOffset(int(offset)).WithLimit(int(limit)).WithOrderBy("id asc") diff --git a/repository/queue_test.go b/repository/queue_test.go index 4e7b21c..69c0e7c 100644 --- a/repository/queue_test.go +++ b/repository/queue_test.go @@ -89,6 +89,31 @@ func TestQueue(t *testing.T) { assert.ErrorIs(t, err, domain.ErrQueueNotFound) }) + t.Run("GetMany", func(t *testing.T) { + defer clearDatabase(t, ctx, pool) + + queueRepo := NewQueue(pool) + + err := queueRepo.Create(ctx, makeQueue("my-queue-1")) + assert.Nil(t, err) + err = queueRepo.Create(ctx, makeQueue("my-queue-2")) + assert.Nil(t, err) + err = queueRepo.Create(ctx, makeQueue("my-queue-3")) + assert.Nil(t, err) + + queues, err := queueRepo.GetMany(ctx, []string{"my-queue-1", "my-queue-3"}) + assert.Nil(t, err) + assert.Len(t, queues, 2) + assert.NotNil(t, queues["my-queue-1"]) + assert.NotNil(t, queues["my-queue-3"]) + assert.Nil(t, queues["my-queue-2"]) + + // Test with empty slice + queues, err = queueRepo.GetMany(ctx, []string{}) + assert.Nil(t, err) + assert.Len(t, queues, 0) + }) + t.Run("List", func(t *testing.T) { defer clearDatabase(t, ctx, pool) diff --git a/service/queue.go b/service/queue.go index 245c252..ed7c0f0 100644 --- a/service/queue.go +++ b/service/queue.go @@ -49,41 +49,19 @@ func (q *Queue) List(ctx context.Context, offset, limit uint) ([]*domain.Queue, } func (q *Queue) Delete(ctx context.Context, id string) error { - queue, err := q.queueRepository.Get(ctx, id) - if err != nil { - return err - } - - return q.queueRepository.Delete(ctx, queue.ID) - + return q.queueRepository.Delete(ctx, id) } func (q *Queue) Stats(ctx context.Context, id string) (*domain.QueueStats, error) { - queue, err := q.queueRepository.Get(ctx, id) - if err != nil { - return nil, err - } - - return q.queueRepository.Stats(ctx, queue.ID) - + return q.queueRepository.Stats(ctx, id) } func (q *Queue) Purge(ctx context.Context, id string) error { - queue, err := q.queueRepository.Get(ctx, id) - if err != nil { - return err - } - - return q.queueRepository.Purge(ctx, queue.ID) + return q.queueRepository.Purge(ctx, id) } func (q *Queue) Cleanup(ctx context.Context, id string) error { - queue, err := q.queueRepository.Get(ctx, id) - if err != nil { - return err - } - - return q.queueRepository.Cleanup(ctx, queue.ID) + return q.queueRepository.Cleanup(ctx, id) } // NewQueue returns an implementation of domain.QueueService. diff --git a/service/queue_test.go b/service/queue_test.go index 077fd34..6fd54f2 100644 --- a/service/queue_test.go +++ b/service/queue_test.go @@ -91,7 +91,6 @@ func TestQueue(t *testing.T) { queueService := NewQueue(queueRepository) queue := makeQueue("my-queue") - queueRepository.On("Get", ctx, queue.ID).Return(queue, nil) queueRepository.On("Delete", ctx, queue.ID).Return(nil) err := queueService.Delete(ctx, queue.ID) @@ -103,7 +102,6 @@ func TestQueue(t *testing.T) { queueService := NewQueue(queueRepository) queue := makeQueue("my-queue") - queueRepository.On("Get", ctx, queue.ID).Return(queue, nil) queueRepository.On("Stats", ctx, queue.ID).Return(&domain.QueueStats{}, nil) _, err := queueService.Stats(ctx, queue.ID) @@ -115,7 +113,6 @@ func TestQueue(t *testing.T) { queueService := NewQueue(queueRepository) queue := makeQueue("my-queue") - queueRepository.On("Get", ctx, queue.ID).Return(queue, nil) queueRepository.On("Purge", ctx, queue.ID).Return(nil) err := queueService.Purge(ctx, queue.ID) diff --git a/service/subscription.go b/service/subscription.go index 3bbbc64..cd7fd0c 100644 --- a/service/subscription.go +++ b/service/subscription.go @@ -31,12 +31,7 @@ func (s *Subscription) List(ctx context.Context, offset, limit uint) ([]*domain. } func (s *Subscription) Delete(ctx context.Context, id string) error { - subscription, err := s.subscriptionRepository.Get(ctx, id) - if err != nil { - return err - } - - return s.subscriptionRepository.Delete(ctx, subscription.ID) + return s.subscriptionRepository.Delete(ctx, id) } // NewSubscription returns an implementation of domain.SubscriptionService. diff --git a/service/subscription_test.go b/service/subscription_test.go index 7b89696..cf0897d 100644 --- a/service/subscription_test.go +++ b/service/subscription_test.go @@ -77,7 +77,6 @@ func TestSubscription(t *testing.T) { subscriptionService := NewSubscription(subscriptionRepository) subscription := makeSubscription("my-subscription", "my-topic", "my-queue") - subscriptionRepository.On("Get", ctx, subscription.ID).Return(subscription, nil) subscriptionRepository.On("Delete", ctx, subscription.ID).Return(nil) err := subscriptionService.Delete(ctx, subscription.ID) diff --git a/service/topic.go b/service/topic.go index 4630182..6ece344 100644 --- a/service/topic.go +++ b/service/topic.go @@ -34,12 +34,7 @@ func (t *Topic) List(ctx context.Context, offset, limit uint) ([]*domain.Topic, } func (t *Topic) Delete(ctx context.Context, id string) error { - topic, err := t.topicRepository.Get(ctx, id) - if err != nil { - return err - } - - return t.topicRepository.Delete(ctx, topic.ID) + return t.topicRepository.Delete(ctx, id) } func (t *Topic) CreateMessage(ctx context.Context, topicID string, message *domain.Message) error { @@ -67,15 +62,35 @@ func (t *Topic) CreateMessage(ctx context.Context, topicID string, message *doma break } + // Collect unique queue IDs to fetch in batch + queueIDs := make([]string, 0, len(subscriptions)) + seenQueues := make(map[string]bool) + for i := range subscriptions { + subscription := subscriptions[i] + if !subscription.ShouldCreateMessage(message) { + continue + } + if !seenQueues[subscription.QueueID] { + queueIDs = append(queueIDs, subscription.QueueID) + seenQueues[subscription.QueueID] = true + } + } + + // Fetch all queues in a single query + queues, err := t.queueRepository.GetMany(ctx, queueIDs) + if err != nil { + return err + } + for i := range subscriptions { subscription := subscriptions[i] if !subscription.ShouldCreateMessage(message) { continue } - queue, err := t.queueRepository.Get(ctx, subscription.QueueID) - if err != nil { - return err + queue, ok := queues[subscription.QueueID] + if !ok { + return domain.ErrQueueNotFound } newMessage := &domain.Message{ diff --git a/service/topic_test.go b/service/topic_test.go index 10d6b02..e51c9f3 100644 --- a/service/topic_test.go +++ b/service/topic_test.go @@ -91,7 +91,6 @@ func TestTopic(t *testing.T) { topicService := NewTopic(topicRepository, subscriptionRepository, queueRepository, messageRepository) topic := makeTopic("my-topic") - topicRepository.On("Get", ctx, topic.ID).Return(topic, nil) topicRepository.On("Delete", ctx, topic.ID).Return(nil) err := topicService.Delete(ctx, topic.ID) @@ -112,7 +111,7 @@ func TestTopic(t *testing.T) { topicRepository.On("Get", ctx, topic.ID).Return(topic, nil) subscriptionRepository.On("ListByTopic", ctx, topic.ID, uint(0), uint(50)).Return([]*domain.Subscription{subscription}, nil) subscriptionRepository.On("ListByTopic", ctx, topic.ID, uint(50), uint(50)).Return([]*domain.Subscription{}, nil) - queueRepository.On("Get", ctx, queue.ID).Return(queue, nil) + queueRepository.On("GetMany", ctx, []string{queue.ID}).Return(map[string]*domain.Queue{queue.ID: queue}, nil) messageRepository.On("CreateMany", ctx, mock.Anything).Return(nil) err := topicService.CreateMessage(ctx, topic.ID, message)