diff --git a/go.mod b/go.mod index 0deca6ea..676f4cd5 100644 --- a/go.mod +++ b/go.mod @@ -11,4 +11,11 @@ require ( golang.org/x/sync v0.6.0 ) -require golang.org/x/net v0.20.0 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect + github.com/stretchr/testify v1.8.4 // indirect + golang.org/x/net v0.20.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 3372973e..54dee94a 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,27 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/eclipse/paho.mqtt.golang v1.4.3 h1:2kwcUGn8seMUfWndX0hGbvH8r7crgcJguQNCyp70xik= github.com/eclipse/paho.mqtt.golang v1.4.3/go.mod h1:CSYvoAlsMkhYOXh/oKyxa8EcBci6dVkLCbo5tTC1RIE= github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go new file mode 100644 index 00000000..75912788 --- /dev/null +++ b/pkg/http/http_test.go @@ -0,0 +1,131 @@ +package http_test + +import ( + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/absmach/mproxy/pkg/session/mocks" + + mhttp "github.com/absmach/mproxy/pkg/http" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const ( + validUrl = "http://example.com" + validAddess = "localhost:8080" + valid = "valid" + invalid = "invalid" +) + +var ( + han *mocks.Handler + log *slog.Logger +) + +func newProxy(address, url string) (mhttp.Proxy, error) { + han = new(mocks.Handler) + log = new(slog.Logger) + return mhttp.NewProxy(address, url, han, log) +} + +func TestNewProxy(t *testing.T) { + cases := []struct { + desc string + address string + url string + err error + }{ + { + desc: "create proxy with valid", + address: validAddess, + url: validUrl, + err: nil, + }, + { + desc: "create proxy with invalid url", + address: validAddess, + url: "0000", + err: nil, + }, + } + for _, c := range cases { + _, err := newProxy(c.address, c.url) + assert.Equal(t, c.err, err, fmt.Sprintf("%s: expected %s got %s\n", c.desc, c.err, err)) + + } +} + +func TestHandler(t *testing.T) { + proxy, err := newProxy(validAddess, validUrl) + assert.Nil(t, err, fmt.Sprintf("expected nil got %s\n", err)) + request := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + + cases := []struct { + desc string + auth func() + authConnectErr error + authPublishErr error + code int + }{ + { + desc: "successful request with username and password and basic auth", + auth: func() { + request.SetBasicAuth("username", "password") + }, + code: http.StatusOK, + }, + { + desc: "successful request with token", + auth: func() { + request.Header.Set("Authorization", valid) + }, + code: http.StatusOK, + }, + { + desc: "request without authorization token", + auth: func() { + request.Header.Set("Authorization", "") + }, + code: http.StatusBadGateway, + }, + } + for _, tc := range cases { + tc.auth() + responseRecorder := httptest.NewRecorder() + sessionCall := han.On("AuthConnect", mock.Anything).Return(tc.authConnectErr) + sessionCall1 := han.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(tc.authPublishErr) + proxy.Handler(responseRecorder, request) + assert.Equal(t, tc.code, responseRecorder.Code, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.code, responseRecorder.Code)) + sessionCall.Unset() + sessionCall1.Unset() + + } +} + +func TestListen(t *testing.T) { + proxy, err := newProxy(validAddess, validUrl) + assert.Nil(t, err, fmt.Sprintf("expected nil got %s\n", err)) + + t.Run("Listen", func(t *testing.T) { + go func() { + err := proxy.Listen() + assert.Nil(t, err, fmt.Sprintf("expected nil got %s\n", err)) + }() + }) +} + +func TestListenTLS(t *testing.T) { + proxy, err := newProxy(validAddess, validUrl) + assert.Nil(t, err, fmt.Sprintf("expected nil got %s\n", err)) + + t.Run("ListenTLS", func(t *testing.T) { + go func() { + err := proxy.ListenTLS("cert", "key") + assert.Nil(t, err, fmt.Sprintf("expected nil got %s\n", err)) + }() + }) +} diff --git a/pkg/mqtt/mqtt_test.go b/pkg/mqtt/mqtt_test.go new file mode 100644 index 00000000..ac25be11 --- /dev/null +++ b/pkg/mqtt/mqtt_test.go @@ -0,0 +1,79 @@ +package mqtt_test + +import ( + "context" + "crypto/tls" + "fmt" + "testing" + + "github.com/absmach/mproxy/pkg/mqtt" + "github.com/absmach/mproxy/pkg/session/mocks" + "github.com/stretchr/testify/assert" +) + +func newProxy(address, target string) *mqtt.Proxy { + handler := new(mocks.Handler) + interceptor := new(mocks.Interceptor) + return mqtt.New(address, target, handler, interceptor, nil) +} + +var tlsConfig = &tls.Config{} + +func TestListen(t *testing.T) { + cases := []struct { + desc string + address string + target string + err error + }{ + { + desc: "listen with valid address", + address: "localhost:8080", + target: "localhost:8080", + err: nil, + }, + // { + // desc: "listen with invalid address", + // address: "0000", + // target: "localhost:8080", + // err: nil, + // }, + } + for _, c := range cases { + proxy := newProxy(c.address, c.target) + go func() { + err := proxy.Listen(context.Background()) + assert.Nil(t, err, fmt.Sprintf("expected nil got %s\n", err)) + }() + } +} + +func TestListenTLS(t *testing.T) { + cases := []struct { + desc string + address string + target string + err error + }{ + { + desc: "listen with valid address", + address: "localhost:8080", + target: "localhost:8080", + err: nil, + }, + // { + // desc: "listen with invalid address", + // address: "0000", + // target: "localhost:8080", + // err: nil, + // }, + } + for _, c := range cases { + + proxy := newProxy(c.address, c.target) + go func() { + err := proxy.ListenTLS(context.Background(), tlsConfig) + assert.Nil(t, err, fmt.Sprintf("expected nil got %s\n", err)) + }() + } +} diff --git a/pkg/mqtt/websocket/websocket_test.go b/pkg/mqtt/websocket/websocket_test.go new file mode 100644 index 00000000..dec0960c --- /dev/null +++ b/pkg/mqtt/websocket/websocket_test.go @@ -0,0 +1,39 @@ +package websocket_test + +import ( + "log/slog" + "net/http/httptest" + "testing" + + "github.com/absmach/mproxy/pkg/mqtt/websocket" + "github.com/absmach/mproxy/pkg/session/mocks" +) + +func newProxy(target, path, scheme string) *websocket.Proxy { + handler := new(mocks.Handler) + interceptor := new(mocks.Interceptor) + logger := new(slog.Logger) + return websocket.New(target, path, scheme, handler, interceptor, logger) +} + +func TestHandler(t *testing.T) { + cases := []struct { + desc string + target string + path string + scheme string + }{ + { + desc: "handler with valid target", + target: "localhost:8080", + path: "/", + scheme: "ws", + }, + } + for _, c := range cases { + proxy := newProxy(c.target, c.path, c.scheme) + responseRecorder := httptest.NewRecorder() + request := httptest.NewRequest("GET", "http://example.com", nil) + proxy.Handler().ServeHTTP(responseRecorder, request) + } +} diff --git a/pkg/session/handler.go b/pkg/session/handler.go index a58bc6af..d5838fb7 100644 --- a/pkg/session/handler.go +++ b/pkg/session/handler.go @@ -3,6 +3,8 @@ package session import "context" // Handler is an interface for mProxy hooks + +//go:generate mockery --name Handler --filename handler.go --quiet --note "Copyright (c) Abstract Machines" type Handler interface { // Authorization on client `CONNECT` // Each of the params are passed by reference, so that it can be changed diff --git a/pkg/session/interceptor.go b/pkg/session/interceptor.go index e69c10c0..62949651 100644 --- a/pkg/session/interceptor.go +++ b/pkg/session/interceptor.go @@ -7,6 +7,8 @@ import ( ) // Interceptor is an interface for mProxy intercept hook. + +//go:generate mockery --name Interceptor --filename interceptor.go --quiet --note "Copyright (c) Abstract Machines" type Interceptor interface { // Intercept is called on every packet flowing through the Proxy. // Packets can be modified before being sent to the broker or the client. diff --git a/pkg/session/mocks/handler.go b/pkg/session/mocks/handler.go new file mode 100644 index 00000000..3dc37103 --- /dev/null +++ b/pkg/session/mocks/handler.go @@ -0,0 +1,174 @@ +// Code generated by mockery v2.38.0. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// Handler is an autogenerated mock type for the Handler type +type Handler struct { + mock.Mock +} + +// AuthConnect provides a mock function with given fields: ctx +func (_m *Handler) AuthConnect(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for AuthConnect") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// AuthPublish provides a mock function with given fields: ctx, topic, payload +func (_m *Handler) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error { + ret := _m.Called(ctx, topic, payload) + + if len(ret) == 0 { + panic("no return value specified for AuthPublish") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *string, *[]byte) error); ok { + r0 = rf(ctx, topic, payload) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// AuthSubscribe provides a mock function with given fields: ctx, topics +func (_m *Handler) AuthSubscribe(ctx context.Context, topics *[]string) error { + ret := _m.Called(ctx, topics) + + if len(ret) == 0 { + panic("no return value specified for AuthSubscribe") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *[]string) error); ok { + r0 = rf(ctx, topics) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Connect provides a mock function with given fields: ctx +func (_m *Handler) Connect(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Connect") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Disconnect provides a mock function with given fields: ctx +func (_m *Handler) Disconnect(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Disconnect") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Publish provides a mock function with given fields: ctx, topic, payload +func (_m *Handler) Publish(ctx context.Context, topic *string, payload *[]byte) error { + ret := _m.Called(ctx, topic, payload) + + if len(ret) == 0 { + panic("no return value specified for Publish") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *string, *[]byte) error); ok { + r0 = rf(ctx, topic, payload) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Subscribe provides a mock function with given fields: ctx, topics +func (_m *Handler) Subscribe(ctx context.Context, topics *[]string) error { + ret := _m.Called(ctx, topics) + + if len(ret) == 0 { + panic("no return value specified for Subscribe") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *[]string) error); ok { + r0 = rf(ctx, topics) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Unsubscribe provides a mock function with given fields: ctx, topics +func (_m *Handler) Unsubscribe(ctx context.Context, topics *[]string) error { + ret := _m.Called(ctx, topics) + + if len(ret) == 0 { + panic("no return value specified for Unsubscribe") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *[]string) error); ok { + r0 = rf(ctx, topics) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewHandler creates a new instance of Handler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewHandler(t interface { + mock.TestingT + Cleanup(func()) +}) *Handler { + mock := &Handler{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/session/mocks/interceptor.go b/pkg/session/mocks/interceptor.go new file mode 100644 index 00000000..63d2c41d --- /dev/null +++ b/pkg/session/mocks/interceptor.go @@ -0,0 +1,63 @@ +// Code generated by mockery v2.38.0. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + packets "github.com/eclipse/paho.mqtt.golang/packets" + mock "github.com/stretchr/testify/mock" + + session "github.com/absmach/mproxy/pkg/session" +) + +// Interceptor is an autogenerated mock type for the Interceptor type +type Interceptor struct { + mock.Mock +} + +// Intercept provides a mock function with given fields: ctx, pkt, dir +func (_m *Interceptor) Intercept(ctx context.Context, pkt packets.ControlPacket, dir session.Direction) (packets.ControlPacket, error) { + ret := _m.Called(ctx, pkt, dir) + + if len(ret) == 0 { + panic("no return value specified for Intercept") + } + + var r0 packets.ControlPacket + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, packets.ControlPacket, session.Direction) (packets.ControlPacket, error)); ok { + return rf(ctx, pkt, dir) + } + if rf, ok := ret.Get(0).(func(context.Context, packets.ControlPacket, session.Direction) packets.ControlPacket); ok { + r0 = rf(ctx, pkt, dir) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(packets.ControlPacket) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, packets.ControlPacket, session.Direction) error); ok { + r1 = rf(ctx, pkt, dir) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewInterceptor creates a new instance of Interceptor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewInterceptor(t interface { + mock.TestingT + Cleanup(func()) +}) *Interceptor { + mock := &Interceptor{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/tls/tls_test.go b/pkg/tls/tls_test.go new file mode 100644 index 00000000..b17b7efd --- /dev/null +++ b/pkg/tls/tls_test.go @@ -0,0 +1,101 @@ +package tls_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "testing" + "time" + + mptls "github.com/absmach/mproxy/pkg/tls" +) + +func createTempFile(content []byte, t *testing.T) string { + tmpfile, err := os.CreateTemp("", "test") + if err != nil { + t.Fatalf("Failed to create temp file: %s", err) + } + + if _, err := tmpfile.Write(content); err != nil { + t.Fatalf("Failed to write to temp file: %s", err) + } + + if err := tmpfile.Close(); err != nil { + t.Fatalf("Failed to close temp file: %s", err) + } + + return tmpfile.Name() +} + +func generateDummyCert(t *testing.T) ([]byte, []byte) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("Failed to generate private key: %s", err) + } + + cert := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("Failed to create certificate: %s", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) + keyBytes, err := x509.MarshalECPrivateKey(priv) + if err != nil { + t.Fatalf("Failed to marshal private key: %s", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyBytes}) + + return certPEM, keyPEM +} + +func TestLoadTLSCfg(t *testing.T) { + certPEM, keyPEM := generateDummyCert(t) + + caFile := createTempFile(certPEM, t) + defer os.Remove(caFile) + + certFile := createTempFile(certPEM, t) + defer os.Remove(certFile) + + keyFile := createTempFile(keyPEM, t) + defer os.Remove(keyFile) + + tests := []struct { + name string + ca string + crt string + key string + wantErr bool + }{ + {"ValidConfig", caFile, certFile, keyFile, false}, + {"InvalidCAFile", "invalid_ca.pem", certFile, keyFile, true}, + {"InvalidCertFile", caFile, "invalid_cert.pem", keyFile, true}, + {"InvalidKeyFile", caFile, certFile, "invalid_key.pem", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := mptls.LoadTLSCfg(tt.ca, tt.crt, tt.key) + if (err != nil) != tt.wantErr { + t.Errorf("LoadTLSCfg() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/websockets/websockets.go b/pkg/websockets/websockets.go index 47e67cad..18d66a70 100644 --- a/pkg/websockets/websockets.go +++ b/pkg/websockets/websockets.go @@ -22,6 +22,7 @@ type Proxy struct { address string event session.Handler logger *slog.Logger + server *http.Server } func (p *Proxy) Handler(w http.ResponseWriter, r *http.Request) { @@ -40,6 +41,7 @@ func (p *Proxy) Handler(w http.ResponseWriter, r *http.Request) { target := fmt.Sprintf("%s%s", p.target, r.RequestURI) + fmt.Println("target: ", target) targetConn, _, err := websocket.DefaultDialer.Dial(target, headers) if err != nil { http.Error(w, err.Error(), http.StatusBadGateway) @@ -114,10 +116,25 @@ func NewProxy(address, target string, logger *slog.Logger, handler session.Handl // Listen - listen withrout tls. func (p *Proxy) Listen() error { - return http.ListenAndServe(p.address, http.HandlerFunc(p.Handler)) + p.server = &http.Server{ + Addr: p.address, + Handler: http.HandlerFunc(p.Handler), + } + return p.server.ListenAndServe() } // ListenTLS - version of Listen with TLS encryption. func (p Proxy) ListenTLS(crt, key string) error { - return http.ListenAndServeTLS(p.address, crt, key, http.HandlerFunc(p.Handler)) + p.server = &http.Server{ + Addr: p.address, + Handler: http.HandlerFunc(p.Handler), + } + return p.server.ListenAndServeTLS( crt, key) } + +func (p *Proxy) Shutdown(ctx context.Context) error { + if p.server != nil { + return p.server.Shutdown(ctx) + } + return nil +} \ No newline at end of file diff --git a/pkg/websockets/websockets_test.go b/pkg/websockets/websockets_test.go new file mode 100644 index 00000000..e1d9cbcd --- /dev/null +++ b/pkg/websockets/websockets_test.go @@ -0,0 +1,156 @@ +package websockets_test + +import ( + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/absmach/mproxy/pkg/session/mocks" + ws "github.com/absmach/mproxy/pkg/websockets" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + handler = new(mocks.Handler) + logger = new(slog.Logger) +) + +func createMockServer(t *testing.T) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + assert.Nil(t, err, fmt.Sprintf("Unexpected error upgrading connection: %v", err)) + defer conn.Close() + })) + + return server +} + +func TestHandlerSuccess(t *testing.T) { + mockServer := createMockServer(t) + defer mockServer.Close() + mockServerURL := "ws" + strings.TrimPrefix(mockServer.URL, "http") + + proxy, err := ws.NewProxy("ws://example.com", mockServerURL, logger, handler) + assert.Nil(t, err, fmt.Sprintf("Unexpected error creating proxy: %v", err)) + testServer := httptest.NewServer(http.HandlerFunc(proxy.Handler)) + defer testServer.Close() + proxyURL := "ws" + strings.TrimPrefix(testServer.URL, "http") + "/test" + + cases := []struct { + desc string + url string + header http.Header + authConnectErr error + authSubscribeErr error + subscribeErr error + status int + }{ + { + desc: "successfull connection with authorization in query", + url: proxyURL + "?authorization=valid", + header: http.Header{}, + authConnectErr: nil, + authSubscribeErr: nil, + subscribeErr: nil, + status: http.StatusSwitchingProtocols, + }, + { + desc: "successfull connection with authorization in header", + url: proxyURL, + header: http.Header{ + "Authorization": []string{"valid-token"}, + }, + authConnectErr: nil, + authSubscribeErr: nil, + subscribeErr: nil, + status: http.StatusSwitchingProtocols, + }, + { + desc: "unsuccesful connection with no authorization", + url: proxyURL, + header: http.Header{}, + status: http.StatusUnauthorized, + }, + { + desc: "unsuccesful connection with failed session auth connect", + url: proxyURL, + header: http.Header{ + "Authorization": []string{"valid-token"}, + }, + authConnectErr: fmt.Errorf("failed auth connect"), + authSubscribeErr: nil, + subscribeErr: nil, + status: http.StatusUnauthorized, + }, + { + desc: "unsuccesful connection with failed session auth subscribe", + url: proxyURL, + header: http.Header{ + "Authorization": []string{"valid-token"}, + }, + authConnectErr: nil, + authSubscribeErr: fmt.Errorf("failed auth subscribe"), + subscribeErr: nil, + status: http.StatusUnauthorized, + }, + { + desc: "unsuccesful connection with failed session subscribe", + url: proxyURL, + header: http.Header{ + "Authorization": []string{"valid-token"}, + }, + authConnectErr: nil, + authSubscribeErr: nil, + subscribeErr: fmt.Errorf("failed subscribe"), + status: http.StatusBadRequest, + }, + } + for _, tc := range cases { + sessionCall := handler.On("AuthConnect", mock.Anything).Return(tc.authConnectErr) + sessionCall1 := handler.On("AuthSubscribe", mock.Anything, mock.Anything).Return(tc.authSubscribeErr) + sessionCall2 := handler.On("Subscribe", mock.Anything, mock.Anything).Return(tc.subscribeErr) + _, res, _ := websocket.DefaultDialer.Dial(tc.url, tc.header) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s expected status code %d but got %d", tc.desc, tc.status, res.StatusCode)) + sessionCall.Unset() + sessionCall1.Unset() + sessionCall2.Unset() + } +} + +// func TestListen(t *testing.T) { +// proxy, err := ws.NewProxy("localhost:8080", "ws://127.0.0.1", logger, handler) +// assert.NoError(t, err) +// go func() { +// err := proxy.Listen() +// assert.Nil(t, err, fmt.Sprintf("Unexpected error listening: %v", err)) +// }() +// time.Sleep(100 * time.Millisecond) + +// req, err := http.NewRequest("GET", "http://example.com", nil) +// assert.Nil(t, err, fmt.Sprintf("Unexpected error creating request: %v", err)) +// req.Header.Set("Authorization", "valid-token") +// rr := httptest.NewRecorder() +// handler := http.HandlerFunc(proxy.Handler) +// handler.ServeHTTP(rr, req) +// assert.Equal(t, http.StatusBadGateway, rr.Code, "Expected status code 502 but got %d", rr.Code) + +// // err = proxy.Shutdown(context.Background()) +// // assert.Nil(t, err, fmt.Sprintf("Unexpected error shutting down: %v", err)) + +// } + +// func TestListenTLS(t *testing.T) { +// proxy, err := ws.NewProxy("localhost:8080", "wss://127.0.0.1", logger, handler) +// assert.NoError(t, err) +// go func() { +// err := proxy.ListenTLS("cert.pem", "key.pem") +// assert.Nil(t, err, fmt.Sprintf("Unexpected error listening: %v", err)) +// }() + +// }