diff --git a/go.mod b/go.mod index 9eb4dd1..dd24a81 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.15.0 // indirect + github.com/joho/godotenv v1.5.1 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect diff --git a/go.sum b/go.sum index d89f00e..d9946e4 100644 --- a/go.sum +++ b/go.sum @@ -33,6 +33,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/julwrites/BotPlatform v0.0.0-20220206144002-60e1b8060734 h1:U/z8aO/8zMpOzdR7kK9hnHfXber1fHa7FWlXGeuG3Yc= github.com/julwrites/BotPlatform v0.0.0-20220206144002-60e1b8060734/go.mod h1:RAVF1PibRuRYv1Z7VxNapzrikBrjtF48aFPCoCVnLpM= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/pkg/app/api_client.go b/pkg/app/api_client.go index 11726b2..0e69e94 100644 --- a/pkg/app/api_client.go +++ b/pkg/app/api_client.go @@ -7,15 +7,11 @@ import ( "io" "log" "net/http" - "os" "sync" - "github.com/julwrites/ScriptureBot/pkg/utils" + "github.com/julwrites/ScriptureBot/pkg/secrets" ) -// getSecretFunc is a variable to allow mocking in tests -var getSecretFunc = utils.GetSecret - var ( cachedAPIURL string cachedAPIKey string @@ -43,7 +39,7 @@ func SetAPIConfigOverride(url, key string) { configInitialized = true } -func getAPIConfig(projectID string) (string, string) { +func getAPIConfig() (string, string) { configMutex.Lock() defer configMutex.Unlock() @@ -51,42 +47,14 @@ func getAPIConfig(projectID string) (string, string) { return cachedAPIURL, cachedAPIKey } - url := os.Getenv("BIBLE_API_URL") - key := os.Getenv("BIBLE_API_KEY") - - // If env vars are missing, try to fetch from Secret Manager - if url == "" || key == "" { - // Try to fetch project ID if not provided. - if projectID == "" { - projectID = os.Getenv("GCLOUD_PROJECT_ID") - - if projectID == "" { - var err error - projectID, err = getSecretFunc("", "GCLOUD_PROJECT_ID") - if err != nil { - log.Printf("Failed to fetch GCLOUD_PROJECT_ID from Secret Manager: %v", err) - } - } - } + url, err := secrets.Get("BIBLE_API_URL") + if err != nil { + log.Printf("Failed to get BIBLE_API_URL: %v", err) + } - if projectID != "" { - if url == "" { - var err error - url, err = getSecretFunc(projectID, "BIBLE_API_URL") - if err != nil { - log.Printf("Failed to fetch BIBLE_API_URL from Secret Manager: %v", err) - } - } - if key == "" { - var err error - key, err = getSecretFunc(projectID, "BIBLE_API_KEY") - if err != nil { - log.Printf("Failed to fetch BIBLE_API_KEY from Secret Manager: %v", err) - } - } - } else { - log.Println("GCLOUD_PROJECT_ID is not set and no project ID passed, skipping Secret Manager lookup") - } + key, err := secrets.Get("BIBLE_API_KEY") + if err != nil { + log.Printf("Failed to get BIBLE_API_KEY: %v", err) } cachedAPIURL = url @@ -99,7 +67,7 @@ func getAPIConfig(projectID string) (string, string) { // SubmitQuery sends the QueryRequest to the Bible API and unmarshals the response into result. // result should be a pointer to the expected response struct. func SubmitQuery(req QueryRequest, result interface{}, projectID string) error { - apiURL, apiKey := getAPIConfig(projectID) + apiURL, apiKey := getAPIConfig() if apiURL == "" { return fmt.Errorf("BIBLE_API_URL environment variable is not set") } diff --git a/pkg/app/api_client_test.go b/pkg/app/api_client_test.go index 965fe1a..8817207 100644 --- a/pkg/app/api_client_test.go +++ b/pkg/app/api_client_test.go @@ -1,81 +1,60 @@ package app import ( - "encoding/json" - "fmt" "net/http" "net/http/httptest" "testing" - - "github.com/julwrites/ScriptureBot/pkg/utils" ) func TestSubmitQuery(t *testing.T) { - // Mock server - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check headers - if r.Header.Get("Content-Type") != "application/json" { - t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) - } - - // Decode request to verify it - var req QueryRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - // Simple response based on input - if req.Query.Prompt == "error" { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(`{"error": {"code": 500, "message": "simulated error"}}`)) - return - } - - if req.Query.Prompt == "badjson" { - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{invalid json`)) - return - } - - // Success response - resp := VerseResponse{Verse: "Success Verse"} - json.NewEncoder(w).Encode(resp) - })) + handler := newMockApiHandler() + ts := httptest.NewServer(handler) defer ts.Close() - // Set env vars - defer setEnv("BIBLE_API_URL", ts.URL)() - - // Test Case 1: Success t.Run("Success", func(t *testing.T) { + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + req := QueryRequest{Query: QueryObject{Prompt: "hello"}} - var resp VerseResponse + var resp OQueryResponse err := SubmitQuery(req, &resp, "") if err != nil { t.Errorf("Unexpected error: %v", err) } - if resp.Verse != "Success Verse" { - t.Errorf("Expected 'Success Verse', got '%s'", resp.Verse) + if resp.Text != "Answer text" { + t.Errorf("Expected 'Answer text', got '%s'", resp.Text) } }) - // Test Case 2: API Error t.Run("API Error", func(t *testing.T) { + handler.statusCode = http.StatusInternalServerError + handler.rawResponse = `{"error": {"code": 500, "message": "simulated error"}}` + defer func() { // Reset handler + handler.statusCode = http.StatusOK + handler.rawResponse = "" + }() + + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + req := QueryRequest{Query: QueryObject{Prompt: "error"}} var resp VerseResponse err := SubmitQuery(req, &resp, "") if err == nil { t.Error("Expected error, got nil") } - // Expect error message to contain "simulated error" - if err != nil && err.Error() != "api error (500): simulated error" { + if err.Error() != "api error (500): simulated error" { t.Errorf("Expected specific API error, got: %v", err) } }) - // Test Case 3: Bad JSON Response t.Run("Bad JSON", func(t *testing.T) { + handler.rawResponse = `{invalid json` + defer func() { handler.rawResponse = "" }() + + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + req := QueryRequest{Query: QueryObject{Prompt: "badjson"}} var resp VerseResponse err := SubmitQuery(req, &resp, "") @@ -84,13 +63,8 @@ func TestSubmitQuery(t *testing.T) { } }) - // Test Case 4: No URL set t.Run("No URL", func(t *testing.T) { - // Temporarily unset/clear the env var - restore := setEnv("BIBLE_API_URL", "") - defer restore() - // Also unset PROJECT_ID to avoid Secret Manager lookup - defer utils.SetEnv("GCLOUD_PROJECT_ID", "")() + defer setEnv("BIBLE_API_URL", "")() ResetAPIConfigCache() req := QueryRequest{} @@ -101,65 +75,3 @@ func TestSubmitQuery(t *testing.T) { } }) } - -func TestGetAPIConfig_SecretManagerFallback(t *testing.T) { - // Ensure Env Vars are empty - defer utils.SetEnv("BIBLE_API_URL", "")() - defer utils.SetEnv("BIBLE_API_KEY", "")() - defer utils.SetEnv("GCLOUD_PROJECT_ID", "test-project")() - ResetAPIConfigCache() - - // Mock the secret function - oldGetSecret := getSecretFunc - defer func() { getSecretFunc = oldGetSecret }() - - getSecretFunc = func(project, name string) (string, error) { - if project != "test-project" { - return "", fmt.Errorf("unexpected project: %s", project) - } - if name == "BIBLE_API_URL" { - return "http://secret-url.com", nil - } - if name == "BIBLE_API_KEY" { - return "secret-key", nil - } - return "", fmt.Errorf("unexpected secret: %s", name) - } - - url, key := getAPIConfig("") - - if url != "http://secret-url.com" { - t.Errorf("Expected URL 'http://secret-url.com', got '%s'", url) - } - if key != "secret-key" { - t.Errorf("Expected Key 'secret-key', got '%s'", key) - } -} - -func TestGetAPIConfig_PassedProjectID(t *testing.T) { - // Ensure Env Vars are empty, including GCLOUD_PROJECT_ID - defer utils.SetEnv("BIBLE_API_URL", "")() - defer utils.SetEnv("BIBLE_API_KEY", "")() - defer utils.SetEnv("GCLOUD_PROJECT_ID", "")() - ResetAPIConfigCache() - - // Mock the secret function - oldGetSecret := getSecretFunc - defer func() { getSecretFunc = oldGetSecret }() - - getSecretFunc = func(project, name string) (string, error) { - if project != "passed-project" { - return "", fmt.Errorf("unexpected project: %s", project) - } - if name == "BIBLE_API_URL" { - return "http://secret-url-passed.com", nil - } - return "", fmt.Errorf("unexpected secret: %s", name) - } - - url, _ := getAPIConfig("passed-project") - - if url != "http://secret-url-passed.com" { - t.Errorf("Expected URL 'http://secret-url-passed.com', got '%s'", url) - } -} diff --git a/pkg/app/ask_test.go b/pkg/app/ask_test.go index 1bf3930..267b2ad 100644 --- a/pkg/app/ask_test.go +++ b/pkg/app/ask_test.go @@ -1,7 +1,6 @@ package app import ( - "encoding/json" "net/http" "net/http/httptest" "strings" @@ -12,29 +11,14 @@ import ( ) func TestGetBibleAsk(t *testing.T) { - // Mock server - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var req QueryRequest - json.NewDecoder(r.Body).Decode(&req) - - if req.Query.Prompt == "error" { - w.WriteHeader(http.StatusInternalServerError) - return - } - - resp := OQueryResponse{ - Text: "Answer text", - References: []SearchResult{ - {Verse: "Ref 1:1", URL: "http://ref1"}, - }, - } - json.NewEncoder(w).Encode(resp) - })) + handler := newMockApiHandler() + ts := httptest.NewServer(handler) defer ts.Close() - defer setEnv("BIBLE_API_URL", ts.URL)() - t.Run("Success", func(t *testing.T) { + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + var env def.SessionData env.Msg.Message = "Question" conf := utils.UserConfig{Version: "NIV"} @@ -51,6 +35,12 @@ func TestGetBibleAsk(t *testing.T) { }) t.Run("Error", func(t *testing.T) { + handler.statusCode = http.StatusInternalServerError + defer func() { handler.statusCode = http.StatusOK }() + + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + var env def.SessionData env.Msg.Message = "error" conf := utils.UserConfig{Version: "NIV"} diff --git a/pkg/app/devo_test.go b/pkg/app/devo_test.go index a7a983a..431ae74 100644 --- a/pkg/app/devo_test.go +++ b/pkg/app/devo_test.go @@ -1,6 +1,7 @@ package app import ( + "net/http/httptest" "testing" "time" @@ -76,56 +77,72 @@ func TestGetUtmostForHisHighestArticles(t *testing.T) { } func TestGetDevotionalData(t *testing.T) { - var env def.SessionData + handler := newMockApiHandler() + ts := httptest.NewServer(handler) + defer ts.Close() - env.ResourcePath = "../../resource" + t.Run("DTMSV", func(t *testing.T) { + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() - env.Res = GetDevotionalData(env, "DTMSV") + var env def.SessionData + env.ResourcePath = "../../resource" + env.Res = GetDevotionalData(env, "DTMSV") - if len(env.Res.Message) == 0 { - t.Errorf("Failed TestGetDevotionalData for DTMSV") - } + if len(env.Res.Message) == 0 { + t.Errorf("Failed TestGetDevotionalData for DTMSV") + } + }) } func TestGetDevo(t *testing.T) { - // Test initial devo command (no specific devo chosen) - var env def.SessionData - env.User.Action = "" - env.Msg.Message = CMD_DEVO // Simulate user typing /devo + handler := newMockApiHandler() + ts := httptest.NewServer(handler) + defer ts.Close() - env = GetDevo(env) - if len(env.Res.Message) == 0 { - t.Errorf("Failed TestGetDevo initial, no message") - } - if len(env.Res.Affordances.Options) == 0 { - t.Errorf("Failed TestGetDevo initial, no affordances") - } + t.Run("Initial Devo", func(t *testing.T) { + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + + var env def.SessionData + env.User.Action = "" + env.Msg.Message = CMD_DEVO + + env = GetDevo(env) + if len(env.Res.Message) == 0 { + t.Error("Failed TestGetDevo initial, no message") + } + if len(env.Res.Affordances.Options) == 0 { + t.Error("Failed TestGetDevo initial, no affordances") + } + }) - // Test each specific devotional option for devoName, devoCode := range DEVOS { + devoName := devoName + devoCode := devoCode t.Run(devoName, func(t *testing.T) { + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + var env def.SessionData env.User.Action = CMD_DEVO env.Msg.Message = devoName - env.ResourcePath = "../../resource" // Needed for some devo types + env.ResourcePath = "../../resource" env = GetDevo(env) - // Check if a message or options are returned if len(env.Res.Message) == 0 && len(env.Res.Affordances.Options) == 0 { - t.Errorf("Failed TestGetDevo for %s: no message or affordances", devoName) + t.Fatalf("Failed TestGetDevo for %s: no message or affordances", devoName) } - // Specific checks based on dispatch method switch GetDevotionalDispatchMethod(devoCode) { case Passage: if len(env.Res.Message) == 0 { - t.Errorf("Failed TestGetDevo for %s (Passage): no message returned", devoName) + t.Errorf("Expected a message for Passage type devo, got none") } case Keyboard: - // For Keyboard types, either options should be present or a message (e.g., for N5XBRP rest day) if len(env.Res.Affordances.Options) == 0 && len(env.Res.Message) == 0 { - t.Errorf("Failed TestGetDevo for %s (Keyboard): no affordances or message returned", devoName) + t.Errorf("Expected affordances or a message for Keyboard type devo, got none") } } }) diff --git a/pkg/app/passage_test.go b/pkg/app/passage_test.go index 6b2eede..2025e38 100644 --- a/pkg/app/passage_test.go +++ b/pkg/app/passage_test.go @@ -1,7 +1,6 @@ package app import ( - "encoding/json" "net/http" "net/http/httptest" "strings" @@ -13,11 +12,6 @@ import ( "github.com/julwrites/ScriptureBot/pkg/utils" ) -func setEnv(key, value string) func() { - ResetAPIConfigCache() - return utils.SetEnv(key, value) -} - func TestGetBiblePassageHtml(t *testing.T) { doc := GetPassageHTMLFunc("gen 8", "NIV") @@ -51,37 +45,15 @@ func TestGetPassage(t *testing.T) { } func TestGetBiblePassage(t *testing.T) { - // Mock server - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var req QueryRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - verse := "" - if len(req.Query.Verses) > 0 { - verse = req.Query.Verses[0] - } - - switch verse { - case "gen 1": - resp := VerseResponse{ - Verse: "
In the beginning God created the heavens and the earth.
", - } - json.NewEncoder(w).Encode(resp) - case "empty": - json.NewEncoder(w).Encode(VerseResponse{}) - default: // Any other case will trigger an error, forcing fallback - w.WriteHeader(http.StatusInternalServerError) - } - })) + handler := newMockApiHandler() + ts := httptest.NewServer(handler) defer ts.Close() - defer setEnv("BIBLE_API_URL", ts.URL)() - defer setEnv("BIBLE_API_KEY", "test_key")() - t.Run("Success", func(t *testing.T) { + defer setEnv("BIBLE_API_URL", ts.URL)() + defer setEnv("BIBLE_API_KEY", "test_key")() + ResetAPIConfigCache() + var env def.SessionData env.Msg.Message = "gen 1" var conf utils.UserConfig @@ -94,7 +66,14 @@ func TestGetBiblePassage(t *testing.T) { } }) - t.Run("Fallback on API error", func(t *testing.T) { + t.Run("Error", func(t *testing.T) { + handler.statusCode = http.StatusInternalServerError + defer func() { handler.statusCode = http.StatusOK }() + + defer setEnv("BIBLE_API_URL", ts.URL)() + defer setEnv("BIBLE_API_KEY", "test_key")() + ResetAPIConfigCache() + var env def.SessionData env.Msg.Message = "John 1:1" // Use a valid reference to test the fallback var conf utils.UserConfig @@ -121,6 +100,17 @@ func TestGetBiblePassage(t *testing.T) { }) t.Run("Empty", func(t *testing.T) { + handler.verseResponse = VerseResponse{} + defer func() { + handler.verseResponse = VerseResponse{ + Verse: "In the beginning God created the heavens and the earth.
", + } + }() + + defer setEnv("BIBLE_API_URL", ts.URL)() + defer setEnv("BIBLE_API_KEY", "test_key")() + ResetAPIConfigCache() + var env def.SessionData env.Msg.Message = "empty" env = GetBiblePassage(env) diff --git a/pkg/app/search_test.go b/pkg/app/search_test.go index 99bed93..d8a71e7 100644 --- a/pkg/app/search_test.go +++ b/pkg/app/search_test.go @@ -1,7 +1,6 @@ package app import ( - "encoding/json" "net/http" "net/http/httptest" "strings" @@ -12,33 +11,14 @@ import ( ) func TestGetBibleSearch(t *testing.T) { - // Mock server - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var req QueryRequest - json.NewDecoder(r.Body).Decode(&req) - - // Check if words contains "error" - for _, word := range req.Query.Words { - if word == "error" { - http.Error(w, "Error", http.StatusInternalServerError) - return - } - if word == "empty" { - json.NewEncoder(w).Encode(WordSearchResponse{}) - return - } - } - - resp := WordSearchResponse{ - {Verse: "Found 1:1", URL: "http://found1"}, - } - json.NewEncoder(w).Encode(resp) - })) + handler := newMockApiHandler() + ts := httptest.NewServer(handler) defer ts.Close() - defer setEnv("BIBLE_API_URL", ts.URL)() - t.Run("Success", func(t *testing.T) { + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + var env def.SessionData env.Msg.Message = "Found" conf := utils.UserConfig{Version: "NIV"} @@ -55,6 +35,16 @@ func TestGetBibleSearch(t *testing.T) { }) t.Run("Empty", func(t *testing.T) { + handler.wordSearchResponse = WordSearchResponse{} + defer func() { + handler.wordSearchResponse = WordSearchResponse{ + {Verse: "Found 1:1", URL: "http://found1"}, + } + }() + + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + var env def.SessionData env.Msg.Message = "empty" @@ -66,6 +56,12 @@ func TestGetBibleSearch(t *testing.T) { }) t.Run("Error", func(t *testing.T) { + handler.statusCode = http.StatusInternalServerError + defer func() { handler.statusCode = http.StatusOK }() + + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + var env def.SessionData env.Msg.Message = "error" diff --git a/pkg/app/test_utils.go b/pkg/app/test_utils.go new file mode 100644 index 0000000..ae887eb --- /dev/null +++ b/pkg/app/test_utils.go @@ -0,0 +1,76 @@ +package app + +import ( + "encoding/json" + "net/http" + "os" +) + +// setEnv is a helper function to temporarily set an environment variable and return a function to restore it. +func setEnv(key, value string) func() { + originalValue, isSet := os.LookupEnv(key) + os.Setenv(key, value) + return func() { + if isSet { + os.Setenv(key, originalValue) + } else { + os.Unsetenv(key) + } + } +} + +// mockApiHandler is a flexible handler for the mock server. +type mockApiHandler struct { + verseResponse VerseResponse + wordSearchResponse WordSearchResponse + oQueryResponse OQueryResponse + statusCode int + rawResponse string +} + +// newMockApiHandler creates a new mockApiHandler with default success responses. +func newMockApiHandler() *mockApiHandler { + return &mockApiHandler{ + statusCode: http.StatusOK, + verseResponse: VerseResponse{ + Verse: "In the beginning God created the heavens and the earth.
", + }, + wordSearchResponse: WordSearchResponse{ + {Verse: "Found 1:1", URL: "http://found1"}, + }, + oQueryResponse: OQueryResponse{ + Text: "Answer text", + References: []SearchResult{ + {Verse: "Ref 1:1", URL: "http://ref1"}, + }, + }, + } +} + +// ServeHTTP handles the incoming requests and sends the configured response. +func (h *mockApiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(h.statusCode) + + if h.rawResponse != "" { + w.Write([]byte(h.rawResponse)) + return + } + + var req QueryRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + if len(req.Query.Words) > 0 { + json.NewEncoder(w).Encode(h.wordSearchResponse) + return + } + + if req.Query.Prompt != "" { + json.NewEncoder(w).Encode(h.oQueryResponse) + return + } + + json.NewEncoder(w).Encode(h.verseResponse) +} diff --git a/pkg/app/tms_test.go b/pkg/app/tms_test.go index a4044e9..384710f 100644 --- a/pkg/app/tms_test.go +++ b/pkg/app/tms_test.go @@ -1,6 +1,7 @@ package app import ( + "net/http/httptest" "strings" "testing" @@ -155,6 +156,13 @@ func TestIdentifyQuery(t *testing.T) { } func TestGetRandomTMSVerse(t *testing.T) { + handler := newMockApiHandler() + ts := httptest.NewServer(handler) + defer ts.Close() + + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + var env def.SessionData var conf utils.UserConfig conf.Version = "NIV" @@ -171,6 +179,13 @@ func TestGetRandomTMSVerse(t *testing.T) { } func TestGetTMSVerse(t *testing.T) { + handler := newMockApiHandler() + ts := httptest.NewServer(handler) + defer ts.Close() + + defer setEnv("BIBLE_API_URL", ts.URL)() + ResetAPIConfigCache() + var env def.SessionData var conf utils.UserConfig conf.Version = "NIV" diff --git a/pkg/secrets/secrets.go b/pkg/secrets/secrets.go new file mode 100644 index 0000000..417bd5e --- /dev/null +++ b/pkg/secrets/secrets.go @@ -0,0 +1,83 @@ +package secrets + +import ( + "context" + "fmt" + "log" + "os" + + secretmanager "cloud.google.com/go/secretmanager/apiv1" + "cloud.google.com/go/secretmanager/apiv1/secretmanagerpb" + "github.com/joho/godotenv" +) + +func init() { + LoadAndLog() +} + +// LoadAndLog loads environment variables from a .env file (if present) and logs +// the status of the GCLOUD_PROJECT_ID. This function is called automatically on package initialization. +// It is also exported to allow for re-loading in test environments. +func LoadAndLog() { + // godotenv.Overload will read your .env file and set the environment variables. + // It will OVERWRITE any existing environment variables. + err := godotenv.Overload() + if err != nil { + log.Println("No .env file found, continuing with environment variables") + } + + // Log the status of the GCLOUD_PROJECT_ID for debugging purposes. + if projectID, ok := os.LookupEnv("GCLOUD_PROJECT_ID"); ok { + log.Printf("GCLOUD_PROJECT_ID is set: %s", projectID) + } else { + log.Println("GCLOUD_PROJECT_ID is not set. Google Secret Manager will not be used.") + } +} + +// Get retrieves a secret. It follows a specific order of precedence: +// 1. Google Secret Manager (if GCLOUD_PROJECT_ID is set) +// 2. Environment variables (which includes those loaded from a .env file) +// +// If the secret is not found in any of these locations, it returns an error. +func Get(secretName string) (string, error) { + // Attempt to get the secret from Google Secret Manager first. + projectID, isCloudRun := os.LookupEnv("GCLOUD_PROJECT_ID") + if isCloudRun && projectID != "" { + secretValue, err := getFromSecretManager(projectID, secretName) + if err == nil { + log.Printf("Loaded '%s' from Secret Manager", secretName) + return secretValue, nil + } + log.Printf("Could not fetch '%s' from Secret Manager, falling back to environment variables: %v", secretName, err) + } + + // Fallback to environment variables. + if value, ok := os.LookupEnv(secretName); ok { + log.Printf("Loaded '%s' from .env file or environment", secretName) + return value, nil + } + + return "", fmt.Errorf("secret '%s' not found in Secret Manager, .env file, or environment variables", secretName) +} + +// getFromSecretManager fetches a secret from Google Secret Manager. +func getFromSecretManager(projectID, secretName string) (string, error) { + ctx := context.Background() + client, err := secretmanager.NewClient(ctx) + if err != nil { + return "", fmt.Errorf("failed to create secret manager client: %v", err) + } + defer client.Close() + + name := fmt.Sprintf("projects/%s/secrets/%s/versions/latest", projectID, secretName) + req := &secretmanagerpb.AccessSecretVersionRequest{ + Name: name, + } + + result, err := client.AccessSecretVersion(ctx, req) + if err != nil { + return "", fmt.Errorf("failed to access secret version: %v", err) + } + + return string(result.Payload.Data), nil +} diff --git a/pkg/secrets/secrets_test.go b/pkg/secrets/secrets_test.go new file mode 100644 index 0000000..bceffb5 --- /dev/null +++ b/pkg/secrets/secrets_test.go @@ -0,0 +1,109 @@ +package secrets + +import ( + "fmt" + "os" + "testing" +) + +// setupTestEnvFile creates a temporary .env file for testing. +func setupTestEnvFile(t *testing.T, content string) func() { + t.Helper() + tmpfile, err := os.Create(".env") + if err != nil { + t.Fatalf("Failed to create temporary .env file: %v", err) + } + if _, err := tmpfile.Write([]byte(content)); err != nil { + t.Fatalf("Failed to write to temporary .env file: %v", err) + } + if err := tmpfile.Close(); err != nil { + t.Fatalf("Failed to close temporary .env file: %v", err) + } + + // The cleanup function to be returned and deferred by the caller. + return func() { + os.Remove(tmpfile.Name()) + // Reset environment variables changed by godotenv.Overload() + os.Unsetenv("FROM_DOTENV") + os.Unsetenv("OVERLOAD_VAR") + } +} + +func TestGetSecret(t *testing.T) { + // Sub-test 1: Test loading from an environment variable. + t.Run("from environment variable", func(t *testing.T) { + const envVarName = "FROM_ENV" + const expectedValue = "env_value" + t.Setenv(envVarName, expectedValue) + + // Re-run the loading logic to pick up the new env var for the test. + LoadAndLog() + + val, err := Get(envVarName) + if err != nil { + t.Errorf("Get() returned an error: %v", err) + } + if val != expectedValue { + t.Errorf("Expected to get '%s', but got '%s'", expectedValue, val) + } + }) + + // Sub-test 2: Test loading from a .env file. + t.Run("from .env file", func(t *testing.T) { + const secretName = "FROM_DOTENV" + const expectedValue = "dotenv_value" + cleanup := setupTestEnvFile(t, fmt.Sprintf("%s=%s", secretName, expectedValue)) + defer cleanup() + + // Re-run the loading logic to load the .env file. + LoadAndLog() + + val, err := Get(secretName) + if err != nil { + t.Errorf("Get() returned an error: %v", err) + } + if val != expectedValue { + t.Errorf("Expected to get '%s' from .env, but got '%s'", expectedValue, val) + } + }) + + // Sub-test 3: Test that .env file takes precedence over environment variables. + t.Run(".env overloads environment variable", func(t *testing.T) { + const varName = "OVERLOAD_VAR" + const envValue = "from_shell" + const dotenvValue = "from_dotenv_file" + + // Set the environment variable first. + t.Setenv(varName, envValue) + + // Create the .env file with the same variable but a different value. + cleanup := setupTestEnvFile(t, fmt.Sprintf("%s=%s", varName, dotenvValue)) + defer cleanup() + + // Re-run the loading logic. godotenv.Overload should prioritize the .env file. + LoadAndLog() + + val, err := Get(varName) + if err != nil { + t.Errorf("Get() returned an error: %v", err) + } + if val != dotenvValue { + t.Errorf("Expected value from .env ('%s'), but got value from shell ('%s')", dotenvValue, val) + } + }) + + // Sub-test 4: Test error when secret is not found. + t.Run("secret not found", func(t *testing.T) { + const nonExistentSecret = "THIS_SECRET_SHOULD_NOT_EXIST" + // Ensure the variable is not set in the environment. + os.Unsetenv(nonExistentSecret) + + // Re-run the loading logic. + LoadAndLog() + + _, err := Get(nonExistentSecret) + if err == nil { + t.Error("Expected an error when getting a non-existent secret, but got nil") + } + }) +} diff --git a/pkg/utils/database.go b/pkg/utils/database.go index 0a78d3b..c5a23e3 100644 --- a/pkg/utils/database.go +++ b/pkg/utils/database.go @@ -7,11 +7,11 @@ import ( "context" "encoding/json" "log" - "os" "sync" "cloud.google.com/go/datastore" "github.com/julwrites/BotPlatform/pkg/def" + "github.com/julwrites/ScriptureBot/pkg/secrets" ) var ( @@ -19,24 +19,15 @@ var ( cachedDatabaseIDOnce sync.Once ) -func GetDatabaseID(project string) string { +func GetDatabaseID() string { cachedDatabaseIDOnce.Do(func() { - // 1. Env Var - if id := os.Getenv("USER_DATABASE_ID"); id != "" { + id, err := secrets.Get("USER_DATABASE_ID") + if err == nil && id != "" { cachedDatabaseID = id return } - // 2. Secret Manager - if id, err := GetSecret(project, "USER_DATABASE_ID"); err == nil && id != "" { - cachedDatabaseID = id - return - } else if err != nil { - log.Printf("Failed to fetch USER_DATABASE_ID from Secret Manager: %v", err) - } - - // 3. Default - log.Println("Warning: USER_DATABASE_ID not found, defaulting to 'scripturebot-users'") + log.Printf("Warning: USER_DATABASE_ID not found, defaulting to 'scripturebot-users'. Error: %v", err) cachedDatabaseID = "scripturebot-users" }) return cachedDatabaseID @@ -49,7 +40,7 @@ type UserConfig struct { } func OpenClient(ctx *context.Context, project string) *datastore.Client { - dbID := GetDatabaseID(project) + dbID := GetDatabaseID() client, err := datastore.NewClientWithDatabase(*ctx, project, dbID) if err != nil { log.Printf("Failed to create Firestore client: %v", err) diff --git a/pkg/utils/secrets.go b/pkg/utils/secrets.go deleted file mode 100644 index 195a3b1..0000000 --- a/pkg/utils/secrets.go +++ /dev/null @@ -1,38 +0,0 @@ -package utils - -import ( - "context" - "fmt" - - secretmanager "cloud.google.com/go/secretmanager/apiv1" - "cloud.google.com/go/secretmanager/apiv1/secretmanagerpb" -) - -// GetSecret fetches a secret payload from Google Secret Manager. -// secretName should be the name of the secret (e.g., "BIBLE_API_KEY"). -// The function constructs the full resource name: projects/{projectID}/secrets/{secretName}/versions/latest -func GetSecret(projectID, secretName string) (string, error) { - ctx := context.Background() - client, err := secretmanager.NewClient(ctx) - if err != nil { - return "", fmt.Errorf("failed to create secret manager client: %v", err) - } - defer client.Close() - - // Build the request. - // We use "latest" to fetch the most recent version of the secret. - name := fmt.Sprintf("projects/%s/secrets/%s/versions/latest", projectID, secretName) - - req := &secretmanagerpb.AccessSecretVersionRequest{ - Name: name, - } - - // Call the API. - result, err := client.AccessSecretVersion(ctx, req) - if err != nil { - return "", fmt.Errorf("failed to access secret version: %v", err) - } - - // Return the secret payload. - return string(result.Payload.Data), nil -} diff --git a/pkg/utils/test_utils.go b/pkg/utils/test_utils.go deleted file mode 100644 index 67abfe2..0000000 --- a/pkg/utils/test_utils.go +++ /dev/null @@ -1,18 +0,0 @@ -package utils - -import ( - "os" -) - -// SetEnv sets an environment variable and returns a function to restore it. -func SetEnv(key, value string) func() { - originalValue, exists := os.LookupEnv(key) - os.Setenv(key, value) - return func() { - if exists { - os.Setenv(key, originalValue) - } else { - os.Unsetenv(key) - } - } -}