diff --git a/middleware/recover/recover.go b/middleware/recover/recover.go index b6df643..75adf85 100644 --- a/middleware/recover/recover.go +++ b/middleware/recover/recover.go @@ -33,7 +33,7 @@ func (o Middleware) PublisherMsgInterceptor(serviceName string, next pubsub.Publ return func(ctx context.Context, topic string, m *pubsub.Msg) (err error) { defer func() { if r := recover(); r != nil { - err = recoverFrom(r, "pubsub: publish error", o.RecoveryHandlerFunc) + err = recoverFrom(r, "pubsub: publisher panic \n", o.RecoveryHandlerFunc) } }() err = next(ctx, topic, m) @@ -43,8 +43,16 @@ func (o Middleware) PublisherMsgInterceptor(serviceName string, next pubsub.Publ func recoverFrom(p interface{}, wrap string, r RecoveryHandlerFunc) error { if r == nil { - return errors.Wrap(p.(error), wrap) + var e error + switch val := p.(type) { + case string: + e = errors.New(val) + case error: + e = val + default: + e = errors.New("unknown error occurred") + } + return errors.Wrap(e, wrap) } - return r(p) } diff --git a/middleware/recover/recover_test.go b/middleware/recover/recover_test.go index 592eecf..5e84d92 100644 --- a/middleware/recover/recover_test.go +++ b/middleware/recover/recover_test.go @@ -17,17 +17,40 @@ type TestSubscriber struct { T *testing.T } -func (ts *TestSubscriber) DoSomething(ctx context.Context, t *test.Account, msg *pubsub.Msg) error { +func (ts *TestSubscriber) PanicWithError(ctx context.Context, t *test.Account, msg *pubsub.Msg) error { assert.True(ts.T, len(msg.Data) > 0) - panic(errors.New("ahhhhhhhh")) - return nil + panic(errors.New("this is an error")) +} + +func (ts *TestSubscriber) PanicWithString(ctx context.Context, t *test.Account, msg *pubsub.Msg) error { + assert.True(ts.T, len(msg.Data) > 0) + panic("this is a panic") +} + +func (ts *TestSubscriber) PanicUnknown(ctx context.Context, t *test.Account, msg *pubsub.Msg) error { + assert.True(ts.T, len(msg.Data) > 0) + panic(struct{}{}) } func (ts *TestSubscriber) Setup(c *pubsub.Client) { c.On(pubsub.HandlerOptions{ - Topic: "test_topic", - Name: "do_something", - Handler: ts.DoSomething, + Topic: "with_error", + Name: "test", + Handler: ts.PanicWithError, + JSON: ts.JSON, + }) + + c.On(pubsub.HandlerOptions{ + Topic: "with_string", + Name: "test", + Handler: ts.PanicWithString, + JSON: ts.JSON, + }) + + c.On(pubsub.HandlerOptions{ + Topic: "with_unknown", + Name: "test", + Handler: ts.PanicUnknown, JSON: ts.JSON, }) } @@ -51,7 +74,13 @@ func TestRecoverMiddleware(t *testing.T) { Name: "smth", } - err := c.Publish(context.Background(), "test_topic", &ps, false) + err := c.Publish(context.Background(), "with_error", &ps, false) + assert.Nil(t, err) + + err = c.Publish(context.Background(), "with_string", &ps, false) + assert.Nil(t, err) + + err = c.Publish(context.Background(), "with_unknown", &ps, false) assert.Nil(t, err) ts := TestSubscriber{T: t} diff --git a/providers/memory/memory.go b/providers/memory/memory.go index f65cb70..f30b867 100644 --- a/providers/memory/memory.go +++ b/providers/memory/memory.go @@ -3,40 +3,79 @@ package memory import ( "context" "fmt" - "github.com/lileio/pubsub/v2" + "sync" ) type MemoryProvider struct { - Msgs map[string][]*pubsub.Msg - ErrorHandler func(err error) + Msgs *sync.Map + Subscribers *sync.Map + Errors chan error } -func (mp *MemoryProvider) Publish(ctx context.Context, topic string, m *pubsub.Msg) error { - if mp.Msgs == nil { - mp.Msgs = make(map[string][]*pubsub.Msg, 0) +func NewMemoryProvider() *MemoryProvider { + mp := &MemoryProvider{ + &sync.Map{}, + &sync.Map{}, + make(chan error, 101), + } + go mp.ProcessErrors() + return mp +} + +func (mp *MemoryProvider) ProcessErrors() { + for err := range mp.Errors { + fmt.Println(err) } +} - mp.Msgs[topic] = append(mp.Msgs[topic], m) +func (mp *MemoryProvider) SetupTopic(topic string) { + if _, ok := mp.Msgs.Load(topic); !ok { + mp.Msgs.Store(topic, make(chan *pubsub.Msg, 100)) + mp.Subscribers.Store(topic, make([]pubsub.MsgHandler, 0, 0)) + go mp.process(topic) + } +} +func (mp *MemoryProvider) Publish(ctx context.Context, topic string, m *pubsub.Msg) error { + if _, ok := mp.Msgs.Load(topic); !ok { + mp.SetupTopic(topic) + } + c, _ := mp.Msgs.Load(topic) + c.(chan *pubsub.Msg) <- m return nil } func (mp *MemoryProvider) Subscribe(opts pubsub.HandlerOptions, h pubsub.MsgHandler) { - for _, v := range mp.Msgs[opts.Topic] { - err := h(context.Background(), *v) - - if err != nil { - if mp.ErrorHandler != nil { - mp.ErrorHandler(err) - } else { - fmt.Print(err.Error()) + topic := opts.Topic + if _, ok := mp.Subscribers.Load(topic); !ok { + mp.SetupTopic(topic) + } + s, _ := mp.Subscribers.Load(topic) + mp.Subscribers.Store(topic, append(s.([]pubsub.MsgHandler), h)) + return +} + +func (mp *MemoryProvider) process(topic string) { + var err error + c, _ := mp.Msgs.Load(topic) + for msg := range c.(chan *pubsub.Msg) { + s, _ := mp.Subscribers.Load(topic) + for _, handler := range s.([]pubsub.MsgHandler) { + err = handler(context.Background(), *msg) + if err != nil { + mp.Errors <- err } } } - return } func (mp *MemoryProvider) Shutdown() { + + mp.Msgs.Range(func(k, v interface{}) bool { + close(v.(chan *pubsub.Msg)) + return true + }) + close(mp.Errors) return }