Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/jetstream/api/structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ type PortalConfig struct {
LogLevel string `configName:"LOG_LEVEL"`
UIListMaxSize int64 `configName:"UI_LIST_MAX_SIZE"`
UIListAllowLoadMaxed bool `configName:"UI_LIST_ALLOW_LOAD_MAXED"`
AutoRefreshCNSITokens bool `configName:"AUTOREFRESH_CNSI_TOKENS"`
CFAdminIdentifier string
CloudFoundryInfo *CFInfo
HTTPS bool
Expand Down
1 change: 1 addition & 0 deletions src/jetstream/api/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type TokenRepository interface {
FindAuthToken(userGUID string, encryptionKey []byte) (TokenRecord, error)
SaveAuthToken(userGUID string, tokenRecord TokenRecord, encryptionKey []byte) error

ListAllEnabledConnectedCNSITokens(encryptionKey []byte) ([]BackupTokenRecord, error)
FindCNSIToken(cnsiGUID string, userGUID string, encryptionKey []byte) (TokenRecord, error)
FindCNSITokenIncludeDisconnected(cnsiGUID string, userGUID string, encryptionKey []byte) (TokenRecord, error)
FindAllCNSITokenBackup(cnsiGUID string, encryptionKey []byte) ([]BackupTokenRecord, error)
Expand Down
63 changes: 63 additions & 0 deletions src/jetstream/cnsi.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/url"
"strconv"
"strings"
"time"

"github.com/labstack/echo/v4"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -699,6 +700,63 @@ func (p *portalProxy) updateTokenAuth(userGUID string, t api.TokenRecord) error
return nil
}

func (p *portalProxy) startCNSITokenRefreshRoutines() error {
log.Debug("startCNSITokenRefreshRoutines")

tokenRepo, err := p.GetStoreFactory().TokenStore()
if err != nil {
log.Errorf(dbReferenceError, err)
return fmt.Errorf(dbReferenceError, err)
}

tokens, err := tokenRepo.ListAllEnabledConnectedCNSITokens(p.Config.EncryptionKeyInBytes)
if err != nil {
msg := "unable to list enabled and connected cnsi tokens: %v"
log.Errorf(msg, err)
return fmt.Errorf(msg, err)
}

for _, token := range tokens {
p.refreshRoutines.wg.Add(1)
go p.refreshToken(token)
}

return nil
}

func (p *portalProxy) refreshToken(token api.BackupTokenRecord) {
log.Debug("refreshToken")
defer p.refreshRoutines.wg.Done()
for {
endpoint, err := p.GetCNSIRecord(token.EndpointGUID)
if err != nil {
// Check if the endpoint doesn't exist anymore, if so shut down routine
// Depends on the implementation of EndpointRepository interface from api/cnsis.go,
// but all current implementations pass through to repository/cnsis/pgsql_cnsis.go line 308 eventually
if err.Error() == "No match for that Endpoint" {
log.Infof("endpoint '%v' no longer exists, shutting down token refresher routine", token.EndpointGUID)
return
}
// If any other error occurred, log it and retry
log.Errorf("could not get retrieve endpoint record to refresh cnsi token '%v': %v", token.TokenRecord.TokenGUID, err)
continue
}
expiry := time.Unix(token.TokenRecord.TokenExpiry, 0)
select {
case <-time.After(time.Until(expiry)):
case <-p.refreshRoutines.context.Done():
return
}

updatedTokenRecord, err := p.RefreshOAuthToken(endpoint.SkipSSLValidation, token.EndpointGUID, token.UserGUID, endpoint.ClientId, endpoint.ClientSecret, endpoint.TokenEndpoint)
if err != nil {
log.Errorf("could not refresh cnsi token '%v': %v", token.TokenRecord.TokenGUID, err)
continue
}
token.TokenRecord = updatedTokenRecord
}
}

func (p *portalProxy) setCNSITokenRecord(cnsiGUID string, userGUID string, t api.TokenRecord) error {
log.Debug("setCNSITokenRecord")
tokenRepo, err := p.GetStoreFactory().TokenStore()
Expand All @@ -714,6 +772,11 @@ func (p *portalProxy) setCNSITokenRecord(cnsiGUID string, userGUID string, t api
return fmt.Errorf(msg, err)
}

if p.Config.AutoRefreshCNSITokens {
p.refreshRoutines.wg.Add(1)
go p.refreshToken(api.BackupTokenRecord{TokenRecord: t, UserGUID: userGUID, EndpointGUID: cnsiGUID})
}

return nil
}

Expand Down
158 changes: 158 additions & 0 deletions src/jetstream/cnsi_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package main

import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -251,6 +256,159 @@ func TestGetCFv2InfoWithInvalidEndpoint(t *testing.T) {
}
}

func TestRegisterEndpointStartsRefreshRoutine(t *testing.T) {
t.Parallel()

Convey("Request to register endpoint", t, func() {
// mock StratosAuthService
ctrl := gomock.NewController(t)
mockStratosAuth := mock_api.NewMockStratosAuth(ctrl)
defer ctrl.Finish()

// setup mock DB, PortalProxy and mock StratosAuthService
pp, db, mock := setupPortalProxyWithAuthService(mockStratosAuth)
defer db.Close()

pp.Config.AutoRefreshCNSITokens = true

// mock individual APIEndpoints
mockV2Info := setupMockEndpointServer(t)
defer mockV2Info.Close()

mockUAAResponseModifiedExpiry := mockUAAResponse

splits := strings.Split(mockUAAResponse.AccessToken, ".")

decoded, _ := base64.RawStdEncoding.DecodeString(splits[1])

u := new(api.JWTUserTokenInfo)
json.Unmarshal(decoded, &u)

u.TokenExpiry = time.Now().Add(time.Minute * 5).Unix()

encode, _ := json.Marshal(u)

splits[1] = base64.RawStdEncoding.EncodeToString(encode)

mockUAAResponseModifiedExpiry.AccessToken = strings.Join(splits, ".")

mockUAA := setupMockServer(t,
msRoute("/oauth/token"),
msMethod("POST"),
msStatus(http.StatusOK),
msBody(jsonMust(mockUAAResponseModifiedExpiry)))

// mock different users
mockAdmin := setupMockUser(mockAdminGUID, true, []string{})

pp.GetConfig().UserEndpointsEnabled = config.UserEndpointsConfigEnum.Enabled

// setup
adminEndpoint := setupMockEndpointRegisterRequest(t, mockAdmin.ConnectedUser, mockV2Info, "CF Cluster 1", true, true)

if errSession := pp.setSessionValues(adminEndpoint.EchoContext, mockAdmin.SessionValues); errSession != nil {
t.Error(errors.New("unable to mock/stub user in session object"))
}

Convey("registering a new endpoint and logging in leads to a refresh routine being started", func() {
// mock executions
mockStratosAuth.
EXPECT().
GetUser(gomock.Eq(mockAdmin.ConnectedUser.GUID)).
Return(mockAdmin.ConnectedUser, nil)

mock.
ExpectQuery(selectFromCNSIs).
WillReturnRows(
sqlmock.NewRows(
[]string{"guid", "name", "cnsi_type", "api_endpoint", "auth_endpoint", "token_endpoint", "doppler_logging_endpoint", "skip_ssl_validation", "client_id", "client_secret", "sso_allowed", "sub_type", "meta_data", "creator", "ca_cert"},
),
)
mock.
ExpectExec(insertIntoCNSIs).
WillReturnResult(sqlmock.NewResult(1, 1))

fetchInfo := getCFPlugin(pp, "cf").Info
err := pp.RegisterEndpoint(adminEndpoint.EchoContext, fetchInfo)

So(err, ShouldBeNil)

first := adminEndpoint.QueryArgs[:4]
newRow := append(first, mockUAA.URL)
last := adminEndpoint.QueryArgs[5:]
newRow = append(newRow, last...)

mock.
ExpectQuery(selectAnyFromCNSIs).
WillReturnRows(
sqlmock.NewRows(
[]string{"guid", "name", "cnsi_type", "api_endpoint", "auth_endpoint", "token_endpoint", "doppler_logging_endpoint", "skip_ssl_validation", "client_id", "client_secret", "sso_allowed", "sub_type", "meta_data", "creator", "ca_cert"},
).AddRow(newRow...),
)

mock.
ExpectQuery(selectAnyFromCNSIs).
WillReturnRows(
sqlmock.NewRows(
[]string{"guid", "name", "cnsi_type", "api_endpoint", "auth_endpoint", "token_endpoint", "doppler_logging_endpoint", "skip_ssl_validation", "client_id", "client_secret", "sso_allowed", "sub_type", "meta_data", "creator", "ca_cert"},
).AddRow(newRow...),
)

mock.
ExpectQuery(selectAnyFromTokens).
WithArgs(newRow[0], mockAdmin.ConnectedUser.GUID).
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow(0))

mock.
ExpectExec(insertIntoTokens).
WillReturnResult(sqlmock.NewResult(1, 1))

mock.
ExpectQuery(selectAnyFromCNSIs).
WillReturnRows(
sqlmock.NewRows(
[]string{"guid", "name", "cnsi_type", "api_endpoint", "auth_endpoint", "token_endpoint", "doppler_logging_endpoint", "skip_ssl_validation", "client_id", "client_secret", "sso_allowed", "sub_type", "meta_data", "creator", "ca_cert"},
).AddRow(newRow...),
)

mock.
ExpectQuery(selectAnyFromCNSIs).
WillReturnRows(
sqlmock.NewRows(
[]string{"guid", "name", "cnsi_type", "api_endpoint", "auth_endpoint", "token_endpoint", "doppler_logging_endpoint", "skip_ssl_validation", "client_id", "client_secret", "sso_allowed", "sub_type", "meta_data", "creator", "ca_cert"},
).AddRow(newRow...),
)

mockStratosAuth.
EXPECT().
GetUser(gomock.Eq(mockAdmin.ConnectedUser.GUID)).
Return(mockAdmin.ConnectedUser, nil)

// value are irrelevant, since we mock the reponse from the uaa regardless but the login won't work without them
formDataForApiLogin := url.Values{}
formDataForApiLogin.Set("username", "test")
formDataForApiLogin.Set("password", "test")
newReq, _ := http.NewRequest(http.MethodPost, "localhost:9999/some/fake/url", bytes.NewBufferString(formDataForApiLogin.Encode()))
newReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")

newContext := adminEndpoint.EchoContext.Echo().NewContext(newReq, adminEndpoint.EchoContext.Response())
_, err = pp.DoLoginToCNSI(newContext, adminEndpoint.InsertArgs[0].(string), true)

// Asynchronosly wait 5 seconds, then cancel the refresh routines
go func() {
time.Sleep(time.Second * 5)
pp.refreshRoutines.cancel()
}()

// Wait until all refresh routines have terminated (portalProxy does the same on graceful shutdown)
pp.refreshRoutines.wg.Wait()

So(err, ShouldBeNil)
So(mock.ExpectationsWereMet(), ShouldBeNil)
})
})
}

func TestRegisterWithUserEndpointsEnabled(t *testing.T) {
// execute this in parallel
t.Parallel()
Expand Down
20 changes: 20 additions & 0 deletions src/jetstream/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"crypto/sha1"
"database/sql"
"encoding/gob"
Expand Down Expand Up @@ -282,6 +283,8 @@ func main() {

log.Info("Initialization complete.")

ctx, cancel := context.WithCancel(context.Background())
portalProxy.SetRefreshRoutineContext(ctx, cancel)
c := make(chan os.Signal, 2)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
Expand All @@ -290,6 +293,9 @@ func main() {
fmt.Println()
log.Info("Attempting to shut down gracefully...")

// Cancel portal proxy context
cancel()

// Database connection pool
log.Info(`... Closing database connection pool`)
databaseConnectionPool.Close()
Expand All @@ -310,6 +316,8 @@ func main() {
pCleanup.Destroy()
}
}
// wait for any goroutines to shut down
portalProxy.refreshRoutines.wg.Wait()

log.Info("Graceful shut down complete")
os.Exit(1)
Expand Down Expand Up @@ -804,6 +812,12 @@ func start(config api.PortalConfig, p *portalProxy, needSetupMiddleware bool, is
go stopEchoWhenUpgraded(e, p.Env())
}

if p.Config.AutoRefreshCNSITokens {
if err := p.startCNSITokenRefreshRoutines(); err != nil {
return err
}
}

var engineErr error
address := config.TLSAddress
if config.HTTPS {
Expand Down Expand Up @@ -1194,3 +1208,9 @@ func (portalProxy *portalProxy) SetStoreFactory(f api.StoreFactory) api.StoreFac
portalProxy.StoreFactory = f
return old
}

// SetContext sets the context
func (portalProxy *portalProxy) SetRefreshRoutineContext(ctx context.Context, cancel context.CancelFunc) {
portalProxy.refreshRoutines.context = ctx
portalProxy.refreshRoutines.cancel = cancel
}
3 changes: 3 additions & 0 deletions src/jetstream/mock_server_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"crypto/sha1"
"database/sql"
"database/sql/driver"
Expand Down Expand Up @@ -182,6 +183,8 @@ func setupPortalProxy(db *sql.DB) *portalProxy {
store := factory.NewDefaultStoreFactory(db)
pp.SetStoreFactory(store)

pp.SetRefreshRoutineContext(context.WithCancel(context.Background()))

return pp
}

Expand Down
4 changes: 4 additions & 0 deletions src/jetstream/plugins/desktop/helm/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ func (d *TokenStore) SaveAuthToken(userGUID string, tokenRecord api.TokenRecord,
return d.store.SaveAuthToken(userGUID, tokenRecord, encryptionKey)
}

func (d *TokenStore) ListAllEnabledConnectedCNSITokens(encryptionKey []byte) ([]api.BackupTokenRecord, error) {
return d.store.ListAllEnabledConnectedCNSITokens(encryptionKey)
}

func (d *TokenStore) FindCNSIToken(cnsiGUID string, userGUID string, encryptionKey []byte) (api.TokenRecord, error) {
return d.store.FindCNSIToken(cnsiGUID, userGUID, encryptionKey)
}
Expand Down
4 changes: 4 additions & 0 deletions src/jetstream/plugins/desktop/kubernetes/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ func (d *TokenStore) SaveAuthToken(userGUID string, tokenRecord api.TokenRecord,
return d.store.SaveAuthToken(userGUID, tokenRecord, encryptionKey)
}

func (d *TokenStore) ListAllEnabledConnectedCNSITokens(encryptionKey []byte) ([]api.BackupTokenRecord, error) {
return d.store.ListAllEnabledConnectedCNSITokens(encryptionKey)
}

func (d *TokenStore) FindCNSIToken(cnsiGUID string, userGUID string, encryptionKey []byte) (api.TokenRecord, error) {

local, cfg, err := ListKubernetes()
Expand Down
4 changes: 4 additions & 0 deletions src/jetstream/plugins/desktop/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ func (d *TokenStore) SaveAuthToken(userGUID string, tokenRecord api.TokenRecord,
return d.store.SaveAuthToken(userGUID, tokenRecord, encryptionKey)
}

func (d *TokenStore) ListAllEnabledConnectedCNSITokens(encryptionKey []byte) ([]api.BackupTokenRecord, error) {
return d.store.ListAllEnabledConnectedCNSITokens(encryptionKey)
}

func (d *TokenStore) FindCNSIToken(cnsiGUID string, userGUID string, encryptionKey []byte) (api.TokenRecord, error) {

// Main method that we need to override to get the token for the given endpoint
Expand Down
Loading