Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions adapters/pgclient/aurora.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"strconv"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
)

Expand Down Expand Up @@ -38,16 +37,17 @@ func (ac *auroraConnector) DSN(ctx context.Context) (string, error) {
}

type CredBuilder struct {
creds aws.CredentialsProvider
//creds aws.CredentialsProvider
configs map[string]*AuroraConfig
region string
//region string
provider AWSProvider
}

func NewCredBuilder(creds aws.CredentialsProvider, region string) *CredBuilder {
// func NewCredBuilder(creds aws.CredentialsProvider, region string) *CredBuilder {
func NewCredBuilder(provider AWSProvider) *CredBuilder {
cb := &CredBuilder{
creds: creds,
region: region,
configs: make(map[string]*AuroraConfig),
provider: provider,
configs: make(map[string]*AuroraConfig),
}

return cb
Expand Down Expand Up @@ -105,8 +105,14 @@ func (cb *CredBuilder) NewToken(ctx context.Context, lookupName string) (string,
}

func (cb *CredBuilder) newToken(ctx context.Context, config *AuroraConfig) (string, error) {
region := cb.provider.Region()
creds, err := cb.provider.Credentials(ctx)
if err != nil {
return "", fmt.Errorf("failed to get aws credentials: %w", err)
}

authenticationToken, err := auth.BuildAuthToken(
ctx, fmt.Sprintf("%s:%d", config.Endpoint, config.Port), cb.region, config.DBUser, cb.creds)
ctx, fmt.Sprintf("%s:%d", config.Endpoint, config.Port), region, config.DBUser, creds)
if err != nil {
return "", fmt.Errorf("failed to create authentication token: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions adapters/pgclient/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func looksLikeJSONString(name string) bool {
var reValidDBEnvName = regexp.MustCompile(`^[A-Z0-9_]+$`)

type AWSProvider interface {
Credentials() aws.CredentialsProvider
Credentials(context.Context) (aws.CredentialsProvider, error)
Region() string
}

Expand Down Expand Up @@ -82,7 +82,7 @@ func (ss *pgConnSet) aurora(name string, config *AuroraConfig) (PGConnector, err
}

if ss.credBuilder == nil {
ss.credBuilder = NewCredBuilder(ss.awsProvider.Credentials(), ss.awsProvider.Region())
ss.credBuilder = NewCredBuilder(ss.awsProvider)
}

if err := ss.credBuilder.AddConfig(name, config); err != nil {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package awsmsg
package sqsmsg

import (
"bytes"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package awsmsg
package sqsmsg

import (
"testing"
Expand Down
137 changes: 10 additions & 127 deletions apps/queueworker/sqslink/worker.go → adapters/sqsmsg/worker.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package sqslink
package sqsmsg

import (
"context"
"crypto/rand"
"fmt"
"math/big"
"strconv"

"github.com/aws/aws-sdk-go-v2/service/sqs"
Expand All @@ -13,113 +11,32 @@ import (
"github.com/pentops/log.go/log"
"github.com/pentops/o5-messaging/gen/o5/messaging/v1/messaging_pb"
"github.com/pentops/o5-messaging/gen/o5/messaging/v1/messaging_tpb"
"github.com/pentops/o5-runtime-sidecar/apps/queueworker/awsmsg"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/pentops/o5-runtime-sidecar/apps/queueworker/messaging"
)

const RawMessageName = "/o5.messaging.v1.topic.RawMessageTopic/Raw"
const GenericTopic = "/o5.messaging.v1.topic.GenericMessageTopic/Generic"

type SQSAPI interface {
ReceiveMessage(ctx context.Context, input *sqs.ReceiveMessageInput, opts ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error)
DeleteMessage(ctx context.Context, input *sqs.DeleteMessageInput, opts ...func(*sqs.Options)) (*sqs.DeleteMessageOutput, error)
}

type Handler interface {
HandleMessage(context.Context, *messaging_pb.Message) error
}

type HandlerFunc func(context.Context, *messaging_pb.Message) error

func (hf HandlerFunc) HandleMessage(ctx context.Context, msg *messaging_pb.Message) error {
return hf(ctx, msg)
}

type Worker struct {
router messaging.Handler
SQSClient SQSAPI
QueueURL string
deadLetterHandler DeadLetterHandler
resendChance int

handlers map[string]Handler
fallbackHandler Handler
}

// Is this message is randomly selected based on percent received?
func randomlySelected(ctx context.Context, pct int) bool {
if pct == 0 {
return false
}

if pct == 100 {
return true
}

if pct > 100 || pct < 0 {
log.Infof(ctx, "Received invalid percent for randomly selecting a message: %v", pct)
return false
}

r, err := rand.Int(rand.Reader, big.NewInt(100))
if err != nil {
log.WithError(ctx, err).Error("couldn't generate random number for selecting message")
return false
}

if r.Int64() <= big.NewInt(int64(pct)).Int64() {
log.Infof(ctx, "Message randomly selected: rand of %v and percent of %v", r.Int64(), pct)
return true
}
return false
deadLetterHandler messaging.DeadLetterHandler
}

func NewWorker(sqs SQSAPI, queueURL string, deadLetters DeadLetterHandler, resendChance int) *Worker {
func NewWorker(sqs SQSAPI, queueURL string, deadLetters messaging.DeadLetterHandler, handler messaging.Handler) *Worker {
return &Worker{
SQSClient: sqs,
QueueURL: queueURL,
handlers: make(map[string]Handler),
router: handler,
deadLetterHandler: deadLetters,
resendChance: resendChance,
}
}

func (ww *Worker) RegisterService(ctx context.Context, service protoreflect.ServiceDescriptor, invoker AppLink) error {
methods := service.Methods()
for ii := 0; ii < methods.Len(); ii++ {
method := methods.Get(ii)
if err := ww.registerMethod(ctx, method, invoker); err != nil {
return err
}
}
return nil
}

func (ww *Worker) registerMethod(ctx context.Context, method protoreflect.MethodDescriptor, invoker AppLink) error {
serviceName := method.Parent().(protoreflect.ServiceDescriptor).FullName()
fullName := fmt.Sprintf("/%s/%s", serviceName, method.Name())

if fullName == GenericTopic {
log.WithField(ctx, "service", fullName).Info("Registering Generic Fallback")
ww.fallbackHandler = &genericHandler{
invoker: invoker,
}

} else {
log.WithField(ctx, "service", fullName).Info("Registering Worker Service")
ss := &service{
requestMessage: method.Input(),
fullName: fullName,
invoker: invoker,
}
ww.handlers[ss.fullName] = ss
}
return nil
}

func (ww *Worker) RegisterHandler(fullMethod string, handler Handler) {
ww.handlers[fullMethod] = handler
}

func (ww *Worker) Run(ctx context.Context) error {
for {
if err := ww.FetchOnce(ctx); err != nil {
Expand All @@ -143,7 +60,7 @@ func (ww *Worker) FetchOnce(ctx context.Context) error {
// retrieve requests after being retrieved by a ReceiveMessage request.
VisibilityTimeout: 30,

MessageAttributeNames: awsmsg.SQSMessageAttributes,
MessageAttributeNames: SQSMessageAttributes,

AttributeNames: []types.QueueAttributeName{
// this type conversion is probably a bug in the SDK
Expand All @@ -156,9 +73,6 @@ func (ww *Worker) FetchOnce(ctx context.Context) error {

for _, msg := range out.Messages {
ww.handleMessage(ctx, msg)
if randomlySelected(ctx, ww.resendChance) {
ww.handleMessage(ctx, msg)
}
}
return nil
}
Expand All @@ -177,7 +91,7 @@ func getReceiveCount(msg types.Message) int {
}

func (ww *Worker) handleMessage(ctx context.Context, msg types.Message) {
parsed, err := awsmsg.ParseSQSMessage(msg)
parsed, err := ParseSQSMessage(msg)
if err != nil {
// Leave it for retry unless we keep failing at parsing it
log.WithError(ctx, err).Error("Message Worker: Failed to parse message")
Expand All @@ -196,40 +110,9 @@ func (ww *Worker) handleMessage(ctx context.Context, msg types.Message) {
return
}

ctx = log.WithFields(ctx, map[string]any{
"grpc-service": parsed.GrpcService,
"grpc-method": parsed.GrpcMethod,
"message-id": parsed.MessageId,
"topic": parsed.DestinationTopic,
"sqs-message-id": msg.MessageId,
})
log.Debug(ctx, "Message Handler: Begin")

fullServiceName := fmt.Sprintf("/%s/%s", parsed.GrpcService, parsed.GrpcMethod)
handler, ok := ww.handlers[fullServiceName]
if !ok {
if ww.fallbackHandler != nil {
log.Debug(ctx, "Message Handler: Using fallback handler")
handler = ww.fallbackHandler
} else {
log.Error(ctx, "no handler matched")
if ww.deadLetterHandler == nil && getReceiveCount(msg) <= 3 {
log.Error(ctx, "Error handling message, leaving in queue")
return
}
err := ww.killMessage(ctx, msg, parsed, fmt.Errorf("no handler for %s", fullServiceName))
if err != nil {
log.WithField(ctx, "killError", err.Error()).
Error("Message Worker: Error killing message, leaving in queue")
return
}
log.Debug(ctx, "Message Handler: Killed")
return
}
}

err = handler.HandleMessage(ctx, parsed)
ctx = log.WithField(ctx, "sqs-message-id", msg.MessageId)

err = ww.router.HandleMessage(ctx, parsed)
if err != nil {
ctx = log.WithError(ctx, err)
log.Error(ctx, "Message Handler: Error")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package sqslink
package messaging

import (
"context"
Expand Down
62 changes: 62 additions & 0 deletions apps/queueworker/messaging/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package messaging

import (
"context"
"crypto/rand"
"math/big"

"github.com/pentops/log.go/log"
"github.com/pentops/o5-messaging/gen/o5/messaging/v1/messaging_pb"
)

type ResendHandler struct {
resendChance int
handler Handler
}

func NewResendHandler(handler Handler, resendChance int) *ResendHandler {
return &ResendHandler{
handler: handler,
resendChance: resendChance,
}
}

func (ww *ResendHandler) HandleMessage(ctx context.Context, msg *messaging_pb.Message) error {
if err := ww.handler.HandleMessage(ctx, msg); err != nil {
return err
}
if randomlySelected(ctx, ww.resendChance) {
if err := ww.handler.HandleMessage(ctx, msg); err != nil {
return err
}
}
return nil
}

// Is this message is randomly selected based on percent received?
func randomlySelected(ctx context.Context, pct int) bool {
if pct == 0 {
return false
}

if pct == 100 {
return true
}

if pct > 100 || pct < 0 {
log.Infof(ctx, "Received invalid percent for randomly selecting a message: %v", pct)
return false
}

r, err := rand.Int(rand.Reader, big.NewInt(100))
if err != nil {
log.WithError(ctx, err).Error("couldn't generate random number for selecting message")
return false
}

if r.Int64() <= big.NewInt(int64(pct)).Int64() {
log.Infof(ctx, "Message randomly selected: rand of %v and percent of %v", r.Int64(), pct)
return true
}
return false
}
Loading