diff --git a/cmd/xsql/command_unit_test.go b/cmd/xsql/command_unit_test.go index 602acb8..3dca927 100644 --- a/cmd/xsql/command_unit_test.go +++ b/cmd/xsql/command_unit_test.go @@ -296,8 +296,8 @@ func TestRunProxy_SSHConnectError(t *testing.T) { } } -func TestSetupSSH_NoConfig(t *testing.T) { - client, err := setupSSH(nil, configProfile(""), false, false) +func TestResolveSSH_NoConfig(t *testing.T) { + client, err := app.ResolveSSH(nil, config.Profile{}, false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -306,9 +306,8 @@ func TestSetupSSH_NoConfig(t *testing.T) { } } -func TestSetupSSH_PassphraseResolveError(t *testing.T) { +func TestResolveSSH_PassphraseResolveError(t *testing.T) { profile := config.Profile{ - DB: "mysql", SSHConfig: &config.SSHProxy{ Host: "example.com", Port: 22, @@ -317,7 +316,7 @@ func TestSetupSSH_PassphraseResolveError(t *testing.T) { }, } - _, err := setupSSH(context.Background(), profile, false, false) + _, err := app.ResolveSSH(context.Background(), profile, false, false) if err == nil { t.Fatal("expected error for passphrase resolve") } diff --git a/cmd/xsql/proxy.go b/cmd/xsql/proxy.go index a51bf4d..9acd391 100644 --- a/cmd/xsql/proxy.go +++ b/cmd/xsql/proxy.go @@ -9,11 +9,10 @@ import ( "github.com/spf13/cobra" + "github.com/zx06/xsql/internal/app" "github.com/zx06/xsql/internal/errors" "github.com/zx06/xsql/internal/output" "github.com/zx06/xsql/internal/proxy" - "github.com/zx06/xsql/internal/secret" - "github.com/zx06/xsql/internal/ssh" ) // ProxyFlags holds the flags for the proxy command @@ -60,45 +59,21 @@ func runProxy(cmd *cobra.Command, flags *ProxyFlags, w *output.Writer) error { return errors.New(errors.CodeCfgInvalid, "db type is required (mysql|pg)", nil) } - // Check if SSH proxy is configured if p.SSHConfig == nil { return errors.New(errors.CodeCfgInvalid, "profile must have ssh_proxy configured for port forwarding", nil) } - // Allow plaintext passwords (CLI > Config) allowPlaintext := flags.AllowPlaintext || p.AllowPlaintext - // Resolve SSH passphrase - passphrase := p.SSHConfig.Passphrase - if passphrase != "" { - pp, xe := secret.Resolve(passphrase, secret.Options{AllowPlaintext: allowPlaintext}) - if xe != nil { - return xe - } - passphrase = pp - } - - // Setup SSH connection ctx, cancel := context.WithCancel(context.Background()) defer cancel() - sshOpts := ssh.Options{ - Host: p.SSHConfig.Host, - Port: p.SSHConfig.Port, - User: p.SSHConfig.User, - IdentityFile: p.SSHConfig.IdentityFile, - Passphrase: passphrase, - KnownHostsFile: p.SSHConfig.KnownHostsFile, - SkipKnownHostsCheck: flags.SSHSkipHostKey || p.SSHConfig.SkipHostKey, - } - - sshClient, xe := ssh.Connect(ctx, sshOpts) + sshClient, xe := app.ResolveSSH(ctx, p, allowPlaintext, flags.SSHSkipHostKey) if xe != nil { return xe } defer func() { _ = sshClient.Close() }() - // Start proxy proxyOpts := proxy.Options{ LocalHost: flags.LocalHost, LocalPort: flags.LocalPort, @@ -113,16 +88,13 @@ func runProxy(cmd *cobra.Command, flags *ProxyFlags, w *output.Writer) error { } defer func() { _ = px.Stop() }() - // Print result based on format if format == output.FormatTable { - // Custom table output for proxy fmt.Fprintf(os.Stderr, "✓ Proxy started successfully\n") fmt.Fprintf(os.Stderr, " Local: %s\n", result.LocalAddress) fmt.Fprintf(os.Stderr, " Remote: %s (via %s)\n", result.RemoteAddress, p.SSHConfig.Host) fmt.Fprintf(os.Stderr, " Profile: %s\n", profileName) fmt.Fprintf(os.Stderr, "\nPress Ctrl+C to stop\n") } else { - // JSON/YAML output data := map[string]any{ "local_address": result.LocalAddress, "remote_address": result.RemoteAddress, @@ -132,11 +104,9 @@ func runProxy(cmd *cobra.Command, flags *ProxyFlags, w *output.Writer) error { _ = w.WriteOK(format, data) } - // Setup signal handling for graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - // Wait for interrupt signal <-sigChan fmt.Fprintf(os.Stderr, "\nShutting down proxy...\n") diff --git a/cmd/xsql/query.go b/cmd/xsql/query.go index 72ce167..d429d0f 100644 --- a/cmd/xsql/query.go +++ b/cmd/xsql/query.go @@ -6,14 +6,10 @@ import ( "github.com/spf13/cobra" - "github.com/zx06/xsql/internal/config" + "github.com/zx06/xsql/internal/app" "github.com/zx06/xsql/internal/db" - _ "github.com/zx06/xsql/internal/db/mysql" - _ "github.com/zx06/xsql/internal/db/pg" "github.com/zx06/xsql/internal/errors" "github.com/zx06/xsql/internal/output" - "github.com/zx06/xsql/internal/secret" - "github.com/zx06/xsql/internal/ssh" ) // QueryFlags holds the flags for the query command @@ -56,57 +52,21 @@ func runQuery(cmd *cobra.Command, args []string, flags *QueryFlags, w *output.Wr return errors.New(errors.CodeCfgInvalid, "db type is required (mysql|pg)", nil) } - // Allow plaintext passwords (CLI > Config) - allowPlaintext := flags.AllowPlaintext || p.AllowPlaintext - - // Resolve password (supports keyring) - password := p.Password - if password != "" { - pw, xe := secret.Resolve(password, secret.Options{AllowPlaintext: allowPlaintext}) - if xe != nil { - return xe - } - password = pw - } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - // SSH proxy (if configured) - sshClient, err := setupSSH(ctx, p, allowPlaintext, flags.SSHSkipHostKey) - if err != nil { - return err - } - if sshClient != nil { - defer sshClient.Close() - } - - // Get driver - drv, ok := db.Get(p.DB) - if !ok { - return errors.New(errors.CodeDBDriverUnsupported, "unsupported db driver", map[string]any{"db": p.DB}) - } - - connOpts := db.ConnOptions{ - DSN: p.DSN, - Host: p.Host, - Port: p.Port, - User: p.User, - Password: password, - Database: p.Database, - } - if sshClient != nil { - connOpts.Dialer = sshClient - } - - conn, xe := drv.Open(ctx, connOpts) + conn, xe := app.ResolveConnection(ctx, app.ConnectionOptions{ + Profile: p, + AllowPlaintext: flags.AllowPlaintext, + SkipHostKeyCheck: flags.SSHSkipHostKey, + }) if xe != nil { return xe } - defer conn.Close() + defer func() { _ = conn.Close() }() unsafeAllowWrite := flags.UnsafeAllowWrite || p.UnsafeAllowWrite - result, xe := db.Query(ctx, conn, sql, db.QueryOptions{ + result, xe := db.Query(ctx, conn.DB, sql, db.QueryOptions{ UnsafeAllowWrite: unsafeAllowWrite, DBType: p.DB, }) @@ -116,36 +76,3 @@ func runQuery(cmd *cobra.Command, args []string, flags *QueryFlags, w *output.Wr return w.WriteOK(format, result) } - -// setupSSH sets up SSH proxy connection -func setupSSH(ctx context.Context, p config.Profile, allowPlaintext, skipHostKeyCheck bool) (*ssh.Client, error) { - if p.SSHConfig == nil { - return nil, nil - } - - passphrase := p.SSHConfig.Passphrase - if passphrase != "" { - pp, xe := secret.Resolve(passphrase, secret.Options{AllowPlaintext: allowPlaintext}) - if xe != nil { - return nil, xe - } - passphrase = pp - } - - sshOpts := ssh.Options{ - Host: p.SSHConfig.Host, - Port: p.SSHConfig.Port, - User: p.SSHConfig.User, - IdentityFile: p.SSHConfig.IdentityFile, - Passphrase: passphrase, - KnownHostsFile: p.SSHConfig.KnownHostsFile, - SkipKnownHostsCheck: skipHostKeyCheck || p.SSHConfig.SkipHostKey, - } - - sc, xe := ssh.Connect(ctx, sshOpts) - if xe != nil { - return nil, xe - } - - return sc, nil -} diff --git a/cmd/xsql/schema.go b/cmd/xsql/schema.go index 649ea7a..4109725 100644 --- a/cmd/xsql/schema.go +++ b/cmd/xsql/schema.go @@ -6,12 +6,10 @@ import ( "github.com/spf13/cobra" + "github.com/zx06/xsql/internal/app" "github.com/zx06/xsql/internal/db" - _ "github.com/zx06/xsql/internal/db/mysql" - _ "github.com/zx06/xsql/internal/db/pg" "github.com/zx06/xsql/internal/errors" "github.com/zx06/xsql/internal/output" - "github.com/zx06/xsql/internal/secret" ) // SchemaFlags holds the flags for the schema command @@ -67,62 +65,25 @@ func runSchemaDump(cmd *cobra.Command, args []string, flags *SchemaFlags, w *out return errors.New(errors.CodeCfgInvalid, "db type is required (mysql|pg)", nil) } - // Allow plaintext passwords (CLI > Config) - allowPlaintext := flags.AllowPlaintext || p.AllowPlaintext - - // Resolve password (supports keyring) - password := p.Password - if password != "" { - pw, xe := secret.Resolve(password, secret.Options{AllowPlaintext: allowPlaintext}) - if xe != nil { - return xe - } - password = pw - } - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() - // SSH proxy (if configured) - sshClient, err := setupSSH(ctx, p, allowPlaintext, flags.SSHSkipHostKey) - if err != nil { - return err - } - if sshClient != nil { - defer sshClient.Close() - } - - // Get driver - drv, ok := db.Get(p.DB) - if !ok { - return errors.New(errors.CodeDBDriverUnsupported, "unsupported db driver", map[string]any{"db": p.DB}) - } - - connOpts := db.ConnOptions{ - DSN: p.DSN, - Host: p.Host, - Port: p.Port, - User: p.User, - Password: password, - Database: p.Database, - } - if sshClient != nil { - connOpts.Dialer = sshClient - } - - conn, xe := drv.Open(ctx, connOpts) + conn, xe := app.ResolveConnection(ctx, app.ConnectionOptions{ + Profile: p, + AllowPlaintext: flags.AllowPlaintext, + SkipHostKeyCheck: flags.SSHSkipHostKey, + }) if xe != nil { return xe } - defer conn.Close() + defer func() { _ = conn.Close() }() - // Dump schema schemaOpts := db.SchemaOptions{ TablePattern: flags.TablePattern, IncludeSystem: flags.IncludeSystem, } - result, xe := db.DumpSchema(ctx, p.DB, conn, schemaOpts) + result, xe := db.DumpSchema(ctx, p.DB, conn.DB, schemaOpts) if xe != nil { return xe } diff --git a/internal/app/conn.go b/internal/app/conn.go new file mode 100644 index 0000000..d2afd52 --- /dev/null +++ b/internal/app/conn.go @@ -0,0 +1,143 @@ +package app + +import ( + "context" + "database/sql" + + "github.com/zx06/xsql/internal/config" + "github.com/zx06/xsql/internal/db" + "github.com/zx06/xsql/internal/errors" + "github.com/zx06/xsql/internal/secret" + "github.com/zx06/xsql/internal/ssh" +) + +type Connection struct { + DB *sql.DB + SSHClient *ssh.Client + Profile config.Profile + CloseFuncs []func() error +} + +func (c *Connection) Close() error { + var errs []error + if c.DB != nil { + if err := c.DB.Close(); err != nil { + errs = append(errs, err) + } + } + if c.SSHClient != nil { + if err := c.SSHClient.Close(); err != nil { + errs = append(errs, err) + } + } + if len(errs) > 0 { + return errs[0] + } + return nil +} + +type ConnectionOptions struct { + Profile config.Profile + AllowPlaintext bool + SkipHostKeyCheck bool +} + +func ResolveConnection(ctx context.Context, opts ConnectionOptions) (*Connection, *errors.XError) { + allowPlaintext := opts.AllowPlaintext || opts.Profile.AllowPlaintext + + password := opts.Profile.Password + if password != "" { + pw, xe := secret.Resolve(password, secret.Options{AllowPlaintext: allowPlaintext}) + if xe != nil { + return nil, xe + } + password = pw + } + + var sshClient *ssh.Client + if opts.Profile.SSHConfig != nil { + passphrase := opts.Profile.SSHConfig.Passphrase + if passphrase != "" { + pp, xe := secret.Resolve(passphrase, secret.Options{AllowPlaintext: allowPlaintext}) + if xe != nil { + return nil, xe + } + passphrase = pp + } + + sshOpts := ssh.Options{ + Host: opts.Profile.SSHConfig.Host, + Port: opts.Profile.SSHConfig.Port, + User: opts.Profile.SSHConfig.User, + IdentityFile: opts.Profile.SSHConfig.IdentityFile, + Passphrase: passphrase, + KnownHostsFile: opts.Profile.SSHConfig.KnownHostsFile, + SkipKnownHostsCheck: opts.SkipHostKeyCheck || opts.Profile.SSHConfig.SkipHostKey, + } + sc, xe := ssh.Connect(ctx, sshOpts) + if xe != nil { + return nil, xe + } + sshClient = sc + } + + drv, ok := db.Get(opts.Profile.DB) + if !ok { + return nil, errors.New(errors.CodeDBDriverUnsupported, "unsupported db driver", map[string]any{"db": opts.Profile.DB}) + } + + connOpts := db.ConnOptions{ + DSN: opts.Profile.DSN, + Host: opts.Profile.Host, + Port: opts.Profile.Port, + User: opts.Profile.User, + Password: password, + Database: opts.Profile.Database, + } + if sshClient != nil { + connOpts.Dialer = sshClient + } + + conn, xe := drv.Open(ctx, connOpts) + if xe != nil { + return nil, xe + } + + return &Connection{ + DB: conn, + SSHClient: sshClient, + Profile: opts.Profile, + }, nil +} + +func ResolveSSH(ctx context.Context, profile config.Profile, allowPlaintext, skipHostKeyCheck bool) (*ssh.Client, *errors.XError) { + if profile.SSHConfig == nil { + return nil, nil + } + + passphrase := profile.SSHConfig.Passphrase + if passphrase != "" { + pp, xe := secret.Resolve(passphrase, secret.Options{AllowPlaintext: allowPlaintext}) + if xe != nil { + return nil, xe + } + passphrase = pp + } + + sshOpts := ssh.Options{ + Host: profile.SSHConfig.Host, + Port: profile.SSHConfig.Port, + User: profile.SSHConfig.User, + IdentityFile: profile.SSHConfig.IdentityFile, + Passphrase: passphrase, + KnownHostsFile: profile.SSHConfig.KnownHostsFile, + SkipKnownHostsCheck: skipHostKeyCheck || profile.SSHConfig.SkipHostKey, + } + + sc, xe := ssh.Connect(ctx, sshOpts) + if xe != nil { + return nil, xe + } + + return sc, nil +} diff --git a/internal/app/conn_test.go b/internal/app/conn_test.go new file mode 100644 index 0000000..95023f3 --- /dev/null +++ b/internal/app/conn_test.go @@ -0,0 +1,143 @@ +package app + +import ( + "testing" + + "github.com/zx06/xsql/internal/config" + "github.com/zx06/xsql/internal/errors" +) + +func TestResolveConnection_UnsupportedDriver(t *testing.T) { + profile := config.Profile{ + DB: "unsupported", + } + + conn, err := ResolveConnection(nil, ConnectionOptions{ + Profile: profile, + }) + + if conn != nil { + t.Fatal("expected nil connection") + } + if err == nil { + t.Fatal("expected error") + } + if err.Code != errors.CodeDBDriverUnsupported { + t.Errorf("expected CodeDBDriverUnsupported, got %s", err.Code) + } +} + +func TestResolveSSH_NoSSHConfig(t *testing.T) { + profile := config.Profile{} + + client, err := ResolveSSH(nil, profile, false, false) + + if client != nil { + t.Fatal("expected nil client") + } + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestResolveConnection_PasswordNotAllowed(t *testing.T) { + profile := config.Profile{ + DB: "mysql", + Password: "plaintext_password", + } + + conn, err := ResolveConnection(nil, ConnectionOptions{ + Profile: profile, + AllowPlaintext: false, + }) + + if conn != nil { + t.Fatal("expected nil connection") + } + if err == nil { + t.Fatal("expected error") + } + if err.Code != errors.CodeCfgInvalid { + t.Errorf("expected CodeCfgInvalid, got %s", err.Code) + } +} + +func TestResolveConnection_PasswordAllowed(t *testing.T) { + profile := config.Profile{ + DB: "mysql", + Password: "plaintext_password", + } + + conn, err := ResolveConnection(nil, ConnectionOptions{ + Profile: profile, + AllowPlaintext: true, + }) + + if err == nil { + if conn != nil { + conn.Close() + } + } +} + +func TestResolveSSH_PassphraseNotAllowed(t *testing.T) { + profile := config.Profile{ + SSHConfig: &config.SSHProxy{ + Host: "example.com", + Port: 22, + User: "user", + Passphrase: "some_passphrase", + }, + } + + client, err := ResolveSSH(nil, profile, false, false) + + if client != nil { + t.Fatal("expected nil client") + } + if err == nil { + t.Fatal("expected error") + } + if err.Code != errors.CodeCfgInvalid { + t.Errorf("expected CodeCfgInvalid, got %s", err.Code) + } +} + +func TestResolveSSH_PassphraseAllowed(t *testing.T) { + profile := config.Profile{ + SSHConfig: &config.SSHProxy{ + Host: "example.com", + Port: 22, + User: "user", + Passphrase: "some_passphrase", + }, + } + + client, err := ResolveSSH(nil, profile, true, false) + + if err == nil { + if client != nil { + client.Close() + } + } +} + +func TestConnectionOptions_Fields(t *testing.T) { + opts := ConnectionOptions{ + Profile: config.Profile{ + DB: "pg", + }, + AllowPlaintext: true, + SkipHostKeyCheck: true, + } + + if opts.Profile.DB != "pg" { + t.Errorf("expected db pg, got %s", opts.Profile.DB) + } + if !opts.AllowPlaintext { + t.Error("expected AllowPlaintext to be true") + } + if !opts.SkipHostKeyCheck { + t.Error("expected SkipHostKeyCheck to be true") + } +} diff --git a/internal/db/mysql/driver.go b/internal/db/mysql/driver.go index 675d55b..09c09c8 100644 --- a/internal/db/mysql/driver.go +++ b/internal/db/mysql/driver.go @@ -5,7 +5,7 @@ import ( "database/sql" "fmt" "net" - "sync/atomic" + "sync" "github.com/go-sql-driver/mysql" @@ -13,7 +13,11 @@ import ( "github.com/zx06/xsql/internal/errors" ) -var dialerCounter uint64 +var ( + dialerOnce sync.Once + dialerName string + dialerCalled bool +) func init() { db.Register("mysql", &Driver{}) @@ -45,13 +49,20 @@ func (d *Driver) Open(ctx context.Context, opts db.ConnOptions) (*sql.DB, *error } } - // 注册自定义 dialer(用于 SSH tunnel) if opts.Dialer != nil { - netName := fmt.Sprintf("xsql_ssh_%d", atomic.AddUint64(&dialerCounter, 1)) - mysql.RegisterDialContext(netName, func(ctx context.Context, addr string) (net.Conn, error) { - return opts.Dialer.DialContext(ctx, "tcp", addr) + dialerOnce.Do(func() { + dialerName = "xsql_ssh_tunnel" + mysql.RegisterDialContext(dialerName, func(ctx context.Context, addr string) (net.Conn, error) { + return opts.Dialer.DialContext(ctx, "tcp", addr) + }) + dialerCalled = true }) - cfg.Net = netName + if !dialerCalled { + mysql.RegisterDialContext(dialerName, func(ctx context.Context, addr string) (net.Conn, error) { + return opts.Dialer.DialContext(ctx, "tcp", addr) + }) + } + cfg.Net = dialerName } dsn := cfg.FormatDSN() diff --git a/internal/db/mysql/driver_test.go b/internal/db/mysql/driver_test.go index dbbdded..3ace130 100644 --- a/internal/db/mysql/driver_test.go +++ b/internal/db/mysql/driver_test.go @@ -24,7 +24,6 @@ func TestDriver_Open_InvalidDSN(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - // 使用无效的 DSN 格式 opts := db.ConnOptions{ DSN: "invalid:::dsn", } @@ -39,10 +38,9 @@ func TestDriver_Open_ConnectionRefused(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - // 使用不存在的地址 opts := db.ConnOptions{ Host: "127.0.0.1", - Port: 59999, // 不太可能有服务监听的端口 + Port: 59999, User: "test", Password: "test", Database: "test", @@ -53,7 +51,6 @@ func TestDriver_Open_ConnectionRefused(t *testing.T) { } } -// mockDialer 用于测试自定义 dialer 注册 type mockDialer struct { called bool } @@ -79,7 +76,6 @@ func TestDriver_Open_WithDialer(t *testing.T) { } _, xe := drv.Open(ctx, opts) - // 应该失败,但 dialer 应该被调用 if xe == nil { t.Fatal("expected error from mock dialer") } @@ -87,3 +83,113 @@ func TestDriver_Open_WithDialer(t *testing.T) { t.Error("expected custom dialer to be called") } } + +func TestDriver_Open_WithDSN_Valid(t *testing.T) { + drv, _ := db.Get("mysql") + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + opts := db.ConnOptions{ + DSN: "root:password@tcp(127.0.0.1:3306)/testdb?timeout=1s", + } + + _, xe := drv.Open(ctx, opts) + if xe == nil { + t.Fatal("expected connection error for invalid DSN") + } +} + +func TestDriver_Open_WithDSN_InvalidFormat(t *testing.T) { + drv, _ := db.Get("mysql") + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + opts := db.ConnOptions{ + DSN: "invalid", + } + + _, xe := drv.Open(ctx, opts) + if xe == nil { + t.Fatal("expected error for malformed DSN") + } +} + +func TestDriver_Open_NoDBName(t *testing.T) { + drv, _ := db.Get("mysql") + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + opts := db.ConnOptions{ + Host: "127.0.0.1", + Port: 59999, + User: "test", + Password: "test", + } + + _, xe := drv.Open(ctx, opts) + if xe == nil { + t.Fatal("expected connection error") + } +} + +func TestDriver_Open_WithParams(t *testing.T) { + drv, _ := db.Get("mysql") + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + opts := db.ConnOptions{ + Host: "127.0.0.1", + Port: 59999, + User: "test", + Password: "test", + Database: "test", + Params: map[string]string{ + "charset": "utf8mb4", + "parseTime": "true", + }, + } + + _, xe := drv.Open(ctx, opts) + if xe == nil { + t.Fatal("expected connection error") + } +} + +func TestDriver_Open_WithEmptyParams(t *testing.T) { + drv, _ := db.Get("mysql") + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + opts := db.ConnOptions{ + Host: "127.0.0.1", + Port: 59999, + User: "test", + Password: "test", + Database: "test", + Params: map[string]string{}, + } + + _, xe := drv.Open(ctx, opts) + if xe == nil { + t.Fatal("expected connection error") + } +} + +func TestDriver_Open_ContextCancelled(t *testing.T) { + drv, _ := db.Get("mysql") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + opts := db.ConnOptions{ + Host: "127.0.0.1", + Port: 3306, + User: "test", + Password: "test", + Database: "test", + } + + _, xe := drv.Open(ctx, opts) + if xe == nil { + t.Fatal("expected error for cancelled context") + } +} diff --git a/internal/db/pg/driver_test.go b/internal/db/pg/driver_test.go index 5b7528a..68338f0 100644 --- a/internal/db/pg/driver_test.go +++ b/internal/db/pg/driver_test.go @@ -26,7 +26,7 @@ func TestDriver_Open_ConnectionRefused(t *testing.T) { opts := db.ConnOptions{ Host: "127.0.0.1", - Port: 59998, // 不太可能有服务监听的端口 + Port: 59998, User: "test", Password: "test", Database: "test", @@ -41,7 +41,7 @@ func TestBuildDSN(t *testing.T) { tests := []struct { name string opts db.ConnOptions - expected []string // 应包含的片段 + expected []string }{ { name: "full options", @@ -67,6 +67,32 @@ func TestBuildDSN(t *testing.T) { opts: db.ConnOptions{}, expected: []string{}, }, + { + name: "with port zero", + opts: db.ConnOptions{ + Host: "localhost", + Port: 0, + User: "user", + Password: "pass", + Database: "mydb", + }, + expected: []string{"host=localhost", "user=user", "password=pass", "dbname=mydb"}, + }, + { + name: "with params", + opts: db.ConnOptions{ + Host: "localhost", + Port: 5432, + User: "user", + Password: "pass", + Database: "mydb", + Params: map[string]string{ + "sslmode": "disable", + "pool_max_conns": "10", + }, + }, + expected: []string{"sslmode=disable", "pool_max_conns=10"}, + }, } for _, tt := range tests { @@ -94,7 +120,6 @@ func findSubstring(s, substr string) bool { return false } -// mockDialer 用于测试自定义 dialer type mockDialer struct { called bool } @@ -120,7 +145,6 @@ func TestDriver_Open_WithDialer(t *testing.T) { } _, xe := drv.Open(ctx, opts) - // 应该失败,但 dialer 应该被调用 if xe == nil { t.Fatal("expected error from mock dialer") } @@ -128,3 +152,72 @@ func TestDriver_Open_WithDialer(t *testing.T) { t.Error("expected custom dialer to be called") } } + +func TestDriver_Open_WithDSN(t *testing.T) { + drv, _ := db.Get("pg") + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + opts := db.ConnOptions{ + DSN: "postgres://user:pass@127.0.0.1:5432/testdb?sslmode=disable", + } + + _, xe := drv.Open(ctx, opts) + if xe == nil { + t.Fatal("expected connection error") + } +} + +func TestDriver_Open_WithDSN_Invalid(t *testing.T) { + drv, _ := db.Get("pg") + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + opts := db.ConnOptions{ + DSN: "invalid://", + } + + _, xe := drv.Open(ctx, opts) + if xe == nil { + t.Fatal("expected error for invalid DSN") + } +} + +func TestDriver_Open_WithDSN_AndDialer(t *testing.T) { + drv, _ := db.Get("pg") + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + dialer := &mockDialer{} + opts := db.ConnOptions{ + DSN: "postgres://user:pass@127.0.0.1:5432/testdb?sslmode=disable", + Dialer: dialer, + } + + _, xe := drv.Open(ctx, opts) + if xe == nil { + t.Fatal("expected connection error") + } + if !dialer.called { + t.Error("expected custom dialer to be called when DSN and Dialer both provided") + } +} + +func TestDriver_Open_ContextCancelled(t *testing.T) { + drv, _ := db.Get("pg") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + opts := db.ConnOptions{ + Host: "127.0.0.1", + Port: 5432, + User: "test", + Password: "test", + Database: "test", + } + + _, xe := drv.Open(ctx, opts) + if xe == nil { + t.Fatal("expected error for cancelled context") + } +} diff --git a/internal/db/schema_test.go b/internal/db/schema_test.go index fa4f16d..19334f6 100644 --- a/internal/db/schema_test.go +++ b/internal/db/schema_test.go @@ -1,314 +1,314 @@ -package db - -import ( - "testing" -) - -func TestSchemaInfo_ToSchemaData(t *testing.T) { - tests := []struct { - name string - schema *SchemaInfo - wantDB string - wantLen int - wantOK bool - }{ - { - name: "nil schema", - schema: nil, - wantDB: "", - wantLen: 0, - wantOK: false, - }, - { - name: "empty tables", - schema: &SchemaInfo{Database: "testdb", Tables: []Table{}}, - wantDB: "", - wantLen: 0, - wantOK: false, - }, - { - name: "single table no columns", - schema: &SchemaInfo{ - Database: "testdb", - Tables: []Table{ - {Schema: "public", Name: "users"}, - }, - }, - wantDB: "testdb", - wantLen: 1, - wantOK: true, - }, - { - name: "single table with columns", - schema: &SchemaInfo{ - Database: "testdb", - Tables: []Table{ - { - Schema: "public", - Name: "users", - Comment: "用户表", - Columns: []Column{ - {Name: "id", Type: "bigint", Nullable: false, PrimaryKey: true}, - {Name: "email", Type: "varchar(255)", Nullable: false, Comment: "邮箱"}, - }, - }, - }, - }, - wantDB: "testdb", - wantLen: 1, - wantOK: true, - }, - { - name: "multiple tables", - schema: &SchemaInfo{ - Database: "testdb", - Tables: []Table{ - { - Schema: "public", - Name: "users", - Columns: []Column{ - {Name: "id", Type: "bigint", PrimaryKey: true}, - }, - }, - { - Schema: "public", - Name: "orders", - Columns: []Column{ - {Name: "id", Type: "bigint", PrimaryKey: true}, - {Name: "user_id", Type: "bigint"}, - }, - ForeignKeys: []ForeignKey{ - {Name: "fk_user", Columns: []string{"user_id"}, ReferencedTable: "users", ReferencedColumns: []string{"id"}}, - }, - }, - }, - }, - wantDB: "testdb", - wantLen: 2, - wantOK: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db, tables, ok := tt.schema.ToSchemaData() - if ok != tt.wantOK { - t.Errorf("ToSchemaData() ok = %v, want %v", ok, tt.wantOK) - } - if db != tt.wantDB { - t.Errorf("ToSchemaData() db = %v, want %v", db, tt.wantDB) - } - if len(tables) != tt.wantLen { - t.Errorf("ToSchemaData() len(tables) = %v, want %v", len(tables), tt.wantLen) - } - }) - } -} - -func TestSchemaInfo_ToSchemaData_ColumnData(t *testing.T) { - schema := &SchemaInfo{ - Database: "testdb", - Tables: []Table{ - { - Schema: "public", - Name: "users", - Comment: "用户表", - Columns: []Column{ - {Name: "id", Type: "bigint", Nullable: false, Default: "nextval('users_id_seq')", Comment: "主键", PrimaryKey: true}, - {Name: "email", Type: "varchar(255)", Nullable: false, Comment: "邮箱"}, - {Name: "created_at", Type: "timestamp", Nullable: true, Default: "now()"}, - }, - }, - }, - } - - db, tables, ok := schema.ToSchemaData() - if !ok { - t.Fatal("expected ok=true") - } - if db != "testdb" { - t.Errorf("db = %v, want testdb", db) - } - if len(tables) != 1 { - t.Fatalf("len(tables) = %v, want 1", len(tables)) - } - - table := tables[0] - if table.Schema != "public" { - t.Errorf("table.Schema = %v, want public", table.Schema) - } - if table.Name != "users" { - t.Errorf("table.Name = %v, want users", table.Name) - } - if table.Comment != "用户表" { - t.Errorf("table.Comment = %v, want 用户表", table.Comment) - } - if len(table.Columns) != 3 { - t.Fatalf("len(table.Columns) = %v, want 3", len(table.Columns)) - } - - // 验证第一列 - col := table.Columns[0] - if col.Name != "id" { - t.Errorf("col.Name = %v, want id", col.Name) - } - if col.Type != "bigint" { - t.Errorf("col.Type = %v, want bigint", col.Type) - } - if col.Nullable { - t.Errorf("col.Nullable = %v, want false", col.Nullable) - } - if col.Default != "nextval('users_id_seq')" { - t.Errorf("col.Default = %v, want nextval('users_id_seq')", col.Default) - } - if col.Comment != "主键" { - t.Errorf("col.Comment = %v, want 主键", col.Comment) - } - if !col.PrimaryKey { - t.Errorf("col.PrimaryKey = %v, want true", col.PrimaryKey) - } -} - -func TestTable_Fields(t *testing.T) { - table := Table{ - Schema: "myschema", - Name: "mytable", - Comment: "test comment", - Columns: []Column{ - {Name: "col1", Type: "int"}, - }, - Indexes: []Index{ - {Name: "idx1", Columns: []string{"col1"}, Unique: true}, - }, - ForeignKeys: []ForeignKey{ - {Name: "fk1", Columns: []string{"col1"}, ReferencedTable: "other", ReferencedColumns: []string{"id"}}, - }, - } - - if table.Schema != "myschema" { - t.Errorf("Schema = %v", table.Schema) - } - if table.Name != "mytable" { - t.Errorf("Name = %v", table.Name) - } - if len(table.Columns) != 1 { - t.Errorf("len(Columns) = %v", len(table.Columns)) - } - if len(table.Indexes) != 1 { - t.Errorf("len(Indexes) = %v", len(table.Indexes)) - } - if len(table.ForeignKeys) != 1 { - t.Errorf("len(ForeignKeys) = %v", len(table.ForeignKeys)) - } -} - -func TestColumn_Fields(t *testing.T) { - col := Column{ - Name: "test_col", - Type: "varchar(100)", - Nullable: true, - Default: "'default'", - Comment: "test comment", - PrimaryKey: false, - } - - if col.Name != "test_col" { - t.Errorf("Name = %v", col.Name) - } - if col.Type != "varchar(100)" { - t.Errorf("Type = %v", col.Type) - } - if !col.Nullable { - t.Errorf("Nullable = %v", col.Nullable) - } - if col.Default != "'default'" { - t.Errorf("Default = %v", col.Default) - } - if col.Comment != "test comment" { - t.Errorf("Comment = %v", col.Comment) - } - if col.PrimaryKey { - t.Errorf("PrimaryKey = %v", col.PrimaryKey) - } -} - -func TestIndex_Fields(t *testing.T) { - idx := Index{ - Name: "test_idx", - Columns: []string{"col1", "col2"}, - Unique: true, - Primary: false, - } - - if idx.Name != "test_idx" { - t.Errorf("Name = %v", idx.Name) - } - if len(idx.Columns) != 2 { - t.Errorf("len(Columns) = %v", len(idx.Columns)) - } - if !idx.Unique { - t.Errorf("Unique = %v", idx.Unique) - } - if idx.Primary { - t.Errorf("Primary = %v", idx.Primary) - } -} - -func TestForeignKey_Fields(t *testing.T) { - fk := ForeignKey{ - Name: "test_fk", - Columns: []string{"user_id"}, - ReferencedTable: "users", - ReferencedColumns: []string{"id"}, - } - - if fk.Name != "test_fk" { - t.Errorf("Name = %v", fk.Name) - } - if len(fk.Columns) != 1 { - t.Errorf("len(Columns) = %v", len(fk.Columns)) - } - if fk.ReferencedTable != "users" { - t.Errorf("ReferencedTable = %v", fk.ReferencedTable) - } - if len(fk.ReferencedColumns) != 1 { - t.Errorf("len(ReferencedColumns) = %v", len(fk.ReferencedColumns)) - } -} - -func TestSchemaOptions(t *testing.T) { - opts := SchemaOptions{ - TablePattern: "user*", - IncludeSystem: true, - } - - if opts.TablePattern != "user*" { - t.Errorf("TablePattern = %v", opts.TablePattern) - } - if !opts.IncludeSystem { - t.Errorf("IncludeSystem = %v", opts.IncludeSystem) - } -} - -func TestDumpSchema_UnsupportedDriver(t *testing.T) { - _, xe := DumpSchema(nil, "nonexistent", nil, SchemaOptions{}) - if xe == nil { - t.Error("expected error for unsupported driver") - } - if xe.Code != "XSQL_DB_DRIVER_UNSUPPORTED" { - t.Errorf("error code = %v, want XSQL_DB_DRIVER_UNSUPPORTED", xe.Code) - } -} - -// Mock driver that doesn't implement SchemaDriver -type mockNonSchemaDriver struct{} - -func (d *mockNonSchemaDriver) Open(ctx interface{}, opts ConnOptions) (interface{}, error) { - return nil, nil -} - -func TestDumpSchema_DriverNotImplementSchema(t *testing.T) { - // Register a mock driver that doesn't implement SchemaDriver - // Note: This test would need to register/unregister which could affect other tests - // Skipping for now as the interface check is straightforward -} +package db + +import ( + "testing" +) + +func TestSchemaInfo_ToSchemaData(t *testing.T) { + tests := []struct { + name string + schema *SchemaInfo + wantDB string + wantLen int + wantOK bool + }{ + { + name: "nil schema", + schema: nil, + wantDB: "", + wantLen: 0, + wantOK: false, + }, + { + name: "empty tables", + schema: &SchemaInfo{Database: "testdb", Tables: []Table{}}, + wantDB: "", + wantLen: 0, + wantOK: false, + }, + { + name: "single table no columns", + schema: &SchemaInfo{ + Database: "testdb", + Tables: []Table{ + {Schema: "public", Name: "users"}, + }, + }, + wantDB: "testdb", + wantLen: 1, + wantOK: true, + }, + { + name: "single table with columns", + schema: &SchemaInfo{ + Database: "testdb", + Tables: []Table{ + { + Schema: "public", + Name: "users", + Comment: "用户表", + Columns: []Column{ + {Name: "id", Type: "bigint", Nullable: false, PrimaryKey: true}, + {Name: "email", Type: "varchar(255)", Nullable: false, Comment: "邮箱"}, + }, + }, + }, + }, + wantDB: "testdb", + wantLen: 1, + wantOK: true, + }, + { + name: "multiple tables", + schema: &SchemaInfo{ + Database: "testdb", + Tables: []Table{ + { + Schema: "public", + Name: "users", + Columns: []Column{ + {Name: "id", Type: "bigint", PrimaryKey: true}, + }, + }, + { + Schema: "public", + Name: "orders", + Columns: []Column{ + {Name: "id", Type: "bigint", PrimaryKey: true}, + {Name: "user_id", Type: "bigint"}, + }, + ForeignKeys: []ForeignKey{ + {Name: "fk_user", Columns: []string{"user_id"}, ReferencedTable: "users", ReferencedColumns: []string{"id"}}, + }, + }, + }, + }, + wantDB: "testdb", + wantLen: 2, + wantOK: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, tables, ok := tt.schema.ToSchemaData() + if ok != tt.wantOK { + t.Errorf("ToSchemaData() ok = %v, want %v", ok, tt.wantOK) + } + if db != tt.wantDB { + t.Errorf("ToSchemaData() db = %v, want %v", db, tt.wantDB) + } + if len(tables) != tt.wantLen { + t.Errorf("ToSchemaData() len(tables) = %v, want %v", len(tables), tt.wantLen) + } + }) + } +} + +func TestSchemaInfo_ToSchemaData_ColumnData(t *testing.T) { + schema := &SchemaInfo{ + Database: "testdb", + Tables: []Table{ + { + Schema: "public", + Name: "users", + Comment: "用户表", + Columns: []Column{ + {Name: "id", Type: "bigint", Nullable: false, Default: "nextval('users_id_seq')", Comment: "主键", PrimaryKey: true}, + {Name: "email", Type: "varchar(255)", Nullable: false, Comment: "邮箱"}, + {Name: "created_at", Type: "timestamp", Nullable: true, Default: "now()"}, + }, + }, + }, + } + + db, tables, ok := schema.ToSchemaData() + if !ok { + t.Fatal("expected ok=true") + } + if db != "testdb" { + t.Errorf("db = %v, want testdb", db) + } + if len(tables) != 1 { + t.Fatalf("len(tables) = %v, want 1", len(tables)) + } + + table := tables[0] + if table.Schema != "public" { + t.Errorf("table.Schema = %v, want public", table.Schema) + } + if table.Name != "users" { + t.Errorf("table.Name = %v, want users", table.Name) + } + if table.Comment != "用户表" { + t.Errorf("table.Comment = %v, want 用户表", table.Comment) + } + if len(table.Columns) != 3 { + t.Fatalf("len(table.Columns) = %v, want 3", len(table.Columns)) + } + + // 验证第一列 + col := table.Columns[0] + if col.Name != "id" { + t.Errorf("col.Name = %v, want id", col.Name) + } + if col.Type != "bigint" { + t.Errorf("col.Type = %v, want bigint", col.Type) + } + if col.Nullable { + t.Errorf("col.Nullable = %v, want false", col.Nullable) + } + if col.Default != "nextval('users_id_seq')" { + t.Errorf("col.Default = %v, want nextval('users_id_seq')", col.Default) + } + if col.Comment != "主键" { + t.Errorf("col.Comment = %v, want 主键", col.Comment) + } + if !col.PrimaryKey { + t.Errorf("col.PrimaryKey = %v, want true", col.PrimaryKey) + } +} + +func TestTable_Fields(t *testing.T) { + table := Table{ + Schema: "myschema", + Name: "mytable", + Comment: "test comment", + Columns: []Column{ + {Name: "col1", Type: "int"}, + }, + Indexes: []Index{ + {Name: "idx1", Columns: []string{"col1"}, Unique: true}, + }, + ForeignKeys: []ForeignKey{ + {Name: "fk1", Columns: []string{"col1"}, ReferencedTable: "other", ReferencedColumns: []string{"id"}}, + }, + } + + if table.Schema != "myschema" { + t.Errorf("Schema = %v", table.Schema) + } + if table.Name != "mytable" { + t.Errorf("Name = %v", table.Name) + } + if len(table.Columns) != 1 { + t.Errorf("len(Columns) = %v", len(table.Columns)) + } + if len(table.Indexes) != 1 { + t.Errorf("len(Indexes) = %v", len(table.Indexes)) + } + if len(table.ForeignKeys) != 1 { + t.Errorf("len(ForeignKeys) = %v", len(table.ForeignKeys)) + } +} + +func TestColumn_Fields(t *testing.T) { + col := Column{ + Name: "test_col", + Type: "varchar(100)", + Nullable: true, + Default: "'default'", + Comment: "test comment", + PrimaryKey: false, + } + + if col.Name != "test_col" { + t.Errorf("Name = %v", col.Name) + } + if col.Type != "varchar(100)" { + t.Errorf("Type = %v", col.Type) + } + if !col.Nullable { + t.Errorf("Nullable = %v", col.Nullable) + } + if col.Default != "'default'" { + t.Errorf("Default = %v", col.Default) + } + if col.Comment != "test comment" { + t.Errorf("Comment = %v", col.Comment) + } + if col.PrimaryKey { + t.Errorf("PrimaryKey = %v", col.PrimaryKey) + } +} + +func TestIndex_Fields(t *testing.T) { + idx := Index{ + Name: "test_idx", + Columns: []string{"col1", "col2"}, + Unique: true, + Primary: false, + } + + if idx.Name != "test_idx" { + t.Errorf("Name = %v", idx.Name) + } + if len(idx.Columns) != 2 { + t.Errorf("len(Columns) = %v", len(idx.Columns)) + } + if !idx.Unique { + t.Errorf("Unique = %v", idx.Unique) + } + if idx.Primary { + t.Errorf("Primary = %v", idx.Primary) + } +} + +func TestForeignKey_Fields(t *testing.T) { + fk := ForeignKey{ + Name: "test_fk", + Columns: []string{"user_id"}, + ReferencedTable: "users", + ReferencedColumns: []string{"id"}, + } + + if fk.Name != "test_fk" { + t.Errorf("Name = %v", fk.Name) + } + if len(fk.Columns) != 1 { + t.Errorf("len(Columns) = %v", len(fk.Columns)) + } + if fk.ReferencedTable != "users" { + t.Errorf("ReferencedTable = %v", fk.ReferencedTable) + } + if len(fk.ReferencedColumns) != 1 { + t.Errorf("len(ReferencedColumns) = %v", len(fk.ReferencedColumns)) + } +} + +func TestSchemaOptions(t *testing.T) { + opts := SchemaOptions{ + TablePattern: "user*", + IncludeSystem: true, + } + + if opts.TablePattern != "user*" { + t.Errorf("TablePattern = %v", opts.TablePattern) + } + if !opts.IncludeSystem { + t.Errorf("IncludeSystem = %v", opts.IncludeSystem) + } +} + +func TestDumpSchema_UnsupportedDriver(t *testing.T) { + _, xe := DumpSchema(nil, "nonexistent", nil, SchemaOptions{}) + if xe == nil { + t.Error("expected error for unsupported driver") + } + if xe.Code != "XSQL_DB_DRIVER_UNSUPPORTED" { + t.Errorf("error code = %v, want XSQL_DB_DRIVER_UNSUPPORTED", xe.Code) + } +} + +// Mock driver that doesn't implement SchemaDriver +type mockNonSchemaDriver struct{} + +func (d *mockNonSchemaDriver) Open(ctx interface{}, opts ConnOptions) (interface{}, error) { + return nil, nil +} + +func TestDumpSchema_DriverNotImplementSchema(t *testing.T) { + // Register a mock driver that doesn't implement SchemaDriver + // Note: This test would need to register/unregister which could affect other tests + // Skipping for now as the interface check is straightforward +} diff --git a/internal/secret/keyring_test.go b/internal/secret/keyring_test.go index 9c6563b..a068557 100644 --- a/internal/secret/keyring_test.go +++ b/internal/secret/keyring_test.go @@ -1,266 +1,266 @@ -package secret - -import ( - "fmt" - "runtime" - "strings" - "testing" - - "github.com/zalando/go-keyring" -) - -// nullByteKeyring 模拟 Windows cmdkey 返回带 null 字节的值 -type nullByteKeyring struct { - data map[string]map[string]string -} - -func newNullByteKeyring() *nullByteKeyring { - return &nullByteKeyring{data: make(map[string]map[string]string)} -} - -func (m *nullByteKeyring) set(service, account, value string) { - if m.data[service] == nil { - m.data[service] = make(map[string]string) - } - m.data[service][account] = value -} - -// setWithNullBytes 模拟 Windows UTF-16 问题:每个字符后插入 null 字节 -func (m *nullByteKeyring) setWithNullBytes(service, account, value string) { - var sb strings.Builder - for _, r := range value { - sb.WriteRune(r) - sb.WriteByte(0x00) - } - m.set(service, account, sb.String()) -} - -func (m *nullByteKeyring) Get(service, account string) (string, error) { - if svc, ok := m.data[service]; ok { - if v, ok := svc[account]; ok { - return v, nil - } - } - return "", fmt.Errorf("not found: %s/%s", service, account) -} - -func (m *nullByteKeyring) Set(service, account, value string) error { - m.set(service, account, value) - return nil -} - -func (m *nullByteKeyring) Delete(service, account string) error { - if svc, ok := m.data[service]; ok { - delete(svc, account) - } - return nil -} - -// ============================================================================= -// Windows null 字节处理测试 -// ============================================================================= - -func TestStripNullBytes(t *testing.T) { - tests := []struct { - name string - input string - want string - }{ - { - name: "no null bytes", - input: "password123", - want: "password123", - }, - { - name: "null bytes between chars", - input: "p\x00a\x00s\x00s\x00", - want: "pass", - }, - { - name: "full password with null bytes", - input: "m\x00y\x00P\x00a\x00s\x00s\x00w\x00o\x00r\x00d\x00", - want: "myPassword", - }, - { - name: "special chars with null bytes", - input: "p\x00@\x00s\x00s\x00!\x00#\x00", - want: "p@ss!#", - }, - { - name: "empty string", - input: "", - want: "", - }, - { - name: "only null bytes", - input: "\x00\x00\x00", - want: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := strings.ReplaceAll(tt.input, "\x00", "") - if got != tt.want { - t.Errorf("stripNullBytes(%q) = %q, want %q", tt.input, got, tt.want) - } - }) - } -} - -func TestNullByteKeyring_SimulatesWindowsBehavior(t *testing.T) { - kr := newNullByteKeyring() - kr.setWithNullBytes("xsql", "prod/password", "secret123") - - val, err := kr.Get("xsql", "prod/password") - if err != nil { - t.Fatalf("Get failed: %v", err) - } - - // 原始值应该包含 null 字节 - if !strings.Contains(val, "\x00") { - t.Error("Expected value to contain null bytes") - } - - // 清理后应该等于原始密码 - cleaned := strings.ReplaceAll(val, "\x00", "") - if cleaned != "secret123" { - t.Errorf("Cleaned value = %q, want %q", cleaned, "secret123") - } -} - -// ============================================================================= -// KeyringAPI 接口合规性测试 -// ============================================================================= - -func TestKeyringAPI_Interface(t *testing.T) { - // 确保 mockKeyring 实现 KeyringAPI 接口 - var _ KeyringAPI = (*mockKeyring)(nil) - var _ KeyringAPI = (*nullByteKeyring)(nil) -} - -func TestKeyringAPI_ErrorCases(t *testing.T) { - kr := newMockKeyring() - - // 空 service - _, err := kr.Get("", "account") - if err == nil { - t.Error("Get with empty service should fail") - } - - // 空 account - _, err = kr.Get("service", "") - if err == nil { - t.Error("Get with empty account should fail") - } - - // 不存在的 service - _, err = kr.Get("nonexistent", "account") - if err == nil { - t.Error("Get with nonexistent service should fail") - } -} - -// ============================================================================= -// Resolve 与 Keyring 集成测试 -// ============================================================================= - -func TestResolve_WithNullByteKeyring(t *testing.T) { - kr := newNullByteKeyring() - // 模拟 Windows 返回带 null 字节的密码 - kr.set("xsql", "prod/password", "s\x00e\x00c\x00r\x00e\x00t\x00") - - // 注意:Resolve 直接使用 keyring 返回值,不做清理 - // 清理逻辑在 keyring_windows.go 的 osKeyring.Get 中 - val, xe := Resolve("keyring:prod/password", Options{Keyring: kr}) - if xe != nil { - t.Fatalf("Resolve failed: %v", xe) - } - - // 由于使用 mockKeyring,不会自动清理 null 字节 - // 这个测试验证 Resolve 正确传递值 - if !strings.Contains(val, "\x00") { - t.Log("Value does not contain null bytes (expected if using cleaned keyring)") - } -} - -func TestResolve_SpecialCharacters(t *testing.T) { - kr := newMockKeyring() - specialPasswords := []string{ - "p@ssw0rd!", - "pass#123$", - "密码123", - "пароль", - "パスワード", - "pass word", - "pass\ttab", - } - - for i, pw := range specialPasswords { - account := fmt.Sprintf("test%d", i) - kr.set("xsql", account, pw) - - val, xe := Resolve(fmt.Sprintf("keyring:%s", account), Options{Keyring: kr}) - if xe != nil { - t.Errorf("Resolve special password %q failed: %v", pw, xe) - continue - } - if val != pw { - t.Errorf("Resolve special password: got %q, want %q", val, pw) - } - } -} - -func TestResolve_EmptyPassword(t *testing.T) { - kr := newMockKeyring() - kr.set("xsql", "empty", "") - - val, xe := Resolve("keyring:empty", Options{Keyring: kr}) - if xe != nil { - t.Fatalf("Resolve failed: %v", xe) - } - if val != "" { - t.Errorf("Expected empty password, got %q", val) - } -} - -func TestResolve_LongPassword(t *testing.T) { - kr := newMockKeyring() - longPass := strings.Repeat("a", 1000) - kr.set("xsql", "long", longPass) - - val, xe := Resolve("keyring:long", Options{Keyring: kr}) - if xe != nil { - t.Fatalf("Resolve failed: %v", xe) - } - if val != longPass { - t.Errorf("Long password mismatch: got len=%d, want len=%d", len(val), len(longPass)) - } -} - -func TestDefaultKeyring_NullByteBehavior(t *testing.T) { - keyring.MockInit() - kr := defaultKeyring() - service := "xsql-test" - account := "null-byte" - raw := "s\x00e\x00c\x00r\x00e\x00t\x00" - if err := kr.Set(service, account, raw); err != nil { - t.Fatalf("Set failed: %v", err) - } - got, err := kr.Get(service, account) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if runtime.GOOS == "windows" { - if strings.Contains(got, "\x00") { - t.Fatalf("expected null bytes to be stripped, got %q", got) - } - if got != "secret" { - t.Fatalf("expected cleaned value, got %q", got) - } - return - } - if got != raw { - t.Fatalf("expected raw value on non-windows, got %q", got) - } -} +package secret + +import ( + "fmt" + "runtime" + "strings" + "testing" + + "github.com/zalando/go-keyring" +) + +// nullByteKeyring 模拟 Windows cmdkey 返回带 null 字节的值 +type nullByteKeyring struct { + data map[string]map[string]string +} + +func newNullByteKeyring() *nullByteKeyring { + return &nullByteKeyring{data: make(map[string]map[string]string)} +} + +func (m *nullByteKeyring) set(service, account, value string) { + if m.data[service] == nil { + m.data[service] = make(map[string]string) + } + m.data[service][account] = value +} + +// setWithNullBytes 模拟 Windows UTF-16 问题:每个字符后插入 null 字节 +func (m *nullByteKeyring) setWithNullBytes(service, account, value string) { + var sb strings.Builder + for _, r := range value { + sb.WriteRune(r) + sb.WriteByte(0x00) + } + m.set(service, account, sb.String()) +} + +func (m *nullByteKeyring) Get(service, account string) (string, error) { + if svc, ok := m.data[service]; ok { + if v, ok := svc[account]; ok { + return v, nil + } + } + return "", fmt.Errorf("not found: %s/%s", service, account) +} + +func (m *nullByteKeyring) Set(service, account, value string) error { + m.set(service, account, value) + return nil +} + +func (m *nullByteKeyring) Delete(service, account string) error { + if svc, ok := m.data[service]; ok { + delete(svc, account) + } + return nil +} + +// ============================================================================= +// Windows null 字节处理测试 +// ============================================================================= + +func TestStripNullBytes(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "no null bytes", + input: "password123", + want: "password123", + }, + { + name: "null bytes between chars", + input: "p\x00a\x00s\x00s\x00", + want: "pass", + }, + { + name: "full password with null bytes", + input: "m\x00y\x00P\x00a\x00s\x00s\x00w\x00o\x00r\x00d\x00", + want: "myPassword", + }, + { + name: "special chars with null bytes", + input: "p\x00@\x00s\x00s\x00!\x00#\x00", + want: "p@ss!#", + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "only null bytes", + input: "\x00\x00\x00", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := strings.ReplaceAll(tt.input, "\x00", "") + if got != tt.want { + t.Errorf("stripNullBytes(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestNullByteKeyring_SimulatesWindowsBehavior(t *testing.T) { + kr := newNullByteKeyring() + kr.setWithNullBytes("xsql", "prod/password", "secret123") + + val, err := kr.Get("xsql", "prod/password") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + // 原始值应该包含 null 字节 + if !strings.Contains(val, "\x00") { + t.Error("Expected value to contain null bytes") + } + + // 清理后应该等于原始密码 + cleaned := strings.ReplaceAll(val, "\x00", "") + if cleaned != "secret123" { + t.Errorf("Cleaned value = %q, want %q", cleaned, "secret123") + } +} + +// ============================================================================= +// KeyringAPI 接口合规性测试 +// ============================================================================= + +func TestKeyringAPI_Interface(t *testing.T) { + // 确保 mockKeyring 实现 KeyringAPI 接口 + var _ KeyringAPI = (*mockKeyring)(nil) + var _ KeyringAPI = (*nullByteKeyring)(nil) +} + +func TestKeyringAPI_ErrorCases(t *testing.T) { + kr := newMockKeyring() + + // 空 service + _, err := kr.Get("", "account") + if err == nil { + t.Error("Get with empty service should fail") + } + + // 空 account + _, err = kr.Get("service", "") + if err == nil { + t.Error("Get with empty account should fail") + } + + // 不存在的 service + _, err = kr.Get("nonexistent", "account") + if err == nil { + t.Error("Get with nonexistent service should fail") + } +} + +// ============================================================================= +// Resolve 与 Keyring 集成测试 +// ============================================================================= + +func TestResolve_WithNullByteKeyring(t *testing.T) { + kr := newNullByteKeyring() + // 模拟 Windows 返回带 null 字节的密码 + kr.set("xsql", "prod/password", "s\x00e\x00c\x00r\x00e\x00t\x00") + + // 注意:Resolve 直接使用 keyring 返回值,不做清理 + // 清理逻辑在 keyring_windows.go 的 osKeyring.Get 中 + val, xe := Resolve("keyring:prod/password", Options{Keyring: kr}) + if xe != nil { + t.Fatalf("Resolve failed: %v", xe) + } + + // 由于使用 mockKeyring,不会自动清理 null 字节 + // 这个测试验证 Resolve 正确传递值 + if !strings.Contains(val, "\x00") { + t.Log("Value does not contain null bytes (expected if using cleaned keyring)") + } +} + +func TestResolve_SpecialCharacters(t *testing.T) { + kr := newMockKeyring() + specialPasswords := []string{ + "p@ssw0rd!", + "pass#123$", + "密码123", + "пароль", + "パスワード", + "pass word", + "pass\ttab", + } + + for i, pw := range specialPasswords { + account := fmt.Sprintf("test%d", i) + kr.set("xsql", account, pw) + + val, xe := Resolve(fmt.Sprintf("keyring:%s", account), Options{Keyring: kr}) + if xe != nil { + t.Errorf("Resolve special password %q failed: %v", pw, xe) + continue + } + if val != pw { + t.Errorf("Resolve special password: got %q, want %q", val, pw) + } + } +} + +func TestResolve_EmptyPassword(t *testing.T) { + kr := newMockKeyring() + kr.set("xsql", "empty", "") + + val, xe := Resolve("keyring:empty", Options{Keyring: kr}) + if xe != nil { + t.Fatalf("Resolve failed: %v", xe) + } + if val != "" { + t.Errorf("Expected empty password, got %q", val) + } +} + +func TestResolve_LongPassword(t *testing.T) { + kr := newMockKeyring() + longPass := strings.Repeat("a", 1000) + kr.set("xsql", "long", longPass) + + val, xe := Resolve("keyring:long", Options{Keyring: kr}) + if xe != nil { + t.Fatalf("Resolve failed: %v", xe) + } + if val != longPass { + t.Errorf("Long password mismatch: got len=%d, want len=%d", len(val), len(longPass)) + } +} + +func TestDefaultKeyring_NullByteBehavior(t *testing.T) { + keyring.MockInit() + kr := defaultKeyring() + service := "xsql-test" + account := "null-byte" + raw := "s\x00e\x00c\x00r\x00e\x00t\x00" + if err := kr.Set(service, account, raw); err != nil { + t.Fatalf("Set failed: %v", err) + } + got, err := kr.Get(service, account) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if runtime.GOOS == "windows" { + if strings.Contains(got, "\x00") { + t.Fatalf("expected null bytes to be stripped, got %q", got) + } + if got != "secret" { + t.Fatalf("expected cleaned value, got %q", got) + } + return + } + if got != raw { + t.Fatalf("expected raw value on non-windows, got %q", got) + } +} diff --git a/tests/integration/db_test.go b/tests/integration/db_test.go index 46d8a9c..634f3fa 100644 --- a/tests/integration/db_test.go +++ b/tests/integration/db_test.go @@ -1,294 +1,294 @@ -//go:build integration - -// Package integration contains integration tests for xsql. -// Run with: go test -tags=integration ./tests/integration/... -// -// These tests require actual database connections: -// - MySQL: XSQL_TEST_MYSQL_DSN -// - PostgreSQL: XSQL_TEST_PG_DSN -package integration - -import ( - "context" - "database/sql" - "os" - "testing" - "time" - - _ "github.com/go-sql-driver/mysql" - _ "github.com/jackc/pgx/v5/stdlib" - - "github.com/zx06/xsql/internal/db" - _ "github.com/zx06/xsql/internal/db/mysql" - _ "github.com/zx06/xsql/internal/db/pg" -) - -func TestMySQLConnection(t *testing.T) { - dsn := os.Getenv("XSQL_TEST_MYSQL_DSN") - if dsn == "" { - t.Skip("XSQL_TEST_MYSQL_DSN not set") - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - conn, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("cannot open mysql: %v", err) - } - defer conn.Close() - - if err := conn.PingContext(ctx); err != nil { - t.Fatalf("mysql ping failed: %v", err) - } - - var result int - err = conn.QueryRowContext(ctx, "SELECT 1").Scan(&result) - if err != nil { - t.Fatalf("SELECT 1 failed: %v", err) - } - if result != 1 { - t.Errorf("expected 1, got %d", result) - } -} - -func TestPostgreSQLConnection(t *testing.T) { - dsn := os.Getenv("XSQL_TEST_PG_DSN") - if dsn == "" { - t.Skip("XSQL_TEST_PG_DSN not set") - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - conn, err := sql.Open("pgx", dsn) - if err != nil { - t.Fatalf("cannot open pg: %v", err) - } - defer conn.Close() - - if err := conn.PingContext(ctx); err != nil { - t.Fatalf("pg ping failed: %v", err) - } - - var result int - err = conn.QueryRowContext(ctx, "SELECT 1").Scan(&result) - if err != nil { - t.Fatalf("SELECT 1 failed: %v", err) - } - if result != 1 { - t.Errorf("expected 1, got %d", result) - } -} - -// TestMySQLDriver_Query tests the full query path through xsql's MySQL driver -func TestMySQLDriver_Query(t *testing.T) { - dsn := os.Getenv("XSQL_TEST_MYSQL_DSN") - if dsn == "" { - t.Skip("XSQL_TEST_MYSQL_DSN not set") - } - - drv, ok := db.Get("mysql") - if !ok { - t.Fatal("mysql driver not registered") - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) - if xe != nil { - t.Fatalf("failed to open: %v", xe) - } - defer conn.Close() - - // Test basic query - result, xe := db.Query(ctx, conn, "SELECT 1 as num, 'hello' as msg", db.QueryOptions{DBType: "mysql"}) - if xe != nil { - t.Fatalf("query failed: %v", xe) - } - - if len(result.Columns) != 2 { - t.Errorf("expected 2 columns, got %d", len(result.Columns)) - } - if len(result.Rows) != 1 { - t.Errorf("expected 1 row, got %d", len(result.Rows)) - } -} - -// TestMySQLDriver_ReadOnlyEnforcement tests that write queries are blocked -func TestMySQLDriver_ReadOnlyEnforcement(t *testing.T) { - dsn := os.Getenv("XSQL_TEST_MYSQL_DSN") - if dsn == "" { - t.Skip("XSQL_TEST_MYSQL_DSN not set") - } - - drv, ok := db.Get("mysql") - if !ok { - t.Fatal("mysql driver not registered") - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) - if xe != nil { - t.Fatalf("failed to open: %v", xe) - } - defer conn.Close() - - // Write query should be blocked - _, xe = db.Query(ctx, conn, "INSERT INTO test VALUES (1)", db.QueryOptions{DBType: "mysql"}) - if xe == nil { - t.Fatal("expected error for INSERT in read-only mode") - } - if xe.Code != "XSQL_RO_BLOCKED" { - t.Errorf("expected XSQL_RO_BLOCKED, got %s", xe.Code) - } -} - -// TestPgDriver_Query tests the full query path through xsql's PG driver -func TestPgDriver_Query(t *testing.T) { - dsn := os.Getenv("XSQL_TEST_PG_DSN") - if dsn == "" { - t.Skip("XSQL_TEST_PG_DSN not set") - } - - drv, ok := db.Get("pg") - if !ok { - t.Fatal("pg driver not registered") - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) - if xe != nil { - t.Fatalf("failed to open: %v", xe) - } - defer conn.Close() - - // Test basic query - result, xe := db.Query(ctx, conn, "SELECT 1 as num, 'hello' as msg", db.QueryOptions{DBType: "pg"}) - if xe != nil { - t.Fatalf("query failed: %v", xe) - } - - if len(result.Columns) != 2 { - t.Errorf("expected 2 columns, got %d", len(result.Columns)) - } - if len(result.Rows) != 1 { - t.Errorf("expected 1 row, got %d", len(result.Rows)) - } -} - -// TestPgDriver_ReadOnlyEnforcement tests that write queries are blocked -func TestPgDriver_ReadOnlyEnforcement(t *testing.T) { - dsn := os.Getenv("XSQL_TEST_PG_DSN") - if dsn == "" { - t.Skip("XSQL_TEST_PG_DSN not set") - } - - drv, ok := db.Get("pg") - if !ok { - t.Fatal("pg driver not registered") - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) - if xe != nil { - t.Fatalf("failed to open: %v", xe) - } - defer conn.Close() - - // Write query should be blocked - _, xe = db.Query(ctx, conn, "INSERT INTO test VALUES (1)", db.QueryOptions{DBType: "pg"}) - if xe == nil { - t.Fatal("expected error for INSERT in read-only mode") - } - if xe.Code != "XSQL_RO_BLOCKED" { - t.Errorf("expected XSQL_RO_BLOCKED, got %s", xe.Code) - } -} - -// TestMySQLDriver_ComplexQuery tests more complex queries -func TestMySQLDriver_ComplexQuery(t *testing.T) { - dsn := os.Getenv("XSQL_TEST_MYSQL_DSN") - if dsn == "" { - t.Skip("XSQL_TEST_MYSQL_DSN not set") - } - - drv, ok := db.Get("mysql") - if !ok { - t.Fatal("mysql driver not registered") - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) - if xe != nil { - t.Fatalf("failed to open: %v", xe) - } - defer conn.Close() - - // Test SHOW statement - result, xe := db.Query(ctx, conn, "SHOW DATABASES", db.QueryOptions{DBType: "pg"}) - if xe != nil { - t.Fatalf("SHOW DATABASES failed: %v", xe) - } - if len(result.Rows) == 0 { - t.Error("expected at least one database") - } - - // Test EXPLAIN - result, xe = db.Query(ctx, conn, "EXPLAIN SELECT 1", db.QueryOptions{DBType: "pg"}) - if xe != nil { - t.Fatalf("EXPLAIN failed: %v", xe) - } - if len(result.Columns) == 0 { - t.Error("EXPLAIN should return columns") - } -} - -// TestPgDriver_ComplexQuery tests more complex queries -func TestPgDriver_ComplexQuery(t *testing.T) { - dsn := os.Getenv("XSQL_TEST_PG_DSN") - if dsn == "" { - t.Skip("XSQL_TEST_PG_DSN not set") - } - - drv, ok := db.Get("pg") - if !ok { - t.Fatal("pg driver not registered") - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) - if xe != nil { - t.Fatalf("failed to open: %v", xe) - } - defer conn.Close() - - // Test system catalog query - result, xe := db.Query(ctx, conn, "SELECT datname FROM pg_database LIMIT 5", db.QueryOptions{DBType: "pg"}) - if xe != nil { - t.Fatalf("pg_database query failed: %v", xe) - } - if len(result.Rows) == 0 { - t.Error("expected at least one database") - } - - // Test EXPLAIN - result, xe = db.Query(ctx, conn, "EXPLAIN SELECT 1", db.QueryOptions{DBType: "pg"}) - if xe != nil { - t.Fatalf("EXPLAIN failed: %v", xe) - } - if len(result.Rows) == 0 { - t.Error("EXPLAIN should return rows") - } -} +//go:build integration + +// Package integration contains integration tests for xsql. +// Run with: go test -tags=integration ./tests/integration/... +// +// These tests require actual database connections: +// - MySQL: XSQL_TEST_MYSQL_DSN +// - PostgreSQL: XSQL_TEST_PG_DSN +package integration + +import ( + "context" + "database/sql" + "os" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/jackc/pgx/v5/stdlib" + + "github.com/zx06/xsql/internal/db" + _ "github.com/zx06/xsql/internal/db/mysql" + _ "github.com/zx06/xsql/internal/db/pg" +) + +func TestMySQLConnection(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_MYSQL_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_MYSQL_DSN not set") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, err := sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("cannot open mysql: %v", err) + } + defer conn.Close() + + if err := conn.PingContext(ctx); err != nil { + t.Fatalf("mysql ping failed: %v", err) + } + + var result int + err = conn.QueryRowContext(ctx, "SELECT 1").Scan(&result) + if err != nil { + t.Fatalf("SELECT 1 failed: %v", err) + } + if result != 1 { + t.Errorf("expected 1, got %d", result) + } +} + +func TestPostgreSQLConnection(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_PG_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_PG_DSN not set") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, err := sql.Open("pgx", dsn) + if err != nil { + t.Fatalf("cannot open pg: %v", err) + } + defer conn.Close() + + if err := conn.PingContext(ctx); err != nil { + t.Fatalf("pg ping failed: %v", err) + } + + var result int + err = conn.QueryRowContext(ctx, "SELECT 1").Scan(&result) + if err != nil { + t.Fatalf("SELECT 1 failed: %v", err) + } + if result != 1 { + t.Errorf("expected 1, got %d", result) + } +} + +// TestMySQLDriver_Query tests the full query path through xsql's MySQL driver +func TestMySQLDriver_Query(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_MYSQL_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_MYSQL_DSN not set") + } + + drv, ok := db.Get("mysql") + if !ok { + t.Fatal("mysql driver not registered") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) + if xe != nil { + t.Fatalf("failed to open: %v", xe) + } + defer conn.Close() + + // Test basic query + result, xe := db.Query(ctx, conn, "SELECT 1 as num, 'hello' as msg", db.QueryOptions{DBType: "mysql"}) + if xe != nil { + t.Fatalf("query failed: %v", xe) + } + + if len(result.Columns) != 2 { + t.Errorf("expected 2 columns, got %d", len(result.Columns)) + } + if len(result.Rows) != 1 { + t.Errorf("expected 1 row, got %d", len(result.Rows)) + } +} + +// TestMySQLDriver_ReadOnlyEnforcement tests that write queries are blocked +func TestMySQLDriver_ReadOnlyEnforcement(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_MYSQL_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_MYSQL_DSN not set") + } + + drv, ok := db.Get("mysql") + if !ok { + t.Fatal("mysql driver not registered") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) + if xe != nil { + t.Fatalf("failed to open: %v", xe) + } + defer conn.Close() + + // Write query should be blocked + _, xe = db.Query(ctx, conn, "INSERT INTO test VALUES (1)", db.QueryOptions{DBType: "mysql"}) + if xe == nil { + t.Fatal("expected error for INSERT in read-only mode") + } + if xe.Code != "XSQL_RO_BLOCKED" { + t.Errorf("expected XSQL_RO_BLOCKED, got %s", xe.Code) + } +} + +// TestPgDriver_Query tests the full query path through xsql's PG driver +func TestPgDriver_Query(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_PG_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_PG_DSN not set") + } + + drv, ok := db.Get("pg") + if !ok { + t.Fatal("pg driver not registered") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) + if xe != nil { + t.Fatalf("failed to open: %v", xe) + } + defer conn.Close() + + // Test basic query + result, xe := db.Query(ctx, conn, "SELECT 1 as num, 'hello' as msg", db.QueryOptions{DBType: "pg"}) + if xe != nil { + t.Fatalf("query failed: %v", xe) + } + + if len(result.Columns) != 2 { + t.Errorf("expected 2 columns, got %d", len(result.Columns)) + } + if len(result.Rows) != 1 { + t.Errorf("expected 1 row, got %d", len(result.Rows)) + } +} + +// TestPgDriver_ReadOnlyEnforcement tests that write queries are blocked +func TestPgDriver_ReadOnlyEnforcement(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_PG_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_PG_DSN not set") + } + + drv, ok := db.Get("pg") + if !ok { + t.Fatal("pg driver not registered") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) + if xe != nil { + t.Fatalf("failed to open: %v", xe) + } + defer conn.Close() + + // Write query should be blocked + _, xe = db.Query(ctx, conn, "INSERT INTO test VALUES (1)", db.QueryOptions{DBType: "pg"}) + if xe == nil { + t.Fatal("expected error for INSERT in read-only mode") + } + if xe.Code != "XSQL_RO_BLOCKED" { + t.Errorf("expected XSQL_RO_BLOCKED, got %s", xe.Code) + } +} + +// TestMySQLDriver_ComplexQuery tests more complex queries +func TestMySQLDriver_ComplexQuery(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_MYSQL_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_MYSQL_DSN not set") + } + + drv, ok := db.Get("mysql") + if !ok { + t.Fatal("mysql driver not registered") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) + if xe != nil { + t.Fatalf("failed to open: %v", xe) + } + defer conn.Close() + + // Test SHOW statement + result, xe := db.Query(ctx, conn, "SHOW DATABASES", db.QueryOptions{DBType: "pg"}) + if xe != nil { + t.Fatalf("SHOW DATABASES failed: %v", xe) + } + if len(result.Rows) == 0 { + t.Error("expected at least one database") + } + + // Test EXPLAIN + result, xe = db.Query(ctx, conn, "EXPLAIN SELECT 1", db.QueryOptions{DBType: "pg"}) + if xe != nil { + t.Fatalf("EXPLAIN failed: %v", xe) + } + if len(result.Columns) == 0 { + t.Error("EXPLAIN should return columns") + } +} + +// TestPgDriver_ComplexQuery tests more complex queries +func TestPgDriver_ComplexQuery(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_PG_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_PG_DSN not set") + } + + drv, ok := db.Get("pg") + if !ok { + t.Fatal("pg driver not registered") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) + if xe != nil { + t.Fatalf("failed to open: %v", xe) + } + defer conn.Close() + + // Test system catalog query + result, xe := db.Query(ctx, conn, "SELECT datname FROM pg_database LIMIT 5", db.QueryOptions{DBType: "pg"}) + if xe != nil { + t.Fatalf("pg_database query failed: %v", xe) + } + if len(result.Rows) == 0 { + t.Error("expected at least one database") + } + + // Test EXPLAIN + result, xe = db.Query(ctx, conn, "EXPLAIN SELECT 1", db.QueryOptions{DBType: "pg"}) + if xe != nil { + t.Fatalf("EXPLAIN failed: %v", xe) + } + if len(result.Rows) == 0 { + t.Error("EXPLAIN should return rows") + } +} diff --git a/tests/integration/schema_dump_test.go b/tests/integration/schema_dump_test.go index feda791..b2d9242 100644 --- a/tests/integration/schema_dump_test.go +++ b/tests/integration/schema_dump_test.go @@ -1,424 +1,424 @@ -//go:build integration - -package integration - -import ( - "context" - "fmt" - "os" - "strings" - "testing" - "time" - - "github.com/zx06/xsql/internal/db" - _ "github.com/zx06/xsql/internal/db/mysql" - _ "github.com/zx06/xsql/internal/db/pg" -) - -func TestSchemaDump_MySQL_RealDB(t *testing.T) { - dsn := os.Getenv("XSQL_TEST_MYSQL_DSN") - if dsn == "" { - t.Skip("XSQL_TEST_MYSQL_DSN not set") - } - - drv, ok := db.Get("mysql") - if !ok { - t.Fatal("mysql driver not registered") - } - - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - - conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) - if xe != nil { - t.Fatalf("failed to open mysql: %v", xe) - } - defer conn.Close() - - suffix := time.Now().UnixNano() - prefix := fmt.Sprintf("xsql_schema_%d", suffix) - usersTable := prefix + "_users" - ordersTable := prefix + "_orders" - - // 清理旧表 - _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", ordersTable)) - _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", usersTable)) - - // 创建表结构(包含注释与默认值) - _, err := conn.ExecContext(ctx, fmt.Sprintf(` - CREATE TABLE %s ( - id BIGINT PRIMARY KEY COMMENT '主键', - email VARCHAR(255) NOT NULL, - tenant_id BIGINT NOT NULL, - status VARCHAR(20) NOT NULL DEFAULT 'active' COMMENT '状态', - created_at DATETIME NULL DEFAULT CURRENT_TIMESTAMP, - INDEX idx_email (email), - UNIQUE KEY uq_tenant_id (tenant_id, id), - INDEX idx_tenant_email (tenant_id, email) - ) ENGINE=InnoDB COMMENT='用户表' - `, usersTable)) - if err != nil { - t.Fatalf("create users table failed: %v", err) - } - - _, err = conn.ExecContext(ctx, fmt.Sprintf(` - CREATE TABLE %s ( - id BIGINT PRIMARY KEY, - tenant_id BIGINT NOT NULL, - user_id BIGINT NOT NULL, - amount DECIMAL(10,2) NOT NULL, - INDEX idx_tenant_user (tenant_id, user_id), - CONSTRAINT fk_%s_user FOREIGN KEY (tenant_id, user_id) REFERENCES %s(tenant_id, id) - ) ENGINE=InnoDB - `, ordersTable, ordersTable, usersTable)) - if err != nil { - t.Fatalf("create orders table failed: %v", err) - } - - t.Cleanup(func() { - _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", ordersTable)) - _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", usersTable)) - }) - - info, xe := db.DumpSchema(ctx, "mysql", conn, db.SchemaOptions{ - TablePattern: prefix + "*", - }) - if xe != nil { - t.Fatalf("DumpSchema error: %v", xe) - } - if info.Database == "" { - t.Fatalf("database name is empty") - } - - infoNoFilter, xe := db.DumpSchema(ctx, "mysql", conn, db.SchemaOptions{}) - if xe != nil { - t.Fatalf("DumpSchema no-filter error: %v", xe) - } - if len(infoNoFilter.Tables) == 0 { - t.Fatalf("expected tables for no-filter dump") - } - - infoEmpty, xe := db.DumpSchema(ctx, "mysql", conn, db.SchemaOptions{ - TablePattern: "no_match_*", - }) - if xe != nil { - t.Fatalf("DumpSchema empty filter error: %v", xe) - } - if len(infoEmpty.Tables) != 0 { - t.Fatalf("expected empty tables for no_match_* filter") - } - - users := findTable(info.Tables, usersTable) - orders := findTable(info.Tables, ordersTable) - if users == nil || orders == nil { - t.Fatalf("missing tables in schema dump: users=%v orders=%v", users != nil, orders != nil) - } - - if users.Schema == "" { - t.Fatalf("users schema is empty") - } - if len(users.Columns) == 0 { - t.Fatalf("users columns should not be empty") - } - - if !hasColumn(users, "id", true) { - t.Fatalf("users table missing primary key column 'id'") - } - if !hasIndex(users, "PRIMARY") { - t.Fatalf("users table missing PRIMARY index") - } - if !hasIndex(users, "idx_email") { - t.Fatalf("users table missing idx_email index") - } - if !hasIndex(users, "uq_tenant_id") { - t.Fatalf("users table missing uq_tenant_id index") - } - if !hasIndex(users, "idx_tenant_email") { - t.Fatalf("users table missing idx_tenant_email index") - } - - if !hasColumnComment(users, "id", "主键") { - t.Fatalf("users table column 'id' missing comment") - } - if !hasColumnComment(users, "status", "状态") { - t.Fatalf("users table column 'status' missing comment") - } - if !hasColumnDefault(users, "status", "active") { - t.Fatalf("users table column 'status' missing default value") - } - - if users.Comment != "用户表" { - t.Fatalf("users table missing comment") - } - - if !hasIndex(orders, "idx_tenant_user") { - t.Fatalf("orders table missing idx_tenant_user index") - } - if len(orders.ForeignKeys) == 0 { - t.Fatalf("orders table should have foreign keys") - } - if !hasForeignKeyTo(orders, usersTable) { - t.Fatalf("orders table missing FK to %s", usersTable) - } - if !hasCompositeForeignKeyTo(orders, usersTable) { - t.Fatalf("orders table missing composite FK to %s", usersTable) - } -} - -func TestSchemaDump_Pg_RealDB(t *testing.T) { - dsn := os.Getenv("XSQL_TEST_PG_DSN") - if dsn == "" { - t.Skip("XSQL_TEST_PG_DSN not set") - } - - drv, ok := db.Get("pg") - if !ok { - t.Fatal("pg driver not registered") - } - - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - - conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) - if xe != nil { - t.Fatalf("failed to open pg: %v", xe) - } - defer conn.Close() - - suffix := time.Now().UnixNano() - schema := fmt.Sprintf("xsql_schema_%d", suffix) - usersTable := "users" - ordersTable := "orders" - prefix := "xsql_" - - // 清理旧 schema - _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schema)) - - // 创建 schema 与表 - _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA %s", schema)) - if err != nil { - t.Fatalf("create schema failed: %v", err) - } - - _, err = conn.ExecContext(ctx, fmt.Sprintf(` - CREATE TABLE %s.%s ( - id BIGSERIAL PRIMARY KEY, - tenant_id BIGINT NOT NULL, - email VARCHAR(255) NOT NULL, - status TEXT NOT NULL DEFAULT 'active', - created_at TIMESTAMPTZ NULL DEFAULT NOW(), - UNIQUE (tenant_id, id) - ) - `, schema, prefix+usersTable)) - if err != nil { - t.Fatalf("create users table failed: %v", err) - } - - _, err = conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON TABLE %s.%s IS '用户表'`, schema, prefix+usersTable)) - if err != nil { - t.Fatalf("comment table failed: %v", err) - } - _, err = conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON COLUMN %s.%s.id IS '主键'`, schema, prefix+usersTable)) - if err != nil { - t.Fatalf("comment column failed: %v", err) - } - _, err = conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON COLUMN %s.%s.status IS '状态'`, schema, prefix+usersTable)) - if err != nil { - t.Fatalf("comment column failed: %v", err) - } - - _, err = conn.ExecContext(ctx, fmt.Sprintf(` - CREATE INDEX idx_email ON %s.%s (email) - `, schema, prefix+usersTable)) - if err != nil { - t.Fatalf("create index failed: %v", err) - } - _, err = conn.ExecContext(ctx, fmt.Sprintf(` - CREATE INDEX idx_tenant_email ON %s.%s (tenant_id, email) - `, schema, prefix+usersTable)) - if err != nil { - t.Fatalf("create index failed: %v", err) - } - - _, err = conn.ExecContext(ctx, fmt.Sprintf(` - CREATE TABLE %s.%s ( - id BIGSERIAL PRIMARY KEY, - tenant_id BIGINT NOT NULL, - user_id BIGINT NOT NULL, - amount NUMERIC(10,2) NOT NULL, - CONSTRAINT fk_%s_user FOREIGN KEY (tenant_id, user_id) REFERENCES %s.%s(tenant_id, id) - ) - `, schema, prefix+ordersTable, prefix+ordersTable, schema, prefix+usersTable)) - if err != nil { - t.Fatalf("create orders table failed: %v", err) - } - _, err = conn.ExecContext(ctx, fmt.Sprintf(` - CREATE INDEX idx_tenant_user ON %s.%s (tenant_id, user_id) - `, schema, prefix+ordersTable)) - if err != nil { - t.Fatalf("create index failed: %v", err) - } - - t.Cleanup(func() { - _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schema)) - }) - - info, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{ - TablePattern: prefix + "*", - }) - if xe != nil { - t.Fatalf("DumpSchema error: %v", xe) - } - if info.Database == "" { - t.Fatalf("database name is empty") - } - - infoNoFilter, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{}) - if xe != nil { - t.Fatalf("DumpSchema no-filter error: %v", xe) - } - if len(infoNoFilter.Tables) == 0 { - t.Fatalf("expected tables for no-filter dump") - } - - infoWithSystem, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{ - TablePattern: prefix + "*", - IncludeSystem: true, - }) - if xe != nil { - t.Fatalf("DumpSchema include-system error: %v", xe) - } - if infoWithSystem.Database == "" { - t.Fatalf("database name is empty for include-system") - } - - infoEmpty, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{ - TablePattern: "no_match_*", - }) - if xe != nil { - t.Fatalf("DumpSchema empty filter error: %v", xe) - } - if len(infoEmpty.Tables) != 0 { - t.Fatalf("expected empty tables for no_match_* filter") - } - - users := findTableWithSchema(info.Tables, schema, prefix+usersTable) - orders := findTableWithSchema(info.Tables, schema, prefix+ordersTable) - if users == nil || orders == nil { - t.Fatalf("missing tables in schema dump: users=%v orders=%v", users != nil, orders != nil) - } - - if !hasColumn(users, "id", true) { - t.Fatalf("users table missing primary key column 'id'") - } - if len(users.Indexes) == 0 { - t.Fatalf("users table should have indexes") - } - if !hasIndex(users, "idx_email") { - t.Fatalf("users table missing idx_email index") - } - if !hasIndex(users, "idx_tenant_email") { - t.Fatalf("users table missing idx_tenant_email index") - } - - if !hasColumnDefault(users, "status", "active") { - t.Fatalf("users table column 'status' missing default value") - } - - if users.Comment != "用户表" { - t.Fatalf("users table missing comment") - } - if !hasColumnComment(users, "id", "主键") { - t.Fatalf("users table column 'id' missing comment") - } - if !hasColumnComment(users, "status", "状态") { - t.Fatalf("users table column 'status' missing comment") - } - - if !hasIndex(orders, "idx_tenant_user") { - t.Fatalf("orders table missing idx_tenant_user index") - } - if len(orders.ForeignKeys) == 0 { - t.Fatalf("orders table should have foreign keys") - } - if !hasForeignKeyTo(orders, prefix+usersTable) { - t.Fatalf("orders table missing FK to %s", prefix+usersTable) - } - if !hasCompositeForeignKeyTo(orders, prefix+usersTable) { - t.Fatalf("orders table missing composite FK to %s", prefix+usersTable) - } -} - -func findTable(tables []db.Table, name string) *db.Table { - for i := range tables { - if tables[i].Name == name { - return &tables[i] - } - } - return nil -} - -func findTableWithSchema(tables []db.Table, schema, name string) *db.Table { - for i := range tables { - if tables[i].Schema == schema && tables[i].Name == name { - return &tables[i] - } - } - return nil -} - -func hasColumn(table *db.Table, name string, primary bool) bool { - for _, c := range table.Columns { - if c.Name == name && c.PrimaryKey == primary { - return true - } - } - return false -} - -func hasIndex(table *db.Table, indexName string) bool { - for _, idx := range table.Indexes { - if idx.Name == indexName { - return true - } - } - return false -} - -func hasForeignKeyTo(table *db.Table, referencedTable string) bool { - for _, fk := range table.ForeignKeys { - if strings.EqualFold(fk.ReferencedTable, referencedTable) { - return true - } - } - return false -} - -func hasCompositeForeignKeyTo(table *db.Table, referencedTable string) bool { - for _, fk := range table.ForeignKeys { - if strings.EqualFold(fk.ReferencedTable, referencedTable) && - len(fk.Columns) >= 2 && - len(fk.ReferencedColumns) >= 2 { - return true - } - } - return false -} - -func hasColumnComment(table *db.Table, name, comment string) bool { - for _, c := range table.Columns { - if c.Name == name && c.Comment == comment { - return true - } - } - return false -} - -func hasColumnDefault(table *db.Table, name, want string) bool { - for _, c := range table.Columns { - if c.Name == name && strings.Contains(c.Default, want) { - return true - } - } - return false -} +//go:build integration + +package integration + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/zx06/xsql/internal/db" + _ "github.com/zx06/xsql/internal/db/mysql" + _ "github.com/zx06/xsql/internal/db/pg" +) + +func TestSchemaDump_MySQL_RealDB(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_MYSQL_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_MYSQL_DSN not set") + } + + drv, ok := db.Get("mysql") + if !ok { + t.Fatal("mysql driver not registered") + } + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) + if xe != nil { + t.Fatalf("failed to open mysql: %v", xe) + } + defer conn.Close() + + suffix := time.Now().UnixNano() + prefix := fmt.Sprintf("xsql_schema_%d", suffix) + usersTable := prefix + "_users" + ordersTable := prefix + "_orders" + + // 清理旧表 + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", ordersTable)) + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", usersTable)) + + // 创建表结构(包含注释与默认值) + _, err := conn.ExecContext(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id BIGINT PRIMARY KEY COMMENT '主键', + email VARCHAR(255) NOT NULL, + tenant_id BIGINT NOT NULL, + status VARCHAR(20) NOT NULL DEFAULT 'active' COMMENT '状态', + created_at DATETIME NULL DEFAULT CURRENT_TIMESTAMP, + INDEX idx_email (email), + UNIQUE KEY uq_tenant_id (tenant_id, id), + INDEX idx_tenant_email (tenant_id, email) + ) ENGINE=InnoDB COMMENT='用户表' + `, usersTable)) + if err != nil { + t.Fatalf("create users table failed: %v", err) + } + + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id BIGINT PRIMARY KEY, + tenant_id BIGINT NOT NULL, + user_id BIGINT NOT NULL, + amount DECIMAL(10,2) NOT NULL, + INDEX idx_tenant_user (tenant_id, user_id), + CONSTRAINT fk_%s_user FOREIGN KEY (tenant_id, user_id) REFERENCES %s(tenant_id, id) + ) ENGINE=InnoDB + `, ordersTable, ordersTable, usersTable)) + if err != nil { + t.Fatalf("create orders table failed: %v", err) + } + + t.Cleanup(func() { + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", ordersTable)) + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", usersTable)) + }) + + info, xe := db.DumpSchema(ctx, "mysql", conn, db.SchemaOptions{ + TablePattern: prefix + "*", + }) + if xe != nil { + t.Fatalf("DumpSchema error: %v", xe) + } + if info.Database == "" { + t.Fatalf("database name is empty") + } + + infoNoFilter, xe := db.DumpSchema(ctx, "mysql", conn, db.SchemaOptions{}) + if xe != nil { + t.Fatalf("DumpSchema no-filter error: %v", xe) + } + if len(infoNoFilter.Tables) == 0 { + t.Fatalf("expected tables for no-filter dump") + } + + infoEmpty, xe := db.DumpSchema(ctx, "mysql", conn, db.SchemaOptions{ + TablePattern: "no_match_*", + }) + if xe != nil { + t.Fatalf("DumpSchema empty filter error: %v", xe) + } + if len(infoEmpty.Tables) != 0 { + t.Fatalf("expected empty tables for no_match_* filter") + } + + users := findTable(info.Tables, usersTable) + orders := findTable(info.Tables, ordersTable) + if users == nil || orders == nil { + t.Fatalf("missing tables in schema dump: users=%v orders=%v", users != nil, orders != nil) + } + + if users.Schema == "" { + t.Fatalf("users schema is empty") + } + if len(users.Columns) == 0 { + t.Fatalf("users columns should not be empty") + } + + if !hasColumn(users, "id", true) { + t.Fatalf("users table missing primary key column 'id'") + } + if !hasIndex(users, "PRIMARY") { + t.Fatalf("users table missing PRIMARY index") + } + if !hasIndex(users, "idx_email") { + t.Fatalf("users table missing idx_email index") + } + if !hasIndex(users, "uq_tenant_id") { + t.Fatalf("users table missing uq_tenant_id index") + } + if !hasIndex(users, "idx_tenant_email") { + t.Fatalf("users table missing idx_tenant_email index") + } + + if !hasColumnComment(users, "id", "主键") { + t.Fatalf("users table column 'id' missing comment") + } + if !hasColumnComment(users, "status", "状态") { + t.Fatalf("users table column 'status' missing comment") + } + if !hasColumnDefault(users, "status", "active") { + t.Fatalf("users table column 'status' missing default value") + } + + if users.Comment != "用户表" { + t.Fatalf("users table missing comment") + } + + if !hasIndex(orders, "idx_tenant_user") { + t.Fatalf("orders table missing idx_tenant_user index") + } + if len(orders.ForeignKeys) == 0 { + t.Fatalf("orders table should have foreign keys") + } + if !hasForeignKeyTo(orders, usersTable) { + t.Fatalf("orders table missing FK to %s", usersTable) + } + if !hasCompositeForeignKeyTo(orders, usersTable) { + t.Fatalf("orders table missing composite FK to %s", usersTable) + } +} + +func TestSchemaDump_Pg_RealDB(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_PG_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_PG_DSN not set") + } + + drv, ok := db.Get("pg") + if !ok { + t.Fatal("pg driver not registered") + } + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) + if xe != nil { + t.Fatalf("failed to open pg: %v", xe) + } + defer conn.Close() + + suffix := time.Now().UnixNano() + schema := fmt.Sprintf("xsql_schema_%d", suffix) + usersTable := "users" + ordersTable := "orders" + prefix := "xsql_" + + // 清理旧 schema + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schema)) + + // 创建 schema 与表 + _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA %s", schema)) + if err != nil { + t.Fatalf("create schema failed: %v", err) + } + + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE TABLE %s.%s ( + id BIGSERIAL PRIMARY KEY, + tenant_id BIGINT NOT NULL, + email VARCHAR(255) NOT NULL, + status TEXT NOT NULL DEFAULT 'active', + created_at TIMESTAMPTZ NULL DEFAULT NOW(), + UNIQUE (tenant_id, id) + ) + `, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("create users table failed: %v", err) + } + + _, err = conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON TABLE %s.%s IS '用户表'`, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("comment table failed: %v", err) + } + _, err = conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON COLUMN %s.%s.id IS '主键'`, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("comment column failed: %v", err) + } + _, err = conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON COLUMN %s.%s.status IS '状态'`, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("comment column failed: %v", err) + } + + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE INDEX idx_email ON %s.%s (email) + `, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("create index failed: %v", err) + } + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE INDEX idx_tenant_email ON %s.%s (tenant_id, email) + `, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("create index failed: %v", err) + } + + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE TABLE %s.%s ( + id BIGSERIAL PRIMARY KEY, + tenant_id BIGINT NOT NULL, + user_id BIGINT NOT NULL, + amount NUMERIC(10,2) NOT NULL, + CONSTRAINT fk_%s_user FOREIGN KEY (tenant_id, user_id) REFERENCES %s.%s(tenant_id, id) + ) + `, schema, prefix+ordersTable, prefix+ordersTable, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("create orders table failed: %v", err) + } + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE INDEX idx_tenant_user ON %s.%s (tenant_id, user_id) + `, schema, prefix+ordersTable)) + if err != nil { + t.Fatalf("create index failed: %v", err) + } + + t.Cleanup(func() { + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schema)) + }) + + info, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{ + TablePattern: prefix + "*", + }) + if xe != nil { + t.Fatalf("DumpSchema error: %v", xe) + } + if info.Database == "" { + t.Fatalf("database name is empty") + } + + infoNoFilter, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{}) + if xe != nil { + t.Fatalf("DumpSchema no-filter error: %v", xe) + } + if len(infoNoFilter.Tables) == 0 { + t.Fatalf("expected tables for no-filter dump") + } + + infoWithSystem, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{ + TablePattern: prefix + "*", + IncludeSystem: true, + }) + if xe != nil { + t.Fatalf("DumpSchema include-system error: %v", xe) + } + if infoWithSystem.Database == "" { + t.Fatalf("database name is empty for include-system") + } + + infoEmpty, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{ + TablePattern: "no_match_*", + }) + if xe != nil { + t.Fatalf("DumpSchema empty filter error: %v", xe) + } + if len(infoEmpty.Tables) != 0 { + t.Fatalf("expected empty tables for no_match_* filter") + } + + users := findTableWithSchema(info.Tables, schema, prefix+usersTable) + orders := findTableWithSchema(info.Tables, schema, prefix+ordersTable) + if users == nil || orders == nil { + t.Fatalf("missing tables in schema dump: users=%v orders=%v", users != nil, orders != nil) + } + + if !hasColumn(users, "id", true) { + t.Fatalf("users table missing primary key column 'id'") + } + if len(users.Indexes) == 0 { + t.Fatalf("users table should have indexes") + } + if !hasIndex(users, "idx_email") { + t.Fatalf("users table missing idx_email index") + } + if !hasIndex(users, "idx_tenant_email") { + t.Fatalf("users table missing idx_tenant_email index") + } + + if !hasColumnDefault(users, "status", "active") { + t.Fatalf("users table column 'status' missing default value") + } + + if users.Comment != "用户表" { + t.Fatalf("users table missing comment") + } + if !hasColumnComment(users, "id", "主键") { + t.Fatalf("users table column 'id' missing comment") + } + if !hasColumnComment(users, "status", "状态") { + t.Fatalf("users table column 'status' missing comment") + } + + if !hasIndex(orders, "idx_tenant_user") { + t.Fatalf("orders table missing idx_tenant_user index") + } + if len(orders.ForeignKeys) == 0 { + t.Fatalf("orders table should have foreign keys") + } + if !hasForeignKeyTo(orders, prefix+usersTable) { + t.Fatalf("orders table missing FK to %s", prefix+usersTable) + } + if !hasCompositeForeignKeyTo(orders, prefix+usersTable) { + t.Fatalf("orders table missing composite FK to %s", prefix+usersTable) + } +} + +func findTable(tables []db.Table, name string) *db.Table { + for i := range tables { + if tables[i].Name == name { + return &tables[i] + } + } + return nil +} + +func findTableWithSchema(tables []db.Table, schema, name string) *db.Table { + for i := range tables { + if tables[i].Schema == schema && tables[i].Name == name { + return &tables[i] + } + } + return nil +} + +func hasColumn(table *db.Table, name string, primary bool) bool { + for _, c := range table.Columns { + if c.Name == name && c.PrimaryKey == primary { + return true + } + } + return false +} + +func hasIndex(table *db.Table, indexName string) bool { + for _, idx := range table.Indexes { + if idx.Name == indexName { + return true + } + } + return false +} + +func hasForeignKeyTo(table *db.Table, referencedTable string) bool { + for _, fk := range table.ForeignKeys { + if strings.EqualFold(fk.ReferencedTable, referencedTable) { + return true + } + } + return false +} + +func hasCompositeForeignKeyTo(table *db.Table, referencedTable string) bool { + for _, fk := range table.ForeignKeys { + if strings.EqualFold(fk.ReferencedTable, referencedTable) && + len(fk.Columns) >= 2 && + len(fk.ReferencedColumns) >= 2 { + return true + } + } + return false +} + +func hasColumnComment(table *db.Table, name, comment string) bool { + for _, c := range table.Columns { + if c.Name == name && c.Comment == comment { + return true + } + } + return false +} + +func hasColumnDefault(table *db.Table, name, want string) bool { + for _, c := range table.Columns { + if c.Name == name && strings.Contains(c.Default, want) { + return true + } + } + return false +}