diff --git a/README.md b/README.md
index 3584171..473400e 100644
--- a/README.md
+++ b/README.md
@@ -2,6 +2,18 @@
Command line utility to interact with your local and remote docker compose sites.
+## Why sitectl vs Docker Context?
+
+While [Docker's native context feature](https://docs.docker.com/engine/manage-resources/contexts/) handles basic daemon connections, `sitectl` is purpose-built for Docker Compose projects and adds:
+
+- **Enhanced remote operations**: SFTP file operations (read env files, upload/download), sudo support, and helpful SSH error messages
+- **Container utilities**: Resolve service names to containers, extract secrets/env vars to better support `exec` operations inside containers, get container IPs within Docker networks
+- **Plugin architecture**: Extend `sitectl` for project-specific needs (e.g. islandora, drupal, etc.)
+- **Service management**: Enable/disable services in docker-compose.yml with automatic cleanup of orphaned resources and Drupal configuration
+- **Compose-first design**: Set the equivalent of `DOCKER_HOST`, `COMPOSE_PROJECT_NAME`, `COMPOSE_FILE`, `COMPOSE_ENV_FILES` automatically based on sitectl context settings
+ - See [Docker's documentation](https://docs.docker.com/compose/how-tos/environment-variables/envvars/#configuration-details) for what these environment variables do
+
+
## Attribution
- The `config` commands for setting contexts were heavily inspired by `kubectl`
diff --git a/cmd/root.go b/cmd/root.go
index a351bc5..95dbc05 100644
--- a/cmd/root.go
+++ b/cmd/root.go
@@ -64,14 +64,8 @@ func init() {
ll = "INFO"
}
- apiURL := os.Getenv("LIBOPS_API_URL")
- if apiURL == "" {
- apiURL = "https://api.libops.io"
- }
-
RootCmd.PersistentFlags().String("context", c, "The sitectl context to use. See sitectl config --help for more info")
RootCmd.PersistentFlags().String("log-level", ll, "The logging level for the command")
- RootCmd.PersistentFlags().String("api-url", apiURL, "Base URL of the libops API")
RootCmd.PersistentFlags().String("format", "table", `Format output using a custom template:
'table': Print output in table format with column headers (default)
'table TEMPLATE': Print output in table format using the given Go template
diff --git a/cmd/sequelace.go b/cmd/sequelace.go
new file mode 100644
index 0000000..d821440
--- /dev/null
+++ b/cmd/sequelace.go
@@ -0,0 +1,58 @@
+package cmd
+
+import (
+ "fmt"
+ "log/slog"
+ "os/exec"
+ "runtime"
+
+ "github.com/libops/sitectl/pkg/config"
+ "github.com/libops/sitectl/pkg/docker"
+
+ "github.com/spf13/cobra"
+)
+
+var sequelAceCmd = &cobra.Command{
+ Use: "sequelace",
+ Short: "Connect to your MySQL/Mariadb database using Sequel Ace (Mac OS only)",
+ RunE: func(cmd *cobra.Command, args []string) error {
+ if runtime.GOOS != "darwin" {
+ return fmt.Errorf("sequelace is only supported on mac OS")
+ }
+
+ f := cmd.Flags()
+ context, err := config.CurrentContext(f)
+ if err != nil {
+ return err
+ }
+
+ sequelAcePath, err := f.GetString("sequel-ace-path")
+ if err != nil {
+ return err
+ }
+
+ mysql, ssh, err := docker.GetDatabaseUris(context)
+ if err != nil {
+ return err
+ }
+ slog.Debug("uris", "mysql", mysql, "ssh", ssh)
+ cmdArgs := []string{
+ fmt.Sprintf("%s?%s", mysql, ssh),
+ "-a",
+ sequelAcePath,
+ }
+ openCmd := exec.Command("open", cmdArgs...)
+ if err := openCmd.Run(); err != nil {
+ slog.Error("Could not open sequelace.")
+ return err
+ }
+
+ return nil
+ },
+}
+
+func init() {
+ RootCmd.AddCommand(sequelAceCmd)
+
+ sequelAceCmd.Flags().String("sequel-ace-path", "/Applications/Sequel Ace.app/Contents/MacOS/Sequel Ace", "Full path to your Sequel Ace app")
+}
diff --git a/go.mod b/go.mod
index 05e255b..9997ee1 100644
--- a/go.mod
+++ b/go.mod
@@ -3,11 +3,9 @@ module github.com/libops/sitectl
go 1.25.3
require (
- connectrpc.com/connect v1.19.1
github.com/docker/docker v28.5.2+incompatible
github.com/joho/godotenv v1.5.1
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51
- github.com/libops/api/proto v0.0.1
github.com/pkg/sftp v1.13.10
github.com/spf13/cobra v1.10.2
github.com/spf13/pflag v1.0.10
@@ -30,8 +28,6 @@ require (
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
- github.com/google/gnostic v0.7.1 // indirect
- github.com/google/gnostic-models v0.7.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
@@ -48,10 +44,8 @@ require (
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.35.0 // indirect
go.opentelemetry.io/otel/metric v1.39.0 // indirect
go.opentelemetry.io/otel/trace v1.39.0 // indirect
- go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/sys v0.39.0 // indirect
- golang.org/x/text v0.32.0 // indirect
golang.org/x/time v0.14.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 // indirect
diff --git a/go.sum b/go.sum
index 7e328ca..45a7373 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,3 @@
-connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14=
-connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w=
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
@@ -32,12 +30,6 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
-github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
-github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
-github.com/google/gnostic v0.7.1 h1:t5Kc7j/8kYr8t2u11rykRrPPovlEMG4+xdc/SpekATs=
-github.com/google/gnostic v0.7.1/go.mod h1:KSw6sxnxEBFM8jLPfJd46xZP+yQcfE8XkiqfZx5zR28=
-github.com/google/gnostic-models v0.7.1 h1:SisTfuFKJSKM5CPZkffwi6coztzzeYUhc3v4yxLWH8c=
-github.com/google/gnostic-models v0.7.1/go.mod h1:whL5G0m6dmc5cPxKc5bdKdEN3UjI7OUGxBlw57miDrQ=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -56,8 +48,6 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
-github.com/libops/api/proto v0.0.1 h1:DCpWMOJK2vUsXk9r+VhXivy2SVdsj8XNL46MwAplW54=
-github.com/libops/api/proto v0.0.1/go.mod h1:0acmYutF3M5smZ3Za8LHSWbC+ccOxfWmTtzwslJ1fMk=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
@@ -110,7 +100,6 @@ go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6
go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA=
go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4=
go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4=
-go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
@@ -124,8 +113,6 @@ golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
-gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
-gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2 h1:7LRqPCEdE4TP4/9psdaB7F2nhZFfBiGJomA5sojLWdU=
google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 h1:2I6GHUeJ/4shcDpoUlLs/2WPnhg7yJwvXtqcMJt9liA=
diff --git a/internal/utils/helper.go b/internal/utils/helper.go
deleted file mode 100644
index fb34f24..0000000
--- a/internal/utils/helper.go
+++ /dev/null
@@ -1,69 +0,0 @@
-package utils
-
-import (
- "fmt"
- "log/slog"
- "os"
- "os/exec"
- "runtime"
- "strings"
-
- "github.com/spf13/cobra"
-)
-
-func ExitOnError(err error) {
- slog.Error(err.Error())
- os.Exit(1)
-}
-
-// open a URL from the terminal
-func OpenURL(url string) error {
- var cmd *exec.Cmd
- switch runtime.GOOS {
- case "windows":
- cmd = exec.Command("cmd", "/c", "start", url)
- case "darwin":
- cmd = exec.Command("open", url)
- case "linux":
- cmd = exec.Command("xdg-open", url)
- default:
- return fmt.Errorf("unknown runtime command to open URL")
- }
-
- return cmd.Start()
-}
-
-// for cobra commands that allow arbitrary args to facilitate passing flags to other commands
-// strip out sitectl's context flag from the args if it was passed
-func GetContextFromArgs(cmd *cobra.Command, args []string) ([]string, string, error) {
- siteCtx, err := cmd.Root().PersistentFlags().GetString("context")
- if err != nil {
- return nil, "", err
- }
-
- // remove --context flag from the args if it exists
- // and set it as the default context if it was passed as a flag
- filteredArgs := []string{}
- skipNext := false
- for _, arg := range args {
- if arg == "--context" {
- skipNext = true
- continue
- }
- if strings.HasPrefix(arg, "--context=") {
- components := strings.Split(arg, "=")
- siteCtx = components[1]
- continue
- }
- if skipNext {
- siteCtx = arg
- skipNext = false
- continue
- }
- filteredArgs = append(filteredArgs, arg)
- }
-
- siteCtx = strings.Trim(siteCtx, `" `)
-
- return filteredArgs, siteCtx, nil
-}
diff --git a/pkg/api/client.go b/pkg/api/client.go
deleted file mode 100644
index 728dc7a..0000000
--- a/pkg/api/client.go
+++ /dev/null
@@ -1,165 +0,0 @@
-package api
-
-import (
- "context"
- "fmt"
- "net/http"
- "os"
- "path/filepath"
- "strings"
-
- "github.com/libops/api/proto/libops/v1/libopsv1connect"
- "github.com/libops/sitectl/pkg/auth"
-)
-
-// LibopsAPIClient holds all the service clients
-type LibopsAPIClient struct {
- OrganizationService libopsv1connect.OrganizationServiceClient
- ProjectService libopsv1connect.ProjectServiceClient
- SiteService libopsv1connect.SiteServiceClient
- AccountService libopsv1connect.AccountServiceClient
-
- // Members
- MemberService libopsv1connect.MemberServiceClient
- ProjectMemberService libopsv1connect.ProjectMemberServiceClient
- SiteMemberService libopsv1connect.SiteMemberServiceClient
-
- // Firewall
- FirewallService libopsv1connect.FirewallServiceClient
- ProjectFirewallService libopsv1connect.ProjectFirewallServiceClient
- SiteFirewallService libopsv1connect.SiteFirewallServiceClient
-
- // Secrets
- OrganizationSecretService libopsv1connect.OrganizationSecretServiceClient
- ProjectSecretService libopsv1connect.ProjectSecretServiceClient
- SiteSecretService libopsv1connect.SiteSecretServiceClient
-}
-
-// authTransport is an http.RoundTripper that adds an Authorization header to requests
-// and handles automatic token refreshing.
-type authTransport struct {
- apiBaseURL string
- next http.RoundTripper
-}
-
-func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
- // Check for API key first
- apiKey, err := loadAPIKey()
- if err == nil && apiKey != "" {
- // Use API key authentication
- req.Header.Set("Authorization", "Bearer "+apiKey)
- return t.next.RoundTrip(req)
- }
-
- // Fall back to OAuth tokens
- tokens, err := auth.LoadTokens()
- if err != nil {
- // If we can't load tokens, just proceed without auth (likely to fail) or return error?
- // Let's return error as we expect to be authenticated.
- return nil, fmt.Errorf("failed to load tokens: %w", err)
- }
-
- // Add Authorization header
- req.Header.Set("Authorization", "Bearer "+tokens.IDToken)
-
- resp, err := t.next.RoundTrip(req)
- if err != nil {
- return nil, err
- }
-
- // If we get a 401, the token is invalid - user needs to re-login
- if resp.StatusCode == http.StatusUnauthorized {
- _ = auth.ClearTokens()
- }
-
- return resp, nil
-}
-
-// loadAPIKey loads the API key from ~/.sitectl/key
-func loadAPIKey() (string, error) {
- homeDir := os.Getenv("HOME")
- if homeDir == "" {
- return "", fmt.Errorf("HOME environment variable not set")
- }
-
- keyPath := filepath.Join(homeDir, ".sitectl", "key")
- data, err := os.ReadFile(keyPath)
- if err != nil {
- return "", err
- }
-
- return strings.TrimSpace(string(data)), nil
-}
-
-// NewLibopsAPIClient creates and returns a new LibopsAPIClient instance.
-// It initializes all necessary service clients with authentication.
-func NewLibopsAPIClient(ctx context.Context, apiBaseURL string) (*LibopsAPIClient, error) {
- // Check for API key first
- apiKey, err := loadAPIKey()
- if err == nil && apiKey != "" {
- // API key found, skip token checks
- authenticatedClient := &http.Client{
- Transport: &authTransport{
- apiBaseURL: apiBaseURL,
- next: http.DefaultTransport,
- },
- }
-
- return &LibopsAPIClient{
- OrganizationService: libopsv1connect.NewOrganizationServiceClient(authenticatedClient, apiBaseURL),
- ProjectService: libopsv1connect.NewProjectServiceClient(authenticatedClient, apiBaseURL),
- SiteService: libopsv1connect.NewSiteServiceClient(authenticatedClient, apiBaseURL),
- AccountService: libopsv1connect.NewAccountServiceClient(authenticatedClient, apiBaseURL),
-
- MemberService: libopsv1connect.NewMemberServiceClient(authenticatedClient, apiBaseURL),
- ProjectMemberService: libopsv1connect.NewProjectMemberServiceClient(authenticatedClient, apiBaseURL),
- SiteMemberService: libopsv1connect.NewSiteMemberServiceClient(authenticatedClient, apiBaseURL),
-
- FirewallService: libopsv1connect.NewFirewallServiceClient(authenticatedClient, apiBaseURL),
- ProjectFirewallService: libopsv1connect.NewProjectFirewallServiceClient(authenticatedClient, apiBaseURL),
- SiteFirewallService: libopsv1connect.NewSiteFirewallServiceClient(authenticatedClient, apiBaseURL),
-
- OrganizationSecretService: libopsv1connect.NewOrganizationSecretServiceClient(authenticatedClient, apiBaseURL),
- ProjectSecretService: libopsv1connect.NewProjectSecretServiceClient(authenticatedClient, apiBaseURL),
- SiteSecretService: libopsv1connect.NewSiteSecretServiceClient(authenticatedClient, apiBaseURL),
- }, nil
- }
-
- // Fall back to OAuth tokens
- tokens, err := auth.LoadTokens()
- if err != nil {
- return nil, fmt.Errorf("failed to load authentication tokens: %w", err)
- }
-
- // Check if token is expired
- if tokens.IsTokenExpired() {
- _ = auth.ClearTokens()
- return nil, fmt.Errorf("authentication token expired, please run 'sitectl login' to re-authenticate")
- }
-
- authenticatedClient := &http.Client{
- Transport: &authTransport{
- apiBaseURL: apiBaseURL,
- next: http.DefaultTransport,
- },
- }
-
- return &LibopsAPIClient{
- OrganizationService: libopsv1connect.NewOrganizationServiceClient(authenticatedClient, apiBaseURL),
- ProjectService: libopsv1connect.NewProjectServiceClient(authenticatedClient, apiBaseURL),
- SiteService: libopsv1connect.NewSiteServiceClient(authenticatedClient, apiBaseURL),
- AccountService: libopsv1connect.NewAccountServiceClient(authenticatedClient, apiBaseURL),
-
- MemberService: libopsv1connect.NewMemberServiceClient(authenticatedClient, apiBaseURL),
- ProjectMemberService: libopsv1connect.NewProjectMemberServiceClient(authenticatedClient, apiBaseURL),
- SiteMemberService: libopsv1connect.NewSiteMemberServiceClient(authenticatedClient, apiBaseURL),
-
- FirewallService: libopsv1connect.NewFirewallServiceClient(authenticatedClient, apiBaseURL),
- ProjectFirewallService: libopsv1connect.NewProjectFirewallServiceClient(authenticatedClient, apiBaseURL),
- SiteFirewallService: libopsv1connect.NewSiteFirewallServiceClient(authenticatedClient, apiBaseURL),
-
- OrganizationSecretService: libopsv1connect.NewOrganizationSecretServiceClient(authenticatedClient, apiBaseURL),
- ProjectSecretService: libopsv1connect.NewProjectSecretServiceClient(authenticatedClient, apiBaseURL),
- SiteSecretService: libopsv1connect.NewSiteSecretServiceClient(authenticatedClient, apiBaseURL),
- }, nil
-}
diff --git a/pkg/auth/client.go b/pkg/auth/client.go
deleted file mode 100644
index dece7a8..0000000
--- a/pkg/auth/client.go
+++ /dev/null
@@ -1,247 +0,0 @@
-package auth
-
-import (
- "context"
- "crypto/rand"
- "encoding/base64"
- "fmt"
- "log/slog"
- "net"
- "net/http"
- "os/exec"
- "runtime"
- "time"
-)
-
-// AuthClient handles unified browser-based authentication.
-type AuthClient struct {
- apiBaseURL string
-}
-
-// callbackResult holds the result of the OAuth callback.
-type callbackResult struct {
- tokens *TokenResponse
- err error
-}
-
-// NewAuthClient creates a new authentication client.
-func NewAuthClient(apiBaseURL string) *AuthClient {
- return &AuthClient{
- apiBaseURL: apiBaseURL,
- }
-}
-
-// Login opens the browser to the API's login page and waits for the callback.
-func (c *AuthClient) Login(ctx context.Context) (*TokenResponse, error) {
- // Start a local HTTP server on a random available port
- listener, err := net.Listen("tcp", "localhost:0")
- if err != nil {
- return nil, fmt.Errorf("failed to start local server: %w", err)
- }
- defer listener.Close()
-
- port := listener.Addr().(*net.TCPAddr).Port
- slog.Debug("Started local callback server", "port", port)
-
- // Generate random state for CSRF protection
- state, err := generateRandomState()
- if err != nil {
- return nil, fmt.Errorf("failed to generate state: %w", err)
- }
-
- // Build the login URL that points to the API's login page
- // The API will show both Google and userpass options
- // Pass redirect_uri so the API knows where to send the user after authentication
- redirectURI := fmt.Sprintf("http://localhost:%d/callback", port)
- loginURL := fmt.Sprintf("%s/login?redirect_uri=%s&state=%s", c.apiBaseURL, redirectURI, state)
-
- // Create a channel to receive the callback result
- resultChan := make(chan callbackResult, 1)
-
- // Set up the HTTP server with callback handler
- mux := http.NewServeMux()
- mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
- c.handleCallback(w, r, state, resultChan)
- })
-
- server := &http.Server{
- Handler: mux,
- }
-
- // Start the server in a goroutine
- go func() {
- if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
- slog.Error("Server error", "err", err)
- }
- }()
- defer func() {
- shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- if err := server.Shutdown(shutdownCtx); err != nil {
- slog.Error("Failed to shutdown server", "err", err)
- }
- }()
-
- // Open browser to the login page
- if err := openBrowser(loginURL); err != nil {
- fmt.Printf("Failed to open browser automatically. Please visit:\n%s\n", loginURL)
- }
-
- // Wait for callback or timeout
- select {
- case result := <-resultChan:
- if result.err != nil {
- return nil, result.err
- }
- return result.tokens, nil
- case <-time.After(5 * time.Minute):
- return nil, fmt.Errorf("authentication timed out after 5 minutes")
- case <-ctx.Done():
- return nil, ctx.Err()
- }
-}
-
-// handleCallback processes the callback from the API after authentication.
-func (c *AuthClient) handleCallback(w http.ResponseWriter, r *http.Request, expectedState string, resultChan chan<- callbackResult) {
- state := r.URL.Query().Get("state")
- errorParam := r.URL.Query().Get("error")
- errorDesc := r.URL.Query().Get("error_description")
-
- if errorParam != "" {
- resultChan <- callbackResult{
- err: fmt.Errorf("authentication error: %s - %s", errorParam, errorDesc),
- }
- http.Error(w, fmt.Sprintf("Authentication failed: %s", errorDesc), http.StatusBadRequest)
- return
- }
-
- if state != expectedState {
- resultChan <- callbackResult{
- err: fmt.Errorf("invalid state parameter"),
- }
- http.Error(w, "Invalid state", http.StatusBadRequest)
- return
- }
-
- // Extract tokens from cookies set by the API's /auth/callback endpoint
- var accessToken, idToken string
- expiresIn := 3600
-
- for _, cookie := range r.Cookies() {
- switch cookie.Name {
- case "vault_token":
- accessToken = cookie.Value
- if cookie.MaxAge > 0 {
- expiresIn = cookie.MaxAge
- }
- case "id_token":
- idToken = cookie.Value
- }
- }
-
- // If tokens aren't in cookies, try query parameters (alternative approach)
- if idToken == "" {
- idToken = r.URL.Query().Get("id_token")
- }
- if accessToken == "" {
- accessToken = r.URL.Query().Get("access_token")
- }
-
- if idToken == "" {
- resultChan <- callbackResult{
- err: fmt.Errorf("authentication completed but no tokens received"),
- }
- http.Error(w, "No tokens received", http.StatusBadRequest)
- return
- }
-
- // Calculate expiry date
- expiryDate := time.Now().Unix() + int64(expiresIn)
-
- tokens := &TokenResponse{
- AccessToken: accessToken,
- IDToken: idToken,
- TokenType: "Bearer",
- ExpiryDate: expiryDate,
- Scope: "openid email profile",
- }
-
- // Send success page to browser
- w.Header().Set("Content-Type", "text/html")
- successHTML := `
-
-
-
- Authentication Successful
-
-
-
-
-
Authentication Successful!
-
You can close this window and return to your terminal.
-
-
-
-
-`
- if _, err := w.Write([]byte(successHTML)); err != nil {
- slog.Error("Failed to write response", "err", err)
- }
-
- resultChan <- callbackResult{tokens: tokens}
-}
-
-// generateRandomState generates a cryptographically secure random state parameter.
-func generateRandomState() (string, error) {
- bytes := make([]byte, 32)
- if _, err := rand.Read(bytes); err != nil {
- return "", err
- }
- return base64.URLEncoding.EncodeToString(bytes)[:32], nil
-}
-
-// openBrowser opens the specified URL in the default browser.
-func openBrowser(url string) error {
- var cmd *exec.Cmd
-
- switch runtime.GOOS {
- case "darwin":
- cmd = exec.Command("open", url)
- case "linux":
- cmd = exec.Command("xdg-open", url)
- case "windows":
- cmd = exec.Command("cmd", "/c", "start", url)
- default:
- return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
- }
-
- return cmd.Start()
-}
diff --git a/pkg/auth/token.go b/pkg/auth/token.go
deleted file mode 100644
index aed26c9..0000000
--- a/pkg/auth/token.go
+++ /dev/null
@@ -1,97 +0,0 @@
-package auth
-
-import (
- "encoding/json"
- "fmt"
- "os"
- "path/filepath"
- "time"
-)
-
-// TokenResponse represents the OAuth token response stored locally.
-type TokenResponse struct {
- AccessToken string `json:"access_token,omitempty"`
- IDToken string `json:"id_token"`
- TokenType string `json:"token_type"`
- ExpiryDate int64 `json:"expiry_date"`
- Scope string `json:"scope"`
-}
-
-// TokenFilePath returns the path to the OAuth token file.
-func TokenFilePath() (string, error) {
- homeDir, err := os.UserHomeDir()
- if err != nil {
- return "", fmt.Errorf("unable to detect home directory: %w", err)
- }
-
- baseDir := filepath.Join(homeDir, ".sitectl")
- if _, err := os.Stat(baseDir); os.IsNotExist(err) {
- if err := os.Mkdir(baseDir, 0700); err != nil {
- return "", fmt.Errorf("unable to create ~/.sitectl directory: %w", err)
- }
- }
-
- return filepath.Join(baseDir, "oauth.json"), nil
-}
-
-// SaveTokens saves OAuth tokens to disk with restricted permissions.
-func SaveTokens(tokens *TokenResponse) error {
- tokenPath, err := TokenFilePath()
- if err != nil {
- return err
- }
-
- data, err := json.MarshalIndent(tokens, "", " ")
- if err != nil {
- return fmt.Errorf("failed to marshal tokens: %w", err)
- }
-
- // Write with restrictive permissions (0600 = rw-------)
- if err := os.WriteFile(tokenPath, data, 0600); err != nil {
- return fmt.Errorf("failed to write token file: %w", err)
- }
-
- return nil
-}
-
-// LoadTokens loads OAuth tokens from disk.
-func LoadTokens() (*TokenResponse, error) {
- tokenPath, err := TokenFilePath()
- if err != nil {
- return nil, err
- }
-
- data, err := os.ReadFile(tokenPath)
- if err != nil {
- if os.IsNotExist(err) {
- return nil, fmt.Errorf("not authenticated: run 'sitectl login' first")
- }
- return nil, fmt.Errorf("failed to read token file: %w", err)
- }
-
- var tokens TokenResponse
- if err := json.Unmarshal(data, &tokens); err != nil {
- return nil, fmt.Errorf("failed to parse token file: %w", err)
- }
-
- return &tokens, nil
-}
-
-// IsTokenExpired checks if the token has expired.
-func (t *TokenResponse) IsTokenExpired() bool {
- return time.Now().Unix() >= t.ExpiryDate
-}
-
-// ClearTokens removes the token file from disk.
-func ClearTokens() error {
- tokenPath, err := TokenFilePath()
- if err != nil {
- return err
- }
-
- if err := os.Remove(tokenPath); err != nil && !os.IsNotExist(err) {
- return fmt.Errorf("failed to remove token file: %w", err)
- }
-
- return nil
-}
diff --git a/pkg/auth/token_test.go b/pkg/auth/token_test.go
deleted file mode 100644
index 464c998..0000000
--- a/pkg/auth/token_test.go
+++ /dev/null
@@ -1,195 +0,0 @@
-package auth
-
-import (
- "os"
- "path/filepath"
- "testing"
- "time"
-)
-
-func TestTokenFilePath(t *testing.T) {
- path, err := TokenFilePath()
- if err != nil {
- t.Fatalf("TokenFilePath() failed: %v", err)
- }
-
- if path == "" {
- t.Error("TokenFilePath() returned empty path")
- }
-
- if !filepath.IsAbs(path) {
- t.Errorf("TokenFilePath() returned non-absolute path: %s", path)
- }
-
- expectedSuffix := filepath.Join(".sitectl", "oauth.json")
- if !filepath.IsAbs(path) || filepath.Base(filepath.Dir(path)) != ".sitectl" {
- t.Errorf("TokenFilePath() = %s, should contain %s", path, expectedSuffix)
- }
-}
-
-func TestSaveAndLoadTokens(t *testing.T) {
- // Create a temporary directory for testing
- tempDir := t.TempDir()
- originalHome := os.Getenv("HOME")
- os.Setenv("HOME", tempDir)
- defer os.Setenv("HOME", originalHome)
-
- tokens := &TokenResponse{
- AccessToken: "test_access_token",
- IDToken: "test_id_token",
- TokenType: "Bearer",
- ExpiryDate: time.Now().Unix() + 3600,
- Scope: "openid email profile",
- }
-
- // Test saving tokens
- err := SaveTokens(tokens)
- if err != nil {
- t.Fatalf("SaveTokens() failed: %v", err)
- }
-
- // Verify file was created with correct permissions
- tokenPath, _ := TokenFilePath()
- info, err := os.Stat(tokenPath)
- if err != nil {
- t.Fatalf("Token file not created: %v", err)
- }
-
- // Check file permissions (should be 0600)
- expectedPerms := os.FileMode(0600)
- if info.Mode().Perm() != expectedPerms {
- t.Errorf("Token file permissions = %o, want %o", info.Mode().Perm(), expectedPerms)
- }
-
- // Test loading tokens
- loadedTokens, err := LoadTokens()
- if err != nil {
- t.Fatalf("LoadTokens() failed: %v", err)
- }
-
- // Verify loaded tokens match saved tokens
- if loadedTokens.AccessToken != tokens.AccessToken {
- t.Errorf("AccessToken = %s, want %s", loadedTokens.AccessToken, tokens.AccessToken)
- }
- if loadedTokens.IDToken != tokens.IDToken {
- t.Errorf("IDToken = %s, want %s", loadedTokens.IDToken, tokens.IDToken)
- }
- if loadedTokens.TokenType != tokens.TokenType {
- t.Errorf("TokenType = %s, want %s", loadedTokens.TokenType, tokens.TokenType)
- }
- if loadedTokens.ExpiryDate != tokens.ExpiryDate {
- t.Errorf("ExpiryDate = %d, want %d", loadedTokens.ExpiryDate, tokens.ExpiryDate)
- }
- if loadedTokens.Scope != tokens.Scope {
- t.Errorf("Scope = %s, want %s", loadedTokens.Scope, tokens.Scope)
- }
-}
-
-func TestLoadTokens_NotFound(t *testing.T) {
- // Create a temporary directory for testing
- tempDir := t.TempDir()
- originalHome := os.Getenv("HOME")
- os.Setenv("HOME", tempDir)
- defer os.Setenv("HOME", originalHome)
-
- // Ensure no token file exists
- tokenPath, _ := TokenFilePath()
- os.Remove(tokenPath)
-
- _, err := LoadTokens()
- if err == nil {
- t.Error("LoadTokens() should fail when token file doesn't exist")
- }
-}
-
-func TestIsTokenExpired(t *testing.T) {
- tests := []struct {
- name string
- expiryDate int64
- want bool
- }{
- {
- name: "token expired",
- expiryDate: time.Now().Unix() - 3600, // 1 hour ago
- want: true,
- },
- {
- name: "token valid",
- expiryDate: time.Now().Unix() + 3600, // 1 hour from now
- want: false,
- },
- {
- name: "token expires now",
- expiryDate: time.Now().Unix(),
- want: true, // Should be considered expired if exactly at expiry time
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- token := &TokenResponse{
- ExpiryDate: tt.expiryDate,
- }
- if got := token.IsTokenExpired(); got != tt.want {
- t.Errorf("IsTokenExpired() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestClearTokens(t *testing.T) {
- // Create a temporary directory for testing
- tempDir := t.TempDir()
- originalHome := os.Getenv("HOME")
- os.Setenv("HOME", tempDir)
- defer os.Setenv("HOME", originalHome)
-
- // Create a token file
- tokens := &TokenResponse{
- AccessToken: "test_token",
- IDToken: "test_id_token",
- TokenType: "Bearer",
- ExpiryDate: time.Now().Unix() + 3600,
- Scope: "openid",
- }
-
- err := SaveTokens(tokens)
- if err != nil {
- t.Fatalf("SaveTokens() failed: %v", err)
- }
-
- // Verify file exists
- tokenPath, _ := TokenFilePath()
- if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
- t.Fatal("Token file was not created")
- }
-
- // Clear tokens
- err = ClearTokens()
- if err != nil {
- t.Fatalf("ClearTokens() failed: %v", err)
- }
-
- // Verify file was removed
- if _, err := os.Stat(tokenPath); !os.IsNotExist(err) {
- t.Error("Token file still exists after ClearTokens()")
- }
-}
-
-func TestClearTokens_NotFound(t *testing.T) {
- // Create a temporary directory for testing
- tempDir := t.TempDir()
- originalHome := os.Getenv("HOME")
- os.Setenv("HOME", tempDir)
- defer os.Setenv("HOME", originalHome)
-
- // Ensure no token file exists
- tokenPath, _ := TokenFilePath()
- os.Remove(tokenPath)
-
- // Clearing non-existent tokens should not error
- err := ClearTokens()
- if err != nil {
- t.Errorf("ClearTokens() failed when no tokens exist: %v", err)
- }
-}
diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go
deleted file mode 100644
index 1ef5624..0000000
--- a/pkg/cache/cache.go
+++ /dev/null
@@ -1,207 +0,0 @@
-package cache
-
-import (
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "fmt"
- "os"
- "path/filepath"
- "time"
-)
-
-const (
- cacheDir = ".sitectl/cache"
- cacheValidity = 12 * time.Hour
-)
-
-// CacheKey represents a structured cache key
-type CacheKey struct {
- ResourceType string // "organization", "project", "site"
- Operation string // "list", "get"
- ParentType string // optional: parent resource type
- ParentID string // optional: parent resource ID
- SubResource string // optional: "firewall", "members", "secrets"
- ResourceID string // optional: specific resource ID
-}
-
-// GetCachePath returns the file path for a cache key
-func (k CacheKey) GetCachePath() (string, error) {
- homeDir, err := os.UserHomeDir()
- if err != nil {
- return "", fmt.Errorf("failed to get home directory: %w", err)
- }
-
- parts := []string{
- homeDir,
- cacheDir,
- k.Operation,
- }
-
- if k.ParentType != "" && k.ParentID != "" {
- // Cached sub-resource: ~/.sitectl/cache/list/organization//firewall/list.resp
- parts = append(parts, k.ParentType, k.ParentID)
- if k.SubResource != "" {
- parts = append(parts, k.SubResource)
- }
- } else {
- // Cached resource: ~/.sitectl/cache/list/organization/list.resp
- parts = append(parts, k.ResourceType)
- }
-
- // Determine filename
- var filename string
- if k.ResourceID != "" {
- filename = fmt.Sprintf("%s.resp", k.ResourceID)
- } else {
- filename = "list.resp"
- }
-
- parts = append(parts, filename)
- return filepath.Join(parts...), nil
-}
-
-// Get retrieves a cached value if it exists and is not expired
-func Get(key CacheKey, target interface{}) (bool, error) {
- path, err := key.GetCachePath()
- if err != nil {
- return false, err
- }
-
- // Check if file exists
- info, err := os.Stat(path)
- if os.IsNotExist(err) {
- return false, nil
- }
- if err != nil {
- return false, err
- }
-
- // Check if cache is expired
- if time.Since(info.ModTime()) > cacheValidity {
- // Cache expired, delete it
- os.Remove(path)
- return false, nil
- }
-
- // Read cache file
- data, err := os.ReadFile(path)
- if err != nil {
- return false, err
- }
-
- // Unmarshal into target
- if err := json.Unmarshal(data, target); err != nil {
- // Cache corrupted, delete it
- os.Remove(path)
- return false, nil
- }
-
- return true, nil
-}
-
-// Set stores a value in the cache
-func Set(key CacheKey, value interface{}) error {
- path, err := key.GetCachePath()
- if err != nil {
- return err
- }
-
- // Create directory structure
- dir := filepath.Dir(path)
- if err := os.MkdirAll(dir, 0755); err != nil {
- return fmt.Errorf("failed to create cache directory: %w", err)
- }
-
- // Marshal value
- data, err := json.Marshal(value)
- if err != nil {
- return fmt.Errorf("failed to marshal cache data: %w", err)
- }
-
- // Write to file
- if err := os.WriteFile(path, data, 0600); err != nil {
- return fmt.Errorf("failed to write cache file: %w", err)
- }
-
- return nil
-}
-
-// Invalidate removes a cached value
-func Invalidate(key CacheKey) error {
- path, err := key.GetCachePath()
- if err != nil {
- return err
- }
-
- // Remove file if it exists
- if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
- return err
- }
-
- return nil
-}
-
-// InvalidatePattern removes all cache entries matching a pattern
-// This is useful for invalidating all caches related to a resource
-func InvalidatePattern(resourceType, resourceID string) error {
- homeDir, err := os.UserHomeDir()
- if err != nil {
- return fmt.Errorf("failed to get home directory: %w", err)
- }
-
- // Invalidate list cache for this resource type
- listKey := CacheKey{
- ResourceType: resourceType,
- Operation: "list",
- }
- err = Invalidate(listKey)
- if err != nil {
- return fmt.Errorf("failed to invalidate cache: %w", err)
- }
-
- // If we have a specific resource ID, invalidate its get cache and all sub-resources
- if resourceID != "" {
- getKey := CacheKey{
- ResourceType: resourceType,
- Operation: "get",
- ResourceID: resourceID,
- }
- err = Invalidate(getKey)
- if err != nil {
- return fmt.Errorf("failed to invalidate cache: %w", err)
- }
-
- // Invalidate all sub-resource caches
- subResources := []string{"firewall", "members", "secrets"}
- for _, subResource := range subResources {
- subCacheDir := filepath.Join(homeDir, cacheDir, "list", resourceType, resourceID, subResource)
- if err := os.RemoveAll(subCacheDir); err != nil && !os.IsNotExist(err) {
- return err
- }
- }
- }
-
- return nil
-}
-
-// Clear removes all cached data
-func Clear() error {
- homeDir, err := os.UserHomeDir()
- if err != nil {
- return fmt.Errorf("failed to get home directory: %w", err)
- }
-
- cachePath := filepath.Join(homeDir, cacheDir)
- if err := os.RemoveAll(cachePath); err != nil && !os.IsNotExist(err) {
- return err
- }
-
- return nil
-}
-
-// HashID creates a short hash for cache keys (for very long IDs)
-func HashID(id string) string {
- hash := sha256.Sum256([]byte(id))
- return hex.EncodeToString(hash[:8]) // Use first 8 bytes (16 hex chars)
-}
diff --git a/pkg/config/context.go b/pkg/config/context.go
index fa677a4..18ecaaf 100644
--- a/pkg/config/context.go
+++ b/pkg/config/context.go
@@ -29,19 +29,23 @@ const (
)
type Context struct {
- Name string `yaml:"name"`
- DockerHostType ContextType `mapstructure:"type" yaml:"type"`
- DockerSocket string `yaml:"docker-socket"`
- ProjectName string `yaml:"project-name"`
- Profile string `yaml:"profile"`
- ProjectDir string `yaml:"project-dir"`
- SSHUser string `yaml:"ssh-user"`
- SSHHostname string `yaml:"ssh-hostname,omitempty"`
- SSHPort uint `yaml:"ssh-port,omitempty"`
- SSHKeyPath string `yaml:"ssh-key,omitempty"`
- EnvFile []string `yaml:"env-file"`
- RunSudo bool `yaml:"sudo"`
- UriMap map[string]string `yaml:"uriMap"`
+ Name string `yaml:"name"`
+ DockerHostType ContextType `mapstructure:"type" yaml:"type"`
+ DockerSocket string `yaml:"docker-socket"`
+ ProjectName string `yaml:"project-name"`
+ ProjectDir string `yaml:"project-dir"`
+ SSHUser string `yaml:"ssh-user"`
+ SSHHostname string `yaml:"ssh-hostname,omitempty"`
+ SSHPort uint `yaml:"ssh-port,omitempty"`
+ SSHKeyPath string `yaml:"ssh-key,omitempty"`
+ EnvFile []string `yaml:"env-file"`
+ RunSudo bool `yaml:"sudo"`
+
+ // Database connection configuration
+ DatabaseService string `yaml:"database-service,omitempty"`
+ DatabaseUser string `yaml:"database-user,omitempty"`
+ DatabasePasswordSecret string `yaml:"database-password-secret,omitempty"`
+ DatabaseName string `yaml:"database-name,omitempty"`
ReadSmallFileFunc func(filename string) string `yaml:"-"`
}
@@ -97,6 +101,20 @@ func SaveContext(ctx *Context, setDefault bool) error {
return err
}
+ // Set database defaults if not provided
+ if ctx.DatabaseService == "" {
+ ctx.DatabaseService = "mariadb"
+ }
+ if ctx.DatabaseUser == "" {
+ ctx.DatabaseUser = "root"
+ }
+ if ctx.DatabasePasswordSecret == "" {
+ ctx.DatabasePasswordSecret = "DB_ROOT_PASSWORD"
+ }
+ if ctx.DatabaseName == "" {
+ ctx.DatabaseName = "drupal_default"
+ }
+
updated := false
for i, c := range cfg.Contexts {
if c.Name == ctx.Name {
@@ -437,3 +455,22 @@ func (c *Context) UploadFile(source, destination string) error {
return nil
}
+
+// GetSshUri returns an SSH connection URI
+func (c *Context) GetSshUri() string {
+ if c.DockerHostType == ContextLocal {
+ return ""
+ }
+
+ sshPort := c.SSHPort
+ if sshPort == 0 {
+ sshPort = 22
+ }
+
+ sshParams := fmt.Sprintf("sshHost=%s&sshUser=%s&sshPort=%d", c.SSHHostname, c.SSHUser, sshPort)
+ if c.SSHKeyPath != "" {
+ sshParams += fmt.Sprintf("&sshKeyFile=%s", c.SSHKeyPath)
+ }
+
+ return sshParams
+}
diff --git a/pkg/config/context_test.go b/pkg/config/context_test.go
index 6093da9..9a3bd00 100644
--- a/pkg/config/context_test.go
+++ b/pkg/config/context_test.go
@@ -77,14 +77,17 @@ func contextsEqual(a, b Context) bool {
a.DockerHostType == b.DockerHostType &&
a.DockerSocket == b.DockerSocket &&
a.ProjectName == b.ProjectName &&
- a.Profile == b.Profile &&
a.ProjectDir == b.ProjectDir &&
a.SSHUser == b.SSHUser &&
a.SSHHostname == b.SSHHostname &&
a.SSHPort == b.SSHPort &&
a.SSHKeyPath == b.SSHKeyPath &&
len(a.EnvFile) == len(b.EnvFile) &&
- a.RunSudo == b.RunSudo
+ a.RunSudo == b.RunSudo &&
+ a.DatabaseService == b.DatabaseService &&
+ a.DatabaseUser == b.DatabaseUser &&
+ a.DatabasePasswordSecret == b.DatabasePasswordSecret &&
+ a.DatabaseName == b.DatabaseName
}
func TestContextString(t *testing.T) {
@@ -308,7 +311,6 @@ func TestVerifyRemoteInputExistingConfig(t *testing.T) {
SSHUser: "bar",
SSHPort: 123,
SSHKeyPath: "/assuming/we/already/checked",
- Profile: "prod",
ProjectName: "baz",
}
cc := original
@@ -322,3 +324,70 @@ func TestVerifyRemoteInputExistingConfig(t *testing.T) {
t.Fatalf("expected context %+v, got %+v", original, cc)
}
}
+
+func TestGetSshUri(t *testing.T) {
+ tests := []struct {
+ name string
+ context Context
+ expected string
+ }{
+ {
+ name: "local context returns empty string",
+ context: Context{
+ DockerHostType: ContextLocal,
+ },
+ expected: "",
+ },
+ {
+ name: "remote context with default port",
+ context: Context{
+ DockerHostType: ContextRemote,
+ SSHHostname: "example.com",
+ SSHUser: "testuser",
+ SSHPort: 0, // Should default to 22
+ },
+ expected: "sshHost=example.com&sshUser=testuser&sshPort=22",
+ },
+ {
+ name: "remote context with custom port",
+ context: Context{
+ DockerHostType: ContextRemote,
+ SSHHostname: "example.com",
+ SSHUser: "testuser",
+ SSHPort: 2222,
+ },
+ expected: "sshHost=example.com&sshUser=testuser&sshPort=2222",
+ },
+ {
+ name: "remote context with SSH key path",
+ context: Context{
+ DockerHostType: ContextRemote,
+ SSHHostname: "example.com",
+ SSHUser: "testuser",
+ SSHPort: 22,
+ SSHKeyPath: "/home/user/.ssh/id_rsa",
+ },
+ expected: "sshHost=example.com&sshUser=testuser&sshPort=22&sshKeyFile=/home/user/.ssh/id_rsa",
+ },
+ {
+ name: "remote context without SSH key path",
+ context: Context{
+ DockerHostType: ContextRemote,
+ SSHHostname: "server.example.com",
+ SSHUser: "admin",
+ SSHPort: 22,
+ SSHKeyPath: "",
+ },
+ expected: "sshHost=server.example.com&sshUser=admin&sshPort=22",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.context.GetSshUri()
+ if result != tt.expected {
+ t.Errorf("expected %q, got %q", tt.expected, result)
+ }
+ })
+ }
+}
diff --git a/pkg/config/utils.go b/pkg/config/utils.go
index 348ef72..3c4ff6e 100644
--- a/pkg/config/utils.go
+++ b/pkg/config/utils.go
@@ -162,7 +162,10 @@ func SetCommandFlags(flags *pflag.FlagSet) {
flags.String("ssh-key", "", "Path to SSH private key for remote context. e.g. "+key)
flags.String("project-dir", "", "Path to docker compose project directory")
flags.String("project-name", "docker-compose", "Name of the docker compose project")
- flags.String("profile", "", "docker compose profile")
flags.Bool("sudo", false, "for remote contexts, run docker commands as sudo")
flags.StringSlice("env-file", []string{}, "when running remote docker commands, the --env-file paths to pass to docker compose")
+ flags.String("database-service", "mariadb", "Name of the database service in Docker Compose")
+ flags.String("database-user", "root", "Database user to connect as (e.g. root, admin)")
+ flags.String("database-password-secret", "DB_ROOT_PASSWORD", "Name of the docker compose secret containing the database password")
+ flags.String("database-name", "drupal_default", "Name of the database to connect to (e.g. drupal_default)")
}
diff --git a/pkg/config/utils_test.go b/pkg/config/utils_test.go
index 8efdb84..ade1719 100644
--- a/pkg/config/utils_test.go
+++ b/pkg/config/utils_test.go
@@ -16,7 +16,6 @@ func TestLoadFromFlags(t *testing.T) {
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
flags.String("docker-socket", "/var/run/docker.sock", "Path to Docker socket")
flags.String("type", "local", "Context type: local or remote")
- flags.String("profile", "default", "Profile name")
flags.String("ssh-hostname", "example.com", "SSH host for remote context")
flags.Uint("ssh-port", 22, "port")
flags.String("ssh-user", "user", "SSH user for remote context")
@@ -25,12 +24,15 @@ func TestLoadFromFlags(t *testing.T) {
flags.String("project-name", "foo", "Composer Project Name")
flags.Bool("sudo", false, "Run commands on remote hosts as sudo")
flags.StringSlice("env-file", []string{}, "path to env files to pass to docker compose")
+ flags.String("database-service", "mariadb", "Name of the database service in Docker Compose")
+ flags.String("database-user", "root", "Database user to connect as")
+ flags.String("database-password-secret", "DB_ROOT_PASSWORD", "Name of the secret containing the database password")
+ flags.String("database-name", "drupal_default", "Name of the database to connect to")
// Define test arguments to override defaults.
args := []string{
"--docker-socket", "/custom/docker.sock",
"--type", "remote",
- "--profile", "prod",
"--ssh-hostname", "remote.example.com",
"--ssh-port", "123",
"--ssh-user", "remoteuser",
@@ -58,9 +60,6 @@ func TestLoadFromFlags(t *testing.T) {
if ctx.DockerHostType != "remote" {
t.Errorf("Expected type 'remote', got %q", ctx.DockerHostType)
}
- if ctx.Profile != "prod" {
- t.Errorf("Expected profile 'prod', got %q", ctx.Profile)
- }
if ctx.SSHHostname != "remote.example.com" {
t.Errorf("Expected ssh-host 'remote.example.com', got %q", ctx.SSHHostname)
}
diff --git a/pkg/docker/docker.go b/pkg/docker/docker.go
index cd5c993..ec9f84a 100644
--- a/pkg/docker/docker.go
+++ b/pkg/docker/docker.go
@@ -125,15 +125,12 @@ func (d *DockerClient) GetServiceIp(ctx context.Context, c *config.Context, cont
return network.IPAddress, nil
}
-func (d *DockerClient) GetContainerName(c *config.Context, service string, neverPrefixProfile bool) (string, error) {
+func (d *DockerClient) GetContainerName(c *config.Context, service string) (string, error) {
ctx := context.Background()
// Define the filters based on the Docker Compose labels.
filterArgs := filters.NewArgs()
filterArgs.Add("label", "com.docker.compose.project="+c.ProjectName)
- if c.Profile != "" && !neverPrefixProfile {
- service = service + "-" + c.Profile
- }
filterArgs.Add("label", "com.docker.compose.service="+service)
slog.Debug("Querying docker", "filters", filterArgs)
@@ -286,3 +283,34 @@ func (d *DockerClient) ExecInteractive(ctx context.Context, containerID string,
Tty: true,
})
}
+
+// GetDatabaseUris constructs MySQL and SSH connection URIs for database tools like Sequel Ace
+// Returns: mysqlURI, sshURI, error
+func GetDatabaseUris(c *config.Context) (string, string, error) {
+ ctx := context.Background()
+
+ // Get Docker client
+ dockerCli, err := GetDockerCli(c)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to get docker client: %w", err)
+ }
+ defer dockerCli.Close()
+
+ // Get the database container name
+ containerName, err := dockerCli.GetContainerName(c, c.DatabaseService)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to get %s container: %w", c.DatabaseService, err)
+ }
+ if containerName == "" {
+ return "", "", fmt.Errorf("%s container not found", c.DatabaseService)
+ }
+
+ // Get database password from container environment
+ password, err := GetSecret(ctx, dockerCli.CLI, c, containerName, c.DatabasePasswordSecret)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to get database password from %s: %w", c.DatabasePasswordSecret, err)
+ }
+
+ mysqlURI := fmt.Sprintf("mysql://%s:%s@127.0.0.1:3306/%s", c.DatabaseUser, password, c.DatabaseName)
+ return mysqlURI, c.GetSshUri(), nil
+}
diff --git a/internal/utils/helper_test.go b/pkg/helpers/helper_test.go
similarity index 99%
rename from internal/utils/helper_test.go
rename to pkg/helpers/helper_test.go
index 7af744c..d7086dc 100644
--- a/internal/utils/helper_test.go
+++ b/pkg/helpers/helper_test.go
@@ -1,4 +1,4 @@
-package utils
+package helpers
import (
"reflect"
diff --git a/pkg/plugin/sdk.go b/pkg/plugin/sdk.go
index 4cd7a99..78f2653 100644
--- a/pkg/plugin/sdk.go
+++ b/pkg/plugin/sdk.go
@@ -83,9 +83,6 @@ func (s *SDK) setupLogging(cmd *cobra.Command) error {
if s.RootCmd.PersistentFlags().Lookup("context") != nil {
s.Config.Context, _ = cmd.Flags().GetString("context")
}
- if s.RootCmd.PersistentFlags().Lookup("api-url") != nil {
- s.Config.APIUrl, _ = cmd.Flags().GetString("api-url")
- }
if s.RootCmd.PersistentFlags().Lookup("format") != nil {
s.Config.Format, _ = cmd.Flags().GetString("format")
}
@@ -120,15 +117,9 @@ func (s *SDK) Execute() {
}
}
-// AddLibopsFlags adds common libops-specific flags
-func (s *SDK) AddLibopsFlags(currentContext string) {
- apiURL := os.Getenv("LIBOPS_API_URL")
- if apiURL == "" {
- apiURL = "https://api.libops.io"
- }
-
+// AddGlobalFlags adds common libops-specific flags
+func (s *SDK) AddGlobalFlags(currentContext string) {
s.RootCmd.PersistentFlags().String("context", currentContext, "The sitectl context to use. See sitectl config --help for more info")
- s.RootCmd.PersistentFlags().String("api-url", apiURL, "Base URL of the libops API")
}
// GetMetadataCommand returns a command that displays plugin metadata
diff --git a/pkg/resources/resources.go b/pkg/resources/resources.go
deleted file mode 100644
index 7e49c69..0000000
--- a/pkg/resources/resources.go
+++ /dev/null
@@ -1,249 +0,0 @@
-package resources
-
-import (
- "context"
- "fmt"
- "log/slog"
-
- "connectrpc.com/connect"
-
- libopsv1 "github.com/libops/api/proto/libops/v1"
- "github.com/libops/api/proto/libops/v1/common"
- "github.com/libops/sitectl/pkg/api"
- "github.com/libops/sitectl/pkg/cache"
-)
-
-// Type aliases for cleaner code
-type Organization = common.FolderConfig
-type Project = common.ProjectConfig
-type Site = common.SiteConfig
-
-// ListOrganizations returns all organizations, using cache when available
-func ListOrganizations(ctx context.Context, apiBaseURL string, useCache bool) ([]*Organization, error) {
- cacheKey := cache.CacheKey{
- ResourceType: "organization",
- Operation: "list",
- }
-
- // Try cache first
- if useCache {
- var cached []*Organization
- found, err := cache.Get(cacheKey, &cached)
- if err != nil {
- slog.Warn("Failed to read cache", "err", err)
- } else if found {
- slog.Debug("Using cached organizations", "count", len(cached))
- return cached, nil
- }
- }
-
- // Fetch from API
- client, err := api.NewLibopsAPIClient(ctx, apiBaseURL)
- if err != nil {
- return nil, err
- }
-
- resp, err := client.OrganizationService.ListOrganizations(ctx, connect.NewRequest(&libopsv1.ListOrganizationsRequest{}))
- if err != nil {
- return nil, fmt.Errorf("failed to list organizations: %w", err)
- }
-
- // Cache the result
- if useCache {
- if err := cache.Set(cacheKey, resp.Msg.Organizations); err != nil {
- slog.Warn("Failed to cache organizations", "err", err)
- }
- }
-
- return resp.Msg.Organizations, nil
-}
-
-// ListProjects returns all projects, using cache when available
-func ListProjects(ctx context.Context, apiBaseURL string, useCache bool, orgID *string) ([]*Project, error) {
- cacheKey := cache.CacheKey{
- ResourceType: "project",
- Operation: "list",
- }
-
- // Try cache first
- if useCache {
- var cached []*Project
- found, err := cache.Get(cacheKey, &cached)
- if err != nil {
- slog.Warn("Failed to read cache", "err", err)
- } else if found {
- // Filter by org if needed
- if orgID != nil && *orgID != "" {
- filtered := make([]*Project, 0)
- for _, p := range cached {
- if p.OrganizationId == *orgID {
- filtered = append(filtered, p)
- }
- }
- slog.Debug("Using cached projects (filtered)", "count", len(filtered))
- return filtered, nil
- }
- slog.Debug("Using cached projects", "count", len(cached))
- return cached, nil
- }
- }
-
- // Fetch from API
- client, err := api.NewLibopsAPIClient(ctx, apiBaseURL)
- if err != nil {
- return nil, err
- }
-
- resp, err := client.ProjectService.ListProjects(ctx, connect.NewRequest(&libopsv1.ListProjectsRequest{
- OrganizationId: orgID,
- }))
- if err != nil {
- return nil, fmt.Errorf("failed to list projects: %w", err)
- }
-
- // Cache the result (only if not filtered)
- if useCache && (orgID == nil || *orgID == "") {
- if err := cache.Set(cacheKey, resp.Msg.Projects); err != nil {
- slog.Warn("Failed to cache projects", "err", err)
- }
- }
-
- return resp.Msg.Projects, nil
-}
-
-// ListSites returns all sites, using cache when available
-func ListSites(ctx context.Context, apiBaseURL string, useCache bool, orgID, projectID *string) ([]*Site, error) {
- cacheKey := cache.CacheKey{
- ResourceType: "site",
- Operation: "list",
- }
-
- // Try cache first
- if useCache {
- var cached []*Site
- found, err := cache.Get(cacheKey, &cached)
- if err != nil {
- slog.Warn("Failed to read cache", "err", err)
- } else if found {
- // Filter by org/project if needed
- filtered := cached
- if orgID != nil && *orgID != "" {
- temp := make([]*Site, 0)
- for _, s := range filtered {
- if s.OrganizationId == *orgID {
- temp = append(temp, s)
- }
- }
- filtered = temp
- }
- if projectID != nil && *projectID != "" {
- temp := make([]*Site, 0)
- for _, s := range filtered {
- if s.ProjectId == *projectID {
- temp = append(temp, s)
- }
- }
- filtered = temp
- }
- return filtered, nil
- }
- }
-
- // Fetch from API
- client, err := api.NewLibopsAPIClient(ctx, apiBaseURL)
- if err != nil {
- return nil, err
- }
-
- resp, err := client.SiteService.ListSites(ctx, connect.NewRequest(&libopsv1.ListSitesRequest{
- OrganizationId: orgID,
- ProjectId: projectID,
- }))
- if err != nil {
- return nil, fmt.Errorf("failed to list sites: %w", err)
- }
-
- // Cache the result (only if not filtered)
- if useCache && (orgID == nil || *orgID == "") && (projectID == nil || *projectID == "") {
- if err := cache.Set(cacheKey, resp.Msg.Sites); err != nil {
- slog.Warn("Failed to cache sites", "err", err)
- }
- }
-
- return resp.Msg.Sites, nil
-}
-
-// GetOrganization returns a specific organization, using cache when available
-func GetOrganization(ctx context.Context, apiBaseURL, orgID string, useCache bool) (*Organization, error) {
- cacheKey := cache.CacheKey{
- ResourceType: "organization",
- Operation: "get",
- ResourceID: orgID,
- }
-
- // Try cache first
- if useCache {
- var cached Organization
- found, err := cache.Get(cacheKey, &cached)
- if err != nil {
- slog.Warn("Failed to read cache", "err", err)
- } else if found {
- slog.Debug("Using cached organization", "id", orgID)
- return &cached, nil
- }
- }
-
- // Fetch from API
- client, err := api.NewLibopsAPIClient(ctx, apiBaseURL)
- if err != nil {
- return nil, err
- }
-
- resp, err := client.OrganizationService.GetOrganization(ctx, connect.NewRequest(&libopsv1.GetOrganizationRequest{
- OrganizationId: orgID,
- }))
- if err != nil {
- return nil, fmt.Errorf("failed to get organization: %w", err)
- }
-
- // The response returns a Folder which is our Organization type
- org := resp.Msg.Folder
-
- // Cache the result
- if useCache {
- if err := cache.Set(cacheKey, org); err != nil {
- slog.Warn("Failed to cache organization", "err", err)
- }
- }
-
- return org, nil
-}
-
-// InvalidateOrganizationCache invalidates all caches related to an organization
-func InvalidateOrganizationCache(orgID string) error {
- return cache.InvalidatePattern("organization", orgID)
-}
-
-// InvalidateProjectCache invalidates all caches related to a project
-func InvalidateProjectCache(projectID string) error {
- return cache.InvalidatePattern("project", projectID)
-}
-
-// InvalidateSiteCache invalidates all caches related to a site
-func InvalidateSiteCache(siteID string) error {
- return cache.InvalidatePattern("site", siteID)
-}
-
-// InvalidateAllResourceCaches invalidates all resource list caches
-func InvalidateAllResourceCaches() error {
- if err := cache.InvalidatePattern("organization", ""); err != nil {
- return err
- }
- if err := cache.InvalidatePattern("project", ""); err != nil {
- return err
- }
- if err := cache.InvalidatePattern("site", ""); err != nil {
- return err
- }
- return nil
-}