diff --git a/README.md b/README.md index 3584171..473400e 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,18 @@ Command line utility to interact with your local and remote docker compose sites. +## Why sitectl vs Docker Context? + +While [Docker's native context feature](https://docs.docker.com/engine/manage-resources/contexts/) handles basic daemon connections, `sitectl` is purpose-built for Docker Compose projects and adds: + +- **Enhanced remote operations**: SFTP file operations (read env files, upload/download), sudo support, and helpful SSH error messages +- **Container utilities**: Resolve service names to containers, extract secrets/env vars to better support `exec` operations inside containers, get container IPs within Docker networks +- **Plugin architecture**: Extend `sitectl` for project-specific needs (e.g. islandora, drupal, etc.) +- **Service management**: Enable/disable services in docker-compose.yml with automatic cleanup of orphaned resources and Drupal configuration +- **Compose-first design**: Set the equivalent of `DOCKER_HOST`, `COMPOSE_PROJECT_NAME`, `COMPOSE_FILE`, `COMPOSE_ENV_FILES` automatically based on sitectl context settings + - See [Docker's documentation](https://docs.docker.com/compose/how-tos/environment-variables/envvars/#configuration-details) for what these environment variables do + + ## Attribution - The `config` commands for setting contexts were heavily inspired by `kubectl` diff --git a/cmd/root.go b/cmd/root.go index a351bc5..95dbc05 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -64,14 +64,8 @@ func init() { ll = "INFO" } - apiURL := os.Getenv("LIBOPS_API_URL") - if apiURL == "" { - apiURL = "https://api.libops.io" - } - RootCmd.PersistentFlags().String("context", c, "The sitectl context to use. See sitectl config --help for more info") RootCmd.PersistentFlags().String("log-level", ll, "The logging level for the command") - RootCmd.PersistentFlags().String("api-url", apiURL, "Base URL of the libops API") RootCmd.PersistentFlags().String("format", "table", `Format output using a custom template: 'table': Print output in table format with column headers (default) 'table TEMPLATE': Print output in table format using the given Go template diff --git a/cmd/sequelace.go b/cmd/sequelace.go new file mode 100644 index 0000000..d821440 --- /dev/null +++ b/cmd/sequelace.go @@ -0,0 +1,58 @@ +package cmd + +import ( + "fmt" + "log/slog" + "os/exec" + "runtime" + + "github.com/libops/sitectl/pkg/config" + "github.com/libops/sitectl/pkg/docker" + + "github.com/spf13/cobra" +) + +var sequelAceCmd = &cobra.Command{ + Use: "sequelace", + Short: "Connect to your MySQL/Mariadb database using Sequel Ace (Mac OS only)", + RunE: func(cmd *cobra.Command, args []string) error { + if runtime.GOOS != "darwin" { + return fmt.Errorf("sequelace is only supported on mac OS") + } + + f := cmd.Flags() + context, err := config.CurrentContext(f) + if err != nil { + return err + } + + sequelAcePath, err := f.GetString("sequel-ace-path") + if err != nil { + return err + } + + mysql, ssh, err := docker.GetDatabaseUris(context) + if err != nil { + return err + } + slog.Debug("uris", "mysql", mysql, "ssh", ssh) + cmdArgs := []string{ + fmt.Sprintf("%s?%s", mysql, ssh), + "-a", + sequelAcePath, + } + openCmd := exec.Command("open", cmdArgs...) + if err := openCmd.Run(); err != nil { + slog.Error("Could not open sequelace.") + return err + } + + return nil + }, +} + +func init() { + RootCmd.AddCommand(sequelAceCmd) + + sequelAceCmd.Flags().String("sequel-ace-path", "/Applications/Sequel Ace.app/Contents/MacOS/Sequel Ace", "Full path to your Sequel Ace app") +} diff --git a/go.mod b/go.mod index 05e255b..9997ee1 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,9 @@ module github.com/libops/sitectl go 1.25.3 require ( - connectrpc.com/connect v1.19.1 github.com/docker/docker v28.5.2+incompatible github.com/joho/godotenv v1.5.1 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 - github.com/libops/api/proto v0.0.1 github.com/pkg/sftp v1.13.10 github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 @@ -30,8 +28,6 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/google/gnostic v0.7.1 // indirect - github.com/google/gnostic-models v0.7.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kr/fs v0.1.0 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect @@ -48,10 +44,8 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.35.0 // indirect go.opentelemetry.io/otel/metric v1.39.0 // indirect go.opentelemetry.io/otel/trace v1.39.0 // indirect - go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/net v0.48.0 // indirect golang.org/x/sys v0.39.0 // indirect - golang.org/x/text v0.32.0 // indirect golang.org/x/time v0.14.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 // indirect diff --git a/go.sum b/go.sum index 7e328ca..45a7373 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= -connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= @@ -32,12 +30,6 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/gnostic v0.7.1 h1:t5Kc7j/8kYr8t2u11rykRrPPovlEMG4+xdc/SpekATs= -github.com/google/gnostic v0.7.1/go.mod h1:KSw6sxnxEBFM8jLPfJd46xZP+yQcfE8XkiqfZx5zR28= -github.com/google/gnostic-models v0.7.1 h1:SisTfuFKJSKM5CPZkffwi6coztzzeYUhc3v4yxLWH8c= -github.com/google/gnostic-models v0.7.1/go.mod h1:whL5G0m6dmc5cPxKc5bdKdEN3UjI7OUGxBlw57miDrQ= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -56,8 +48,6 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/libops/api/proto v0.0.1 h1:DCpWMOJK2vUsXk9r+VhXivy2SVdsj8XNL46MwAplW54= -github.com/libops/api/proto v0.0.1/go.mod h1:0acmYutF3M5smZ3Za8LHSWbC+ccOxfWmTtzwslJ1fMk= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw= @@ -110,7 +100,6 @@ go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6 go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4= go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4= -go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= @@ -124,8 +113,6 @@ golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2 h1:7LRqPCEdE4TP4/9psdaB7F2nhZFfBiGJomA5sojLWdU= google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 h1:2I6GHUeJ/4shcDpoUlLs/2WPnhg7yJwvXtqcMJt9liA= diff --git a/internal/utils/helper.go b/internal/utils/helper.go deleted file mode 100644 index fb34f24..0000000 --- a/internal/utils/helper.go +++ /dev/null @@ -1,69 +0,0 @@ -package utils - -import ( - "fmt" - "log/slog" - "os" - "os/exec" - "runtime" - "strings" - - "github.com/spf13/cobra" -) - -func ExitOnError(err error) { - slog.Error(err.Error()) - os.Exit(1) -} - -// open a URL from the terminal -func OpenURL(url string) error { - var cmd *exec.Cmd - switch runtime.GOOS { - case "windows": - cmd = exec.Command("cmd", "/c", "start", url) - case "darwin": - cmd = exec.Command("open", url) - case "linux": - cmd = exec.Command("xdg-open", url) - default: - return fmt.Errorf("unknown runtime command to open URL") - } - - return cmd.Start() -} - -// for cobra commands that allow arbitrary args to facilitate passing flags to other commands -// strip out sitectl's context flag from the args if it was passed -func GetContextFromArgs(cmd *cobra.Command, args []string) ([]string, string, error) { - siteCtx, err := cmd.Root().PersistentFlags().GetString("context") - if err != nil { - return nil, "", err - } - - // remove --context flag from the args if it exists - // and set it as the default context if it was passed as a flag - filteredArgs := []string{} - skipNext := false - for _, arg := range args { - if arg == "--context" { - skipNext = true - continue - } - if strings.HasPrefix(arg, "--context=") { - components := strings.Split(arg, "=") - siteCtx = components[1] - continue - } - if skipNext { - siteCtx = arg - skipNext = false - continue - } - filteredArgs = append(filteredArgs, arg) - } - - siteCtx = strings.Trim(siteCtx, `" `) - - return filteredArgs, siteCtx, nil -} diff --git a/pkg/api/client.go b/pkg/api/client.go deleted file mode 100644 index 728dc7a..0000000 --- a/pkg/api/client.go +++ /dev/null @@ -1,165 +0,0 @@ -package api - -import ( - "context" - "fmt" - "net/http" - "os" - "path/filepath" - "strings" - - "github.com/libops/api/proto/libops/v1/libopsv1connect" - "github.com/libops/sitectl/pkg/auth" -) - -// LibopsAPIClient holds all the service clients -type LibopsAPIClient struct { - OrganizationService libopsv1connect.OrganizationServiceClient - ProjectService libopsv1connect.ProjectServiceClient - SiteService libopsv1connect.SiteServiceClient - AccountService libopsv1connect.AccountServiceClient - - // Members - MemberService libopsv1connect.MemberServiceClient - ProjectMemberService libopsv1connect.ProjectMemberServiceClient - SiteMemberService libopsv1connect.SiteMemberServiceClient - - // Firewall - FirewallService libopsv1connect.FirewallServiceClient - ProjectFirewallService libopsv1connect.ProjectFirewallServiceClient - SiteFirewallService libopsv1connect.SiteFirewallServiceClient - - // Secrets - OrganizationSecretService libopsv1connect.OrganizationSecretServiceClient - ProjectSecretService libopsv1connect.ProjectSecretServiceClient - SiteSecretService libopsv1connect.SiteSecretServiceClient -} - -// authTransport is an http.RoundTripper that adds an Authorization header to requests -// and handles automatic token refreshing. -type authTransport struct { - apiBaseURL string - next http.RoundTripper -} - -func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Check for API key first - apiKey, err := loadAPIKey() - if err == nil && apiKey != "" { - // Use API key authentication - req.Header.Set("Authorization", "Bearer "+apiKey) - return t.next.RoundTrip(req) - } - - // Fall back to OAuth tokens - tokens, err := auth.LoadTokens() - if err != nil { - // If we can't load tokens, just proceed without auth (likely to fail) or return error? - // Let's return error as we expect to be authenticated. - return nil, fmt.Errorf("failed to load tokens: %w", err) - } - - // Add Authorization header - req.Header.Set("Authorization", "Bearer "+tokens.IDToken) - - resp, err := t.next.RoundTrip(req) - if err != nil { - return nil, err - } - - // If we get a 401, the token is invalid - user needs to re-login - if resp.StatusCode == http.StatusUnauthorized { - _ = auth.ClearTokens() - } - - return resp, nil -} - -// loadAPIKey loads the API key from ~/.sitectl/key -func loadAPIKey() (string, error) { - homeDir := os.Getenv("HOME") - if homeDir == "" { - return "", fmt.Errorf("HOME environment variable not set") - } - - keyPath := filepath.Join(homeDir, ".sitectl", "key") - data, err := os.ReadFile(keyPath) - if err != nil { - return "", err - } - - return strings.TrimSpace(string(data)), nil -} - -// NewLibopsAPIClient creates and returns a new LibopsAPIClient instance. -// It initializes all necessary service clients with authentication. -func NewLibopsAPIClient(ctx context.Context, apiBaseURL string) (*LibopsAPIClient, error) { - // Check for API key first - apiKey, err := loadAPIKey() - if err == nil && apiKey != "" { - // API key found, skip token checks - authenticatedClient := &http.Client{ - Transport: &authTransport{ - apiBaseURL: apiBaseURL, - next: http.DefaultTransport, - }, - } - - return &LibopsAPIClient{ - OrganizationService: libopsv1connect.NewOrganizationServiceClient(authenticatedClient, apiBaseURL), - ProjectService: libopsv1connect.NewProjectServiceClient(authenticatedClient, apiBaseURL), - SiteService: libopsv1connect.NewSiteServiceClient(authenticatedClient, apiBaseURL), - AccountService: libopsv1connect.NewAccountServiceClient(authenticatedClient, apiBaseURL), - - MemberService: libopsv1connect.NewMemberServiceClient(authenticatedClient, apiBaseURL), - ProjectMemberService: libopsv1connect.NewProjectMemberServiceClient(authenticatedClient, apiBaseURL), - SiteMemberService: libopsv1connect.NewSiteMemberServiceClient(authenticatedClient, apiBaseURL), - - FirewallService: libopsv1connect.NewFirewallServiceClient(authenticatedClient, apiBaseURL), - ProjectFirewallService: libopsv1connect.NewProjectFirewallServiceClient(authenticatedClient, apiBaseURL), - SiteFirewallService: libopsv1connect.NewSiteFirewallServiceClient(authenticatedClient, apiBaseURL), - - OrganizationSecretService: libopsv1connect.NewOrganizationSecretServiceClient(authenticatedClient, apiBaseURL), - ProjectSecretService: libopsv1connect.NewProjectSecretServiceClient(authenticatedClient, apiBaseURL), - SiteSecretService: libopsv1connect.NewSiteSecretServiceClient(authenticatedClient, apiBaseURL), - }, nil - } - - // Fall back to OAuth tokens - tokens, err := auth.LoadTokens() - if err != nil { - return nil, fmt.Errorf("failed to load authentication tokens: %w", err) - } - - // Check if token is expired - if tokens.IsTokenExpired() { - _ = auth.ClearTokens() - return nil, fmt.Errorf("authentication token expired, please run 'sitectl login' to re-authenticate") - } - - authenticatedClient := &http.Client{ - Transport: &authTransport{ - apiBaseURL: apiBaseURL, - next: http.DefaultTransport, - }, - } - - return &LibopsAPIClient{ - OrganizationService: libopsv1connect.NewOrganizationServiceClient(authenticatedClient, apiBaseURL), - ProjectService: libopsv1connect.NewProjectServiceClient(authenticatedClient, apiBaseURL), - SiteService: libopsv1connect.NewSiteServiceClient(authenticatedClient, apiBaseURL), - AccountService: libopsv1connect.NewAccountServiceClient(authenticatedClient, apiBaseURL), - - MemberService: libopsv1connect.NewMemberServiceClient(authenticatedClient, apiBaseURL), - ProjectMemberService: libopsv1connect.NewProjectMemberServiceClient(authenticatedClient, apiBaseURL), - SiteMemberService: libopsv1connect.NewSiteMemberServiceClient(authenticatedClient, apiBaseURL), - - FirewallService: libopsv1connect.NewFirewallServiceClient(authenticatedClient, apiBaseURL), - ProjectFirewallService: libopsv1connect.NewProjectFirewallServiceClient(authenticatedClient, apiBaseURL), - SiteFirewallService: libopsv1connect.NewSiteFirewallServiceClient(authenticatedClient, apiBaseURL), - - OrganizationSecretService: libopsv1connect.NewOrganizationSecretServiceClient(authenticatedClient, apiBaseURL), - ProjectSecretService: libopsv1connect.NewProjectSecretServiceClient(authenticatedClient, apiBaseURL), - SiteSecretService: libopsv1connect.NewSiteSecretServiceClient(authenticatedClient, apiBaseURL), - }, nil -} diff --git a/pkg/auth/client.go b/pkg/auth/client.go deleted file mode 100644 index dece7a8..0000000 --- a/pkg/auth/client.go +++ /dev/null @@ -1,247 +0,0 @@ -package auth - -import ( - "context" - "crypto/rand" - "encoding/base64" - "fmt" - "log/slog" - "net" - "net/http" - "os/exec" - "runtime" - "time" -) - -// AuthClient handles unified browser-based authentication. -type AuthClient struct { - apiBaseURL string -} - -// callbackResult holds the result of the OAuth callback. -type callbackResult struct { - tokens *TokenResponse - err error -} - -// NewAuthClient creates a new authentication client. -func NewAuthClient(apiBaseURL string) *AuthClient { - return &AuthClient{ - apiBaseURL: apiBaseURL, - } -} - -// Login opens the browser to the API's login page and waits for the callback. -func (c *AuthClient) Login(ctx context.Context) (*TokenResponse, error) { - // Start a local HTTP server on a random available port - listener, err := net.Listen("tcp", "localhost:0") - if err != nil { - return nil, fmt.Errorf("failed to start local server: %w", err) - } - defer listener.Close() - - port := listener.Addr().(*net.TCPAddr).Port - slog.Debug("Started local callback server", "port", port) - - // Generate random state for CSRF protection - state, err := generateRandomState() - if err != nil { - return nil, fmt.Errorf("failed to generate state: %w", err) - } - - // Build the login URL that points to the API's login page - // The API will show both Google and userpass options - // Pass redirect_uri so the API knows where to send the user after authentication - redirectURI := fmt.Sprintf("http://localhost:%d/callback", port) - loginURL := fmt.Sprintf("%s/login?redirect_uri=%s&state=%s", c.apiBaseURL, redirectURI, state) - - // Create a channel to receive the callback result - resultChan := make(chan callbackResult, 1) - - // Set up the HTTP server with callback handler - mux := http.NewServeMux() - mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { - c.handleCallback(w, r, state, resultChan) - }) - - server := &http.Server{ - Handler: mux, - } - - // Start the server in a goroutine - go func() { - if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { - slog.Error("Server error", "err", err) - } - }() - defer func() { - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := server.Shutdown(shutdownCtx); err != nil { - slog.Error("Failed to shutdown server", "err", err) - } - }() - - // Open browser to the login page - if err := openBrowser(loginURL); err != nil { - fmt.Printf("Failed to open browser automatically. Please visit:\n%s\n", loginURL) - } - - // Wait for callback or timeout - select { - case result := <-resultChan: - if result.err != nil { - return nil, result.err - } - return result.tokens, nil - case <-time.After(5 * time.Minute): - return nil, fmt.Errorf("authentication timed out after 5 minutes") - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -// handleCallback processes the callback from the API after authentication. -func (c *AuthClient) handleCallback(w http.ResponseWriter, r *http.Request, expectedState string, resultChan chan<- callbackResult) { - state := r.URL.Query().Get("state") - errorParam := r.URL.Query().Get("error") - errorDesc := r.URL.Query().Get("error_description") - - if errorParam != "" { - resultChan <- callbackResult{ - err: fmt.Errorf("authentication error: %s - %s", errorParam, errorDesc), - } - http.Error(w, fmt.Sprintf("Authentication failed: %s", errorDesc), http.StatusBadRequest) - return - } - - if state != expectedState { - resultChan <- callbackResult{ - err: fmt.Errorf("invalid state parameter"), - } - http.Error(w, "Invalid state", http.StatusBadRequest) - return - } - - // Extract tokens from cookies set by the API's /auth/callback endpoint - var accessToken, idToken string - expiresIn := 3600 - - for _, cookie := range r.Cookies() { - switch cookie.Name { - case "vault_token": - accessToken = cookie.Value - if cookie.MaxAge > 0 { - expiresIn = cookie.MaxAge - } - case "id_token": - idToken = cookie.Value - } - } - - // If tokens aren't in cookies, try query parameters (alternative approach) - if idToken == "" { - idToken = r.URL.Query().Get("id_token") - } - if accessToken == "" { - accessToken = r.URL.Query().Get("access_token") - } - - if idToken == "" { - resultChan <- callbackResult{ - err: fmt.Errorf("authentication completed but no tokens received"), - } - http.Error(w, "No tokens received", http.StatusBadRequest) - return - } - - // Calculate expiry date - expiryDate := time.Now().Unix() + int64(expiresIn) - - tokens := &TokenResponse{ - AccessToken: accessToken, - IDToken: idToken, - TokenType: "Bearer", - ExpiryDate: expiryDate, - Scope: "openid email profile", - } - - // Send success page to browser - w.Header().Set("Content-Type", "text/html") - successHTML := ` - - - - Authentication Successful - - - -
-

Authentication Successful!

-

You can close this window and return to your terminal.

-
- - - -` - if _, err := w.Write([]byte(successHTML)); err != nil { - slog.Error("Failed to write response", "err", err) - } - - resultChan <- callbackResult{tokens: tokens} -} - -// generateRandomState generates a cryptographically secure random state parameter. -func generateRandomState() (string, error) { - bytes := make([]byte, 32) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - return base64.URLEncoding.EncodeToString(bytes)[:32], nil -} - -// openBrowser opens the specified URL in the default browser. -func openBrowser(url string) error { - var cmd *exec.Cmd - - switch runtime.GOOS { - case "darwin": - cmd = exec.Command("open", url) - case "linux": - cmd = exec.Command("xdg-open", url) - case "windows": - cmd = exec.Command("cmd", "/c", "start", url) - default: - return fmt.Errorf("unsupported platform: %s", runtime.GOOS) - } - - return cmd.Start() -} diff --git a/pkg/auth/token.go b/pkg/auth/token.go deleted file mode 100644 index aed26c9..0000000 --- a/pkg/auth/token.go +++ /dev/null @@ -1,97 +0,0 @@ -package auth - -import ( - "encoding/json" - "fmt" - "os" - "path/filepath" - "time" -) - -// TokenResponse represents the OAuth token response stored locally. -type TokenResponse struct { - AccessToken string `json:"access_token,omitempty"` - IDToken string `json:"id_token"` - TokenType string `json:"token_type"` - ExpiryDate int64 `json:"expiry_date"` - Scope string `json:"scope"` -} - -// TokenFilePath returns the path to the OAuth token file. -func TokenFilePath() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("unable to detect home directory: %w", err) - } - - baseDir := filepath.Join(homeDir, ".sitectl") - if _, err := os.Stat(baseDir); os.IsNotExist(err) { - if err := os.Mkdir(baseDir, 0700); err != nil { - return "", fmt.Errorf("unable to create ~/.sitectl directory: %w", err) - } - } - - return filepath.Join(baseDir, "oauth.json"), nil -} - -// SaveTokens saves OAuth tokens to disk with restricted permissions. -func SaveTokens(tokens *TokenResponse) error { - tokenPath, err := TokenFilePath() - if err != nil { - return err - } - - data, err := json.MarshalIndent(tokens, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal tokens: %w", err) - } - - // Write with restrictive permissions (0600 = rw-------) - if err := os.WriteFile(tokenPath, data, 0600); err != nil { - return fmt.Errorf("failed to write token file: %w", err) - } - - return nil -} - -// LoadTokens loads OAuth tokens from disk. -func LoadTokens() (*TokenResponse, error) { - tokenPath, err := TokenFilePath() - if err != nil { - return nil, err - } - - data, err := os.ReadFile(tokenPath) - if err != nil { - if os.IsNotExist(err) { - return nil, fmt.Errorf("not authenticated: run 'sitectl login' first") - } - return nil, fmt.Errorf("failed to read token file: %w", err) - } - - var tokens TokenResponse - if err := json.Unmarshal(data, &tokens); err != nil { - return nil, fmt.Errorf("failed to parse token file: %w", err) - } - - return &tokens, nil -} - -// IsTokenExpired checks if the token has expired. -func (t *TokenResponse) IsTokenExpired() bool { - return time.Now().Unix() >= t.ExpiryDate -} - -// ClearTokens removes the token file from disk. -func ClearTokens() error { - tokenPath, err := TokenFilePath() - if err != nil { - return err - } - - if err := os.Remove(tokenPath); err != nil && !os.IsNotExist(err) { - return fmt.Errorf("failed to remove token file: %w", err) - } - - return nil -} diff --git a/pkg/auth/token_test.go b/pkg/auth/token_test.go deleted file mode 100644 index 464c998..0000000 --- a/pkg/auth/token_test.go +++ /dev/null @@ -1,195 +0,0 @@ -package auth - -import ( - "os" - "path/filepath" - "testing" - "time" -) - -func TestTokenFilePath(t *testing.T) { - path, err := TokenFilePath() - if err != nil { - t.Fatalf("TokenFilePath() failed: %v", err) - } - - if path == "" { - t.Error("TokenFilePath() returned empty path") - } - - if !filepath.IsAbs(path) { - t.Errorf("TokenFilePath() returned non-absolute path: %s", path) - } - - expectedSuffix := filepath.Join(".sitectl", "oauth.json") - if !filepath.IsAbs(path) || filepath.Base(filepath.Dir(path)) != ".sitectl" { - t.Errorf("TokenFilePath() = %s, should contain %s", path, expectedSuffix) - } -} - -func TestSaveAndLoadTokens(t *testing.T) { - // Create a temporary directory for testing - tempDir := t.TempDir() - originalHome := os.Getenv("HOME") - os.Setenv("HOME", tempDir) - defer os.Setenv("HOME", originalHome) - - tokens := &TokenResponse{ - AccessToken: "test_access_token", - IDToken: "test_id_token", - TokenType: "Bearer", - ExpiryDate: time.Now().Unix() + 3600, - Scope: "openid email profile", - } - - // Test saving tokens - err := SaveTokens(tokens) - if err != nil { - t.Fatalf("SaveTokens() failed: %v", err) - } - - // Verify file was created with correct permissions - tokenPath, _ := TokenFilePath() - info, err := os.Stat(tokenPath) - if err != nil { - t.Fatalf("Token file not created: %v", err) - } - - // Check file permissions (should be 0600) - expectedPerms := os.FileMode(0600) - if info.Mode().Perm() != expectedPerms { - t.Errorf("Token file permissions = %o, want %o", info.Mode().Perm(), expectedPerms) - } - - // Test loading tokens - loadedTokens, err := LoadTokens() - if err != nil { - t.Fatalf("LoadTokens() failed: %v", err) - } - - // Verify loaded tokens match saved tokens - if loadedTokens.AccessToken != tokens.AccessToken { - t.Errorf("AccessToken = %s, want %s", loadedTokens.AccessToken, tokens.AccessToken) - } - if loadedTokens.IDToken != tokens.IDToken { - t.Errorf("IDToken = %s, want %s", loadedTokens.IDToken, tokens.IDToken) - } - if loadedTokens.TokenType != tokens.TokenType { - t.Errorf("TokenType = %s, want %s", loadedTokens.TokenType, tokens.TokenType) - } - if loadedTokens.ExpiryDate != tokens.ExpiryDate { - t.Errorf("ExpiryDate = %d, want %d", loadedTokens.ExpiryDate, tokens.ExpiryDate) - } - if loadedTokens.Scope != tokens.Scope { - t.Errorf("Scope = %s, want %s", loadedTokens.Scope, tokens.Scope) - } -} - -func TestLoadTokens_NotFound(t *testing.T) { - // Create a temporary directory for testing - tempDir := t.TempDir() - originalHome := os.Getenv("HOME") - os.Setenv("HOME", tempDir) - defer os.Setenv("HOME", originalHome) - - // Ensure no token file exists - tokenPath, _ := TokenFilePath() - os.Remove(tokenPath) - - _, err := LoadTokens() - if err == nil { - t.Error("LoadTokens() should fail when token file doesn't exist") - } -} - -func TestIsTokenExpired(t *testing.T) { - tests := []struct { - name string - expiryDate int64 - want bool - }{ - { - name: "token expired", - expiryDate: time.Now().Unix() - 3600, // 1 hour ago - want: true, - }, - { - name: "token valid", - expiryDate: time.Now().Unix() + 3600, // 1 hour from now - want: false, - }, - { - name: "token expires now", - expiryDate: time.Now().Unix(), - want: true, // Should be considered expired if exactly at expiry time - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - token := &TokenResponse{ - ExpiryDate: tt.expiryDate, - } - if got := token.IsTokenExpired(); got != tt.want { - t.Errorf("IsTokenExpired() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestClearTokens(t *testing.T) { - // Create a temporary directory for testing - tempDir := t.TempDir() - originalHome := os.Getenv("HOME") - os.Setenv("HOME", tempDir) - defer os.Setenv("HOME", originalHome) - - // Create a token file - tokens := &TokenResponse{ - AccessToken: "test_token", - IDToken: "test_id_token", - TokenType: "Bearer", - ExpiryDate: time.Now().Unix() + 3600, - Scope: "openid", - } - - err := SaveTokens(tokens) - if err != nil { - t.Fatalf("SaveTokens() failed: %v", err) - } - - // Verify file exists - tokenPath, _ := TokenFilePath() - if _, err := os.Stat(tokenPath); os.IsNotExist(err) { - t.Fatal("Token file was not created") - } - - // Clear tokens - err = ClearTokens() - if err != nil { - t.Fatalf("ClearTokens() failed: %v", err) - } - - // Verify file was removed - if _, err := os.Stat(tokenPath); !os.IsNotExist(err) { - t.Error("Token file still exists after ClearTokens()") - } -} - -func TestClearTokens_NotFound(t *testing.T) { - // Create a temporary directory for testing - tempDir := t.TempDir() - originalHome := os.Getenv("HOME") - os.Setenv("HOME", tempDir) - defer os.Setenv("HOME", originalHome) - - // Ensure no token file exists - tokenPath, _ := TokenFilePath() - os.Remove(tokenPath) - - // Clearing non-existent tokens should not error - err := ClearTokens() - if err != nil { - t.Errorf("ClearTokens() failed when no tokens exist: %v", err) - } -} diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go deleted file mode 100644 index 1ef5624..0000000 --- a/pkg/cache/cache.go +++ /dev/null @@ -1,207 +0,0 @@ -package cache - -import ( - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" - "os" - "path/filepath" - "time" -) - -const ( - cacheDir = ".sitectl/cache" - cacheValidity = 12 * time.Hour -) - -// CacheKey represents a structured cache key -type CacheKey struct { - ResourceType string // "organization", "project", "site" - Operation string // "list", "get" - ParentType string // optional: parent resource type - ParentID string // optional: parent resource ID - SubResource string // optional: "firewall", "members", "secrets" - ResourceID string // optional: specific resource ID -} - -// GetCachePath returns the file path for a cache key -func (k CacheKey) GetCachePath() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get home directory: %w", err) - } - - parts := []string{ - homeDir, - cacheDir, - k.Operation, - } - - if k.ParentType != "" && k.ParentID != "" { - // Cached sub-resource: ~/.sitectl/cache/list/organization//firewall/list.resp - parts = append(parts, k.ParentType, k.ParentID) - if k.SubResource != "" { - parts = append(parts, k.SubResource) - } - } else { - // Cached resource: ~/.sitectl/cache/list/organization/list.resp - parts = append(parts, k.ResourceType) - } - - // Determine filename - var filename string - if k.ResourceID != "" { - filename = fmt.Sprintf("%s.resp", k.ResourceID) - } else { - filename = "list.resp" - } - - parts = append(parts, filename) - return filepath.Join(parts...), nil -} - -// Get retrieves a cached value if it exists and is not expired -func Get(key CacheKey, target interface{}) (bool, error) { - path, err := key.GetCachePath() - if err != nil { - return false, err - } - - // Check if file exists - info, err := os.Stat(path) - if os.IsNotExist(err) { - return false, nil - } - if err != nil { - return false, err - } - - // Check if cache is expired - if time.Since(info.ModTime()) > cacheValidity { - // Cache expired, delete it - os.Remove(path) - return false, nil - } - - // Read cache file - data, err := os.ReadFile(path) - if err != nil { - return false, err - } - - // Unmarshal into target - if err := json.Unmarshal(data, target); err != nil { - // Cache corrupted, delete it - os.Remove(path) - return false, nil - } - - return true, nil -} - -// Set stores a value in the cache -func Set(key CacheKey, value interface{}) error { - path, err := key.GetCachePath() - if err != nil { - return err - } - - // Create directory structure - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0755); err != nil { - return fmt.Errorf("failed to create cache directory: %w", err) - } - - // Marshal value - data, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("failed to marshal cache data: %w", err) - } - - // Write to file - if err := os.WriteFile(path, data, 0600); err != nil { - return fmt.Errorf("failed to write cache file: %w", err) - } - - return nil -} - -// Invalidate removes a cached value -func Invalidate(key CacheKey) error { - path, err := key.GetCachePath() - if err != nil { - return err - } - - // Remove file if it exists - if err := os.Remove(path); err != nil && !os.IsNotExist(err) { - return err - } - - return nil -} - -// InvalidatePattern removes all cache entries matching a pattern -// This is useful for invalidating all caches related to a resource -func InvalidatePattern(resourceType, resourceID string) error { - homeDir, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("failed to get home directory: %w", err) - } - - // Invalidate list cache for this resource type - listKey := CacheKey{ - ResourceType: resourceType, - Operation: "list", - } - err = Invalidate(listKey) - if err != nil { - return fmt.Errorf("failed to invalidate cache: %w", err) - } - - // If we have a specific resource ID, invalidate its get cache and all sub-resources - if resourceID != "" { - getKey := CacheKey{ - ResourceType: resourceType, - Operation: "get", - ResourceID: resourceID, - } - err = Invalidate(getKey) - if err != nil { - return fmt.Errorf("failed to invalidate cache: %w", err) - } - - // Invalidate all sub-resource caches - subResources := []string{"firewall", "members", "secrets"} - for _, subResource := range subResources { - subCacheDir := filepath.Join(homeDir, cacheDir, "list", resourceType, resourceID, subResource) - if err := os.RemoveAll(subCacheDir); err != nil && !os.IsNotExist(err) { - return err - } - } - } - - return nil -} - -// Clear removes all cached data -func Clear() error { - homeDir, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("failed to get home directory: %w", err) - } - - cachePath := filepath.Join(homeDir, cacheDir) - if err := os.RemoveAll(cachePath); err != nil && !os.IsNotExist(err) { - return err - } - - return nil -} - -// HashID creates a short hash for cache keys (for very long IDs) -func HashID(id string) string { - hash := sha256.Sum256([]byte(id)) - return hex.EncodeToString(hash[:8]) // Use first 8 bytes (16 hex chars) -} diff --git a/pkg/config/context.go b/pkg/config/context.go index fa677a4..18ecaaf 100644 --- a/pkg/config/context.go +++ b/pkg/config/context.go @@ -29,19 +29,23 @@ const ( ) type Context struct { - Name string `yaml:"name"` - DockerHostType ContextType `mapstructure:"type" yaml:"type"` - DockerSocket string `yaml:"docker-socket"` - ProjectName string `yaml:"project-name"` - Profile string `yaml:"profile"` - ProjectDir string `yaml:"project-dir"` - SSHUser string `yaml:"ssh-user"` - SSHHostname string `yaml:"ssh-hostname,omitempty"` - SSHPort uint `yaml:"ssh-port,omitempty"` - SSHKeyPath string `yaml:"ssh-key,omitempty"` - EnvFile []string `yaml:"env-file"` - RunSudo bool `yaml:"sudo"` - UriMap map[string]string `yaml:"uriMap"` + Name string `yaml:"name"` + DockerHostType ContextType `mapstructure:"type" yaml:"type"` + DockerSocket string `yaml:"docker-socket"` + ProjectName string `yaml:"project-name"` + ProjectDir string `yaml:"project-dir"` + SSHUser string `yaml:"ssh-user"` + SSHHostname string `yaml:"ssh-hostname,omitempty"` + SSHPort uint `yaml:"ssh-port,omitempty"` + SSHKeyPath string `yaml:"ssh-key,omitempty"` + EnvFile []string `yaml:"env-file"` + RunSudo bool `yaml:"sudo"` + + // Database connection configuration + DatabaseService string `yaml:"database-service,omitempty"` + DatabaseUser string `yaml:"database-user,omitempty"` + DatabasePasswordSecret string `yaml:"database-password-secret,omitempty"` + DatabaseName string `yaml:"database-name,omitempty"` ReadSmallFileFunc func(filename string) string `yaml:"-"` } @@ -97,6 +101,20 @@ func SaveContext(ctx *Context, setDefault bool) error { return err } + // Set database defaults if not provided + if ctx.DatabaseService == "" { + ctx.DatabaseService = "mariadb" + } + if ctx.DatabaseUser == "" { + ctx.DatabaseUser = "root" + } + if ctx.DatabasePasswordSecret == "" { + ctx.DatabasePasswordSecret = "DB_ROOT_PASSWORD" + } + if ctx.DatabaseName == "" { + ctx.DatabaseName = "drupal_default" + } + updated := false for i, c := range cfg.Contexts { if c.Name == ctx.Name { @@ -437,3 +455,22 @@ func (c *Context) UploadFile(source, destination string) error { return nil } + +// GetSshUri returns an SSH connection URI +func (c *Context) GetSshUri() string { + if c.DockerHostType == ContextLocal { + return "" + } + + sshPort := c.SSHPort + if sshPort == 0 { + sshPort = 22 + } + + sshParams := fmt.Sprintf("sshHost=%s&sshUser=%s&sshPort=%d", c.SSHHostname, c.SSHUser, sshPort) + if c.SSHKeyPath != "" { + sshParams += fmt.Sprintf("&sshKeyFile=%s", c.SSHKeyPath) + } + + return sshParams +} diff --git a/pkg/config/context_test.go b/pkg/config/context_test.go index 6093da9..9a3bd00 100644 --- a/pkg/config/context_test.go +++ b/pkg/config/context_test.go @@ -77,14 +77,17 @@ func contextsEqual(a, b Context) bool { a.DockerHostType == b.DockerHostType && a.DockerSocket == b.DockerSocket && a.ProjectName == b.ProjectName && - a.Profile == b.Profile && a.ProjectDir == b.ProjectDir && a.SSHUser == b.SSHUser && a.SSHHostname == b.SSHHostname && a.SSHPort == b.SSHPort && a.SSHKeyPath == b.SSHKeyPath && len(a.EnvFile) == len(b.EnvFile) && - a.RunSudo == b.RunSudo + a.RunSudo == b.RunSudo && + a.DatabaseService == b.DatabaseService && + a.DatabaseUser == b.DatabaseUser && + a.DatabasePasswordSecret == b.DatabasePasswordSecret && + a.DatabaseName == b.DatabaseName } func TestContextString(t *testing.T) { @@ -308,7 +311,6 @@ func TestVerifyRemoteInputExistingConfig(t *testing.T) { SSHUser: "bar", SSHPort: 123, SSHKeyPath: "/assuming/we/already/checked", - Profile: "prod", ProjectName: "baz", } cc := original @@ -322,3 +324,70 @@ func TestVerifyRemoteInputExistingConfig(t *testing.T) { t.Fatalf("expected context %+v, got %+v", original, cc) } } + +func TestGetSshUri(t *testing.T) { + tests := []struct { + name string + context Context + expected string + }{ + { + name: "local context returns empty string", + context: Context{ + DockerHostType: ContextLocal, + }, + expected: "", + }, + { + name: "remote context with default port", + context: Context{ + DockerHostType: ContextRemote, + SSHHostname: "example.com", + SSHUser: "testuser", + SSHPort: 0, // Should default to 22 + }, + expected: "sshHost=example.com&sshUser=testuser&sshPort=22", + }, + { + name: "remote context with custom port", + context: Context{ + DockerHostType: ContextRemote, + SSHHostname: "example.com", + SSHUser: "testuser", + SSHPort: 2222, + }, + expected: "sshHost=example.com&sshUser=testuser&sshPort=2222", + }, + { + name: "remote context with SSH key path", + context: Context{ + DockerHostType: ContextRemote, + SSHHostname: "example.com", + SSHUser: "testuser", + SSHPort: 22, + SSHKeyPath: "/home/user/.ssh/id_rsa", + }, + expected: "sshHost=example.com&sshUser=testuser&sshPort=22&sshKeyFile=/home/user/.ssh/id_rsa", + }, + { + name: "remote context without SSH key path", + context: Context{ + DockerHostType: ContextRemote, + SSHHostname: "server.example.com", + SSHUser: "admin", + SSHPort: 22, + SSHKeyPath: "", + }, + expected: "sshHost=server.example.com&sshUser=admin&sshPort=22", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.context.GetSshUri() + if result != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, result) + } + }) + } +} diff --git a/pkg/config/utils.go b/pkg/config/utils.go index 348ef72..3c4ff6e 100644 --- a/pkg/config/utils.go +++ b/pkg/config/utils.go @@ -162,7 +162,10 @@ func SetCommandFlags(flags *pflag.FlagSet) { flags.String("ssh-key", "", "Path to SSH private key for remote context. e.g. "+key) flags.String("project-dir", "", "Path to docker compose project directory") flags.String("project-name", "docker-compose", "Name of the docker compose project") - flags.String("profile", "", "docker compose profile") flags.Bool("sudo", false, "for remote contexts, run docker commands as sudo") flags.StringSlice("env-file", []string{}, "when running remote docker commands, the --env-file paths to pass to docker compose") + flags.String("database-service", "mariadb", "Name of the database service in Docker Compose") + flags.String("database-user", "root", "Database user to connect as (e.g. root, admin)") + flags.String("database-password-secret", "DB_ROOT_PASSWORD", "Name of the docker compose secret containing the database password") + flags.String("database-name", "drupal_default", "Name of the database to connect to (e.g. drupal_default)") } diff --git a/pkg/config/utils_test.go b/pkg/config/utils_test.go index 8efdb84..ade1719 100644 --- a/pkg/config/utils_test.go +++ b/pkg/config/utils_test.go @@ -16,7 +16,6 @@ func TestLoadFromFlags(t *testing.T) { flags := pflag.NewFlagSet("test", pflag.ContinueOnError) flags.String("docker-socket", "/var/run/docker.sock", "Path to Docker socket") flags.String("type", "local", "Context type: local or remote") - flags.String("profile", "default", "Profile name") flags.String("ssh-hostname", "example.com", "SSH host for remote context") flags.Uint("ssh-port", 22, "port") flags.String("ssh-user", "user", "SSH user for remote context") @@ -25,12 +24,15 @@ func TestLoadFromFlags(t *testing.T) { flags.String("project-name", "foo", "Composer Project Name") flags.Bool("sudo", false, "Run commands on remote hosts as sudo") flags.StringSlice("env-file", []string{}, "path to env files to pass to docker compose") + flags.String("database-service", "mariadb", "Name of the database service in Docker Compose") + flags.String("database-user", "root", "Database user to connect as") + flags.String("database-password-secret", "DB_ROOT_PASSWORD", "Name of the secret containing the database password") + flags.String("database-name", "drupal_default", "Name of the database to connect to") // Define test arguments to override defaults. args := []string{ "--docker-socket", "/custom/docker.sock", "--type", "remote", - "--profile", "prod", "--ssh-hostname", "remote.example.com", "--ssh-port", "123", "--ssh-user", "remoteuser", @@ -58,9 +60,6 @@ func TestLoadFromFlags(t *testing.T) { if ctx.DockerHostType != "remote" { t.Errorf("Expected type 'remote', got %q", ctx.DockerHostType) } - if ctx.Profile != "prod" { - t.Errorf("Expected profile 'prod', got %q", ctx.Profile) - } if ctx.SSHHostname != "remote.example.com" { t.Errorf("Expected ssh-host 'remote.example.com', got %q", ctx.SSHHostname) } diff --git a/pkg/docker/docker.go b/pkg/docker/docker.go index cd5c993..ec9f84a 100644 --- a/pkg/docker/docker.go +++ b/pkg/docker/docker.go @@ -125,15 +125,12 @@ func (d *DockerClient) GetServiceIp(ctx context.Context, c *config.Context, cont return network.IPAddress, nil } -func (d *DockerClient) GetContainerName(c *config.Context, service string, neverPrefixProfile bool) (string, error) { +func (d *DockerClient) GetContainerName(c *config.Context, service string) (string, error) { ctx := context.Background() // Define the filters based on the Docker Compose labels. filterArgs := filters.NewArgs() filterArgs.Add("label", "com.docker.compose.project="+c.ProjectName) - if c.Profile != "" && !neverPrefixProfile { - service = service + "-" + c.Profile - } filterArgs.Add("label", "com.docker.compose.service="+service) slog.Debug("Querying docker", "filters", filterArgs) @@ -286,3 +283,34 @@ func (d *DockerClient) ExecInteractive(ctx context.Context, containerID string, Tty: true, }) } + +// GetDatabaseUris constructs MySQL and SSH connection URIs for database tools like Sequel Ace +// Returns: mysqlURI, sshURI, error +func GetDatabaseUris(c *config.Context) (string, string, error) { + ctx := context.Background() + + // Get Docker client + dockerCli, err := GetDockerCli(c) + if err != nil { + return "", "", fmt.Errorf("failed to get docker client: %w", err) + } + defer dockerCli.Close() + + // Get the database container name + containerName, err := dockerCli.GetContainerName(c, c.DatabaseService) + if err != nil { + return "", "", fmt.Errorf("failed to get %s container: %w", c.DatabaseService, err) + } + if containerName == "" { + return "", "", fmt.Errorf("%s container not found", c.DatabaseService) + } + + // Get database password from container environment + password, err := GetSecret(ctx, dockerCli.CLI, c, containerName, c.DatabasePasswordSecret) + if err != nil { + return "", "", fmt.Errorf("failed to get database password from %s: %w", c.DatabasePasswordSecret, err) + } + + mysqlURI := fmt.Sprintf("mysql://%s:%s@127.0.0.1:3306/%s", c.DatabaseUser, password, c.DatabaseName) + return mysqlURI, c.GetSshUri(), nil +} diff --git a/internal/utils/helper_test.go b/pkg/helpers/helper_test.go similarity index 99% rename from internal/utils/helper_test.go rename to pkg/helpers/helper_test.go index 7af744c..d7086dc 100644 --- a/internal/utils/helper_test.go +++ b/pkg/helpers/helper_test.go @@ -1,4 +1,4 @@ -package utils +package helpers import ( "reflect" diff --git a/pkg/plugin/sdk.go b/pkg/plugin/sdk.go index 4cd7a99..78f2653 100644 --- a/pkg/plugin/sdk.go +++ b/pkg/plugin/sdk.go @@ -83,9 +83,6 @@ func (s *SDK) setupLogging(cmd *cobra.Command) error { if s.RootCmd.PersistentFlags().Lookup("context") != nil { s.Config.Context, _ = cmd.Flags().GetString("context") } - if s.RootCmd.PersistentFlags().Lookup("api-url") != nil { - s.Config.APIUrl, _ = cmd.Flags().GetString("api-url") - } if s.RootCmd.PersistentFlags().Lookup("format") != nil { s.Config.Format, _ = cmd.Flags().GetString("format") } @@ -120,15 +117,9 @@ func (s *SDK) Execute() { } } -// AddLibopsFlags adds common libops-specific flags -func (s *SDK) AddLibopsFlags(currentContext string) { - apiURL := os.Getenv("LIBOPS_API_URL") - if apiURL == "" { - apiURL = "https://api.libops.io" - } - +// AddGlobalFlags adds common libops-specific flags +func (s *SDK) AddGlobalFlags(currentContext string) { s.RootCmd.PersistentFlags().String("context", currentContext, "The sitectl context to use. See sitectl config --help for more info") - s.RootCmd.PersistentFlags().String("api-url", apiURL, "Base URL of the libops API") } // GetMetadataCommand returns a command that displays plugin metadata diff --git a/pkg/resources/resources.go b/pkg/resources/resources.go deleted file mode 100644 index 7e49c69..0000000 --- a/pkg/resources/resources.go +++ /dev/null @@ -1,249 +0,0 @@ -package resources - -import ( - "context" - "fmt" - "log/slog" - - "connectrpc.com/connect" - - libopsv1 "github.com/libops/api/proto/libops/v1" - "github.com/libops/api/proto/libops/v1/common" - "github.com/libops/sitectl/pkg/api" - "github.com/libops/sitectl/pkg/cache" -) - -// Type aliases for cleaner code -type Organization = common.FolderConfig -type Project = common.ProjectConfig -type Site = common.SiteConfig - -// ListOrganizations returns all organizations, using cache when available -func ListOrganizations(ctx context.Context, apiBaseURL string, useCache bool) ([]*Organization, error) { - cacheKey := cache.CacheKey{ - ResourceType: "organization", - Operation: "list", - } - - // Try cache first - if useCache { - var cached []*Organization - found, err := cache.Get(cacheKey, &cached) - if err != nil { - slog.Warn("Failed to read cache", "err", err) - } else if found { - slog.Debug("Using cached organizations", "count", len(cached)) - return cached, nil - } - } - - // Fetch from API - client, err := api.NewLibopsAPIClient(ctx, apiBaseURL) - if err != nil { - return nil, err - } - - resp, err := client.OrganizationService.ListOrganizations(ctx, connect.NewRequest(&libopsv1.ListOrganizationsRequest{})) - if err != nil { - return nil, fmt.Errorf("failed to list organizations: %w", err) - } - - // Cache the result - if useCache { - if err := cache.Set(cacheKey, resp.Msg.Organizations); err != nil { - slog.Warn("Failed to cache organizations", "err", err) - } - } - - return resp.Msg.Organizations, nil -} - -// ListProjects returns all projects, using cache when available -func ListProjects(ctx context.Context, apiBaseURL string, useCache bool, orgID *string) ([]*Project, error) { - cacheKey := cache.CacheKey{ - ResourceType: "project", - Operation: "list", - } - - // Try cache first - if useCache { - var cached []*Project - found, err := cache.Get(cacheKey, &cached) - if err != nil { - slog.Warn("Failed to read cache", "err", err) - } else if found { - // Filter by org if needed - if orgID != nil && *orgID != "" { - filtered := make([]*Project, 0) - for _, p := range cached { - if p.OrganizationId == *orgID { - filtered = append(filtered, p) - } - } - slog.Debug("Using cached projects (filtered)", "count", len(filtered)) - return filtered, nil - } - slog.Debug("Using cached projects", "count", len(cached)) - return cached, nil - } - } - - // Fetch from API - client, err := api.NewLibopsAPIClient(ctx, apiBaseURL) - if err != nil { - return nil, err - } - - resp, err := client.ProjectService.ListProjects(ctx, connect.NewRequest(&libopsv1.ListProjectsRequest{ - OrganizationId: orgID, - })) - if err != nil { - return nil, fmt.Errorf("failed to list projects: %w", err) - } - - // Cache the result (only if not filtered) - if useCache && (orgID == nil || *orgID == "") { - if err := cache.Set(cacheKey, resp.Msg.Projects); err != nil { - slog.Warn("Failed to cache projects", "err", err) - } - } - - return resp.Msg.Projects, nil -} - -// ListSites returns all sites, using cache when available -func ListSites(ctx context.Context, apiBaseURL string, useCache bool, orgID, projectID *string) ([]*Site, error) { - cacheKey := cache.CacheKey{ - ResourceType: "site", - Operation: "list", - } - - // Try cache first - if useCache { - var cached []*Site - found, err := cache.Get(cacheKey, &cached) - if err != nil { - slog.Warn("Failed to read cache", "err", err) - } else if found { - // Filter by org/project if needed - filtered := cached - if orgID != nil && *orgID != "" { - temp := make([]*Site, 0) - for _, s := range filtered { - if s.OrganizationId == *orgID { - temp = append(temp, s) - } - } - filtered = temp - } - if projectID != nil && *projectID != "" { - temp := make([]*Site, 0) - for _, s := range filtered { - if s.ProjectId == *projectID { - temp = append(temp, s) - } - } - filtered = temp - } - return filtered, nil - } - } - - // Fetch from API - client, err := api.NewLibopsAPIClient(ctx, apiBaseURL) - if err != nil { - return nil, err - } - - resp, err := client.SiteService.ListSites(ctx, connect.NewRequest(&libopsv1.ListSitesRequest{ - OrganizationId: orgID, - ProjectId: projectID, - })) - if err != nil { - return nil, fmt.Errorf("failed to list sites: %w", err) - } - - // Cache the result (only if not filtered) - if useCache && (orgID == nil || *orgID == "") && (projectID == nil || *projectID == "") { - if err := cache.Set(cacheKey, resp.Msg.Sites); err != nil { - slog.Warn("Failed to cache sites", "err", err) - } - } - - return resp.Msg.Sites, nil -} - -// GetOrganization returns a specific organization, using cache when available -func GetOrganization(ctx context.Context, apiBaseURL, orgID string, useCache bool) (*Organization, error) { - cacheKey := cache.CacheKey{ - ResourceType: "organization", - Operation: "get", - ResourceID: orgID, - } - - // Try cache first - if useCache { - var cached Organization - found, err := cache.Get(cacheKey, &cached) - if err != nil { - slog.Warn("Failed to read cache", "err", err) - } else if found { - slog.Debug("Using cached organization", "id", orgID) - return &cached, nil - } - } - - // Fetch from API - client, err := api.NewLibopsAPIClient(ctx, apiBaseURL) - if err != nil { - return nil, err - } - - resp, err := client.OrganizationService.GetOrganization(ctx, connect.NewRequest(&libopsv1.GetOrganizationRequest{ - OrganizationId: orgID, - })) - if err != nil { - return nil, fmt.Errorf("failed to get organization: %w", err) - } - - // The response returns a Folder which is our Organization type - org := resp.Msg.Folder - - // Cache the result - if useCache { - if err := cache.Set(cacheKey, org); err != nil { - slog.Warn("Failed to cache organization", "err", err) - } - } - - return org, nil -} - -// InvalidateOrganizationCache invalidates all caches related to an organization -func InvalidateOrganizationCache(orgID string) error { - return cache.InvalidatePattern("organization", orgID) -} - -// InvalidateProjectCache invalidates all caches related to a project -func InvalidateProjectCache(projectID string) error { - return cache.InvalidatePattern("project", projectID) -} - -// InvalidateSiteCache invalidates all caches related to a site -func InvalidateSiteCache(siteID string) error { - return cache.InvalidatePattern("site", siteID) -} - -// InvalidateAllResourceCaches invalidates all resource list caches -func InvalidateAllResourceCaches() error { - if err := cache.InvalidatePattern("organization", ""); err != nil { - return err - } - if err := cache.InvalidatePattern("project", ""); err != nil { - return err - } - if err := cache.InvalidatePattern("site", ""); err != nil { - return err - } - return nil -}