From 7feb29d3b13ee7b81110a91eb4525afabff88229 Mon Sep 17 00:00:00 2001 From: Oyewole Samuel A Date: Mon, 1 Nov 2021 17:20:21 +0100 Subject: [PATCH 1/2] upgrade to official mongo go driver #12 --- .gitignore | 2 ++ mgostore_test.go | 21 ++++++++++--- mongostore.go | 78 ++++++++++++++++++++++++++++-------------------- 3 files changed, 64 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 0026861..d2ded5a 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,5 @@ _cgo_export.* _testmain.go *.exe +go.mod +go.sum \ No newline at end of file diff --git a/mgostore_test.go b/mgostore_test.go index 4fdf41d..244b8af 100644 --- a/mgostore_test.go +++ b/mgostore_test.go @@ -1,5 +1,6 @@ // Copyright (c) 2013 Gregor Robinson. // Copyright (c) 2013 Brian Jones. +// Copyright (c) 2021 Oyewol Samuel. // All rights reserved. // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. @@ -7,13 +8,16 @@ package mongostore import ( + "context" "encoding/gob" "net/http" "net/http/httptest" + "os" "testing" - "github.com/globalsign/mgo" "github.com/gorilla/sessions" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) type FlashMessage struct { @@ -36,13 +40,22 @@ func TestMongoStore(t *testing.T) { // license that can be found in the LICENSE file. // Round 1 ---------------------------------------------------------------- - dbsess, err := mgo.Dial("localhost") + mongosrv := os.Getenv("") + client, err := mongo.NewClient(options.Client().ApplyURI(mongosrv)) if err != nil { panic(err) } - defer dbsess.Close() - store := NewMongoStore(dbsess.DB("test").C("test_session"), 3600, true, + if err := client.Connect(context.Background()); err != nil { + panic(err) + } + + defer client.Disconnect(context.Background()) + + store := NewMongoStore( + client.Database("test").Collection("test_session"), + 3600, + true, []byte("secret-key")) req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) diff --git a/mongostore.go b/mongostore.go index 7087bff..a2ca6a4 100644 --- a/mongostore.go +++ b/mongostore.go @@ -5,23 +5,26 @@ package mongostore import ( + "context" "errors" "net/http" "time" - "github.com/globalsign/mgo" "github.com/globalsign/mgo/bson" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) var ( - ErrInvalidId = errors.New("mgostore: invalid session id") + ErrInvalidId = errors.New("mongostore: invalid session id") ) // Session object store in MongoDB type Session struct { - Id bson.ObjectId `bson:"_id,omitempty"` + Id primitive.ObjectID `bson:"_id,omitempty"` Data string Modified time.Time } @@ -31,13 +34,12 @@ type MongoStore struct { Codecs []securecookie.Codec Options *sessions.Options Token TokenGetSeter - coll *mgo.Collection + coll *mongo.Collection } // NewMongoStore returns a new MongoStore. // Set ensureTTL to true let the database auto-remove expired object by maxAge. -func NewMongoStore(c *mgo.Collection, maxAge int, ensureTTL bool, - keyPairs ...[]byte) *MongoStore { +func NewMongoStore(c *mongo.Collection, maxAge int, ensureTTL bool, keyPairs ...[]byte) *MongoStore { store := &MongoStore{ Codecs: securecookie.CodecsFromPairs(keyPairs...), Options: &sessions.Options{ @@ -51,12 +53,17 @@ func NewMongoStore(c *mgo.Collection, maxAge int, ensureTTL bool, store.MaxAge(maxAge) if ensureTTL { - c.EnsureIndex(mgo.Index{ - Key: []string{"modified"}, - Background: true, - Sparse: true, - ExpireAfter: time.Duration(maxAge) * time.Second, - }) + expireAfter := time.Duration(maxAge) * time.Second + + indexModel := mongo.IndexModel{ + Keys: bson.M{"modified": 1}, + Options: options.Index().SetExpireAfterSeconds(int32(expireAfter.Seconds())), + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + c.Indexes().CreateOne(ctx, indexModel) } return store @@ -64,14 +71,12 @@ func NewMongoStore(c *mgo.Collection, maxAge int, ensureTTL bool, // Get registers and returns a session for the given name and session store. // It returns a new session if there are no sessions registered for the name. -func (m *MongoStore) Get(r *http.Request, name string) ( - *sessions.Session, error) { +func (m *MongoStore) Get(r *http.Request, name string) (*sessions.Session, error) { return sessions.GetRegistry(r).Get(m, name) } // New returns a session for the given name without adding it to the registry. -func (m *MongoStore) New(r *http.Request, name string) ( - *sessions.Session, error) { +func (m *MongoStore) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(m, name) session.Options = &sessions.Options{ Path: m.Options.Path, @@ -97,8 +102,7 @@ func (m *MongoStore) New(r *http.Request, name string) ( } // Save saves all sessions registered for the current request. -func (m *MongoStore) Save(r *http.Request, w http.ResponseWriter, - session *sessions.Session) error { +func (m *MongoStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { if session.Options.MaxAge < 0 { if err := m.delete(session); err != nil { return err @@ -115,8 +119,7 @@ func (m *MongoStore) Save(r *http.Request, w http.ResponseWriter, return err } - encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, - m.Codecs...) + encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, m.Codecs...) if err != nil { return err } @@ -139,19 +142,18 @@ func (m *MongoStore) MaxAge(age int) { } } -func (m *MongoStore) load(session *sessions.Session) error { - if !bson.IsObjectIdHex(session.ID) { +func (m *MongoStore) load(session *sessions.Session) error { + objID, err := primitive.ObjectIDFromHex(session.ID) + if err != nil { return ErrInvalidId } s := Session{} - err := m.coll.FindId(bson.ObjectIdHex(session.ID)).One(&s) - if err != nil { + if err := m.coll.FindOne(context.Background(), bson.M{"_id": objID}).Decode(&s); err != nil { return err } - if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, - m.Codecs...); err != nil { + if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, m.Codecs...); err != nil { return err } @@ -159,10 +161,12 @@ func (m *MongoStore) load(session *sessions.Session) error { } func (m *MongoStore) upsert(session *sessions.Session) error { - if !bson.IsObjectIdHex(session.ID) { + objID, err := primitive.ObjectIDFromHex(session.ID) + if err != nil { return ErrInvalidId } + var modified time.Time if val, ok := session.Values["modified"]; ok { modified, ok = val.(time.Time) @@ -173,19 +177,22 @@ func (m *MongoStore) upsert(session *sessions.Session) error { modified = time.Now() } - encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, - m.Codecs...) + encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, m.Codecs...) if err != nil { return err } s := Session{ - Id: bson.ObjectIdHex(session.ID), + Id: objID, Data: encoded, Modified: modified, } - _, err = m.coll.UpsertId(s.Id, &s) + opts := options.Update().SetUpsert(true) + filter := bson.M{"_id": s.Id} + updateData := bson.M{"$set": s} + + _, err = m.coll.UpdateOne(context.Background(), filter, updateData, opts) if err != nil { return err } @@ -194,9 +201,14 @@ func (m *MongoStore) upsert(session *sessions.Session) error { } func (m *MongoStore) delete(session *sessions.Session) error { - if !bson.IsObjectIdHex(session.ID) { + objID, err := primitive.ObjectIDFromHex(session.ID) + if err != nil { return ErrInvalidId } - return m.coll.RemoveId(bson.ObjectIdHex(session.ID)) + _, deleteErr := m.coll.DeleteOne(context.Background(), bson.M{"_id": objID}) + if err != nil { + return deleteErr + } + return nil } From 4db440b50d495a45d92c964b1796d6c81cb72227 Mon Sep 17 00:00:00 2001 From: Oyewole Samuel A Date: Mon, 1 Nov 2021 17:30:39 +0100 Subject: [PATCH 2/2] revamp delete/update response --- mongostore.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mongostore.go b/mongostore.go index a2ca6a4..8259be5 100644 --- a/mongostore.go +++ b/mongostore.go @@ -192,8 +192,7 @@ func (m *MongoStore) upsert(session *sessions.Session) error { filter := bson.M{"_id": s.Id} updateData := bson.M{"$set": s} - _, err = m.coll.UpdateOne(context.Background(), filter, updateData, opts) - if err != nil { + if _, err = m.coll.UpdateOne(context.Background(), filter, updateData, opts); err != nil { return err } @@ -206,9 +205,8 @@ func (m *MongoStore) delete(session *sessions.Session) error { return ErrInvalidId } - _, deleteErr := m.coll.DeleteOne(context.Background(), bson.M{"_id": objID}) - if err != nil { - return deleteErr + if _, err = m.coll.DeleteOne(context.Background(), bson.M{"_id": objID}); err != nil { + return err } return nil }