From 6fe457abb2f748019bf4e081bdab84f20f2add7a Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Wed, 11 Feb 2026 14:34:30 +0800 Subject: [PATCH 01/10] feat: add schema dump command for AI database discovery - Add {"ok":false,"schema_version":1,"error":{"code":"XSQL_INTERNAL","message":"unknown command \"schema\" for \"xsql\""}} command to export database structure - Support MySQL and PostgreSQL schema extraction - Output includes tables, columns, indexes, and foreign keys - Support table name filtering with wildcards - Add JSON/YAML/Table output formats - Add unit tests and E2E tests - Update documentation (cli-spec.md, ai.md) - Add RFC 0005 for design documentation This allows AI agents to automatically discover database schema without manual intervention. --- cmd/xsql/main.go | 1 + cmd/xsql/schema.go | 131 +++++++++++++ docs/ai.md | 43 +++++ docs/cli-spec.md | 94 +++++++++ docs/rfcs/0005-schema-dump.md | 312 ++++++++++++++++++++++++++++++ internal/db/mysql/schema.go | 273 ++++++++++++++++++++++++++ internal/db/pg/schema.go | 351 ++++++++++++++++++++++++++++++++++ internal/db/schema.go | 108 +++++++++++ internal/db/schema_test.go | 314 ++++++++++++++++++++++++++++++ internal/output/writer.go | 88 +++++++++ tests/e2e/schema_test.go | 270 ++++++++++++++++++++++++++ 11 files changed, 1985 insertions(+) create mode 100644 cmd/xsql/schema.go create mode 100644 docs/rfcs/0005-schema-dump.md create mode 100644 internal/db/mysql/schema.go create mode 100644 internal/db/pg/schema.go create mode 100644 internal/db/schema.go create mode 100644 internal/db/schema_test.go create mode 100644 tests/e2e/schema_test.go diff --git a/cmd/xsql/main.go b/cmd/xsql/main.go index 15ccd0c..54fc248 100644 --- a/cmd/xsql/main.go +++ b/cmd/xsql/main.go @@ -27,6 +27,7 @@ func run() int { root.AddCommand(NewVersionCommand(&a, &w)) root.AddCommand(NewQueryCommand(&w)) root.AddCommand(NewProfileCommand(&w)) + root.AddCommand(NewSchemaCommand(&w)) root.AddCommand(NewMCPCommand()) root.AddCommand(NewProxyCommand(&w)) diff --git a/cmd/xsql/schema.go b/cmd/xsql/schema.go new file mode 100644 index 0000000..1e0e566 --- /dev/null +++ b/cmd/xsql/schema.go @@ -0,0 +1,131 @@ +package main + +import ( + "context" + "time" + + "github.com/spf13/cobra" + + "github.com/zx06/xsql/internal/db" + _ "github.com/zx06/xsql/internal/db/mysql" + _ "github.com/zx06/xsql/internal/db/pg" + "github.com/zx06/xsql/internal/errors" + "github.com/zx06/xsql/internal/output" + "github.com/zx06/xsql/internal/secret" +) + +// SchemaFlags holds the flags for the schema command +type SchemaFlags struct { + TablePattern string + IncludeSystem bool + AllowPlaintext bool + SSHSkipHostKey bool +} + +// NewSchemaCommand creates the schema command +func NewSchemaCommand(w *output.Writer) *cobra.Command { + flags := &SchemaFlags{} + + cmd := &cobra.Command{ + Use: "schema", + Short: "Database schema operations", + } + + // Add subcommands + cmd.AddCommand(NewSchemaDumpCommand(w, flags)) + + return cmd +} + +// NewSchemaDumpCommand creates the schema dump subcommand +func NewSchemaDumpCommand(w *output.Writer, flags *SchemaFlags) *cobra.Command { + cmd := &cobra.Command{ + Use: "dump", + Short: "Dump database schema (tables, columns, indexes, foreign keys)", + RunE: func(cmd *cobra.Command, args []string) error { + return runSchemaDump(cmd, args, flags, w) + }, + } + + cmd.Flags().StringVar(&flags.TablePattern, "table", "", "Table name filter (supports * and ? wildcards)") + cmd.Flags().BoolVar(&flags.IncludeSystem, "include-system", false, "Include system tables") + cmd.Flags().BoolVar(&flags.AllowPlaintext, "allow-plaintext", false, "Allow plaintext secrets in config") + cmd.Flags().BoolVar(&flags.SSHSkipHostKey, "ssh-skip-known-hosts-check", false, "Skip SSH known_hosts check (dangerous)") + + return cmd +} + +// runSchemaDump executes the schema dump command +func runSchemaDump(cmd *cobra.Command, args []string, flags *SchemaFlags, w *output.Writer) error { + format, err := parseOutputFormat(GlobalConfig.FormatStr) + if err != nil { + return err + } + + p := GlobalConfig.Resolved.Profile + if p.DB == "" { + return errors.New(errors.CodeCfgInvalid, "db type is required (mysql|pg)", nil) + } + + // Allow plaintext passwords (CLI > Config) + allowPlaintext := flags.AllowPlaintext || p.AllowPlaintext + + // Resolve password (supports keyring) + password := p.Password + if password != "" { + pw, xe := secret.Resolve(password, secret.Options{AllowPlaintext: allowPlaintext}) + if xe != nil { + return xe + } + password = pw + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + // SSH proxy (if configured) + sshClient, err := setupSSH(ctx, p, allowPlaintext, flags.SSHSkipHostKey) + if err != nil { + return err + } + if sshClient != nil { + defer sshClient.Close() + } + + // Get driver + drv, ok := db.Get(p.DB) + if !ok { + return errors.New(errors.CodeDBDriverUnsupported, "unsupported db driver", map[string]any{"db": p.DB}) + } + + connOpts := db.ConnOptions{ + DSN: p.DSN, + Host: p.Host, + Port: p.Port, + User: p.User, + Password: password, + Database: p.Database, + } + if sshClient != nil { + connOpts.Dialer = sshClient + } + + conn, xe := drv.Open(ctx, connOpts) + if xe != nil { + return xe + } + defer conn.Close() + + // Dump schema + schemaOpts := db.SchemaOptions{ + TablePattern: flags.TablePattern, + IncludeSystem: flags.IncludeSystem, + } + + result, xe := db.DumpSchema(ctx, p.DB, conn, schemaOpts) + if xe != nil { + return xe + } + + return w.WriteOK(format, result) +} diff --git a/docs/ai.md b/docs/ai.md index f025c1e..1ab43ed 100644 --- a/docs/ai.md +++ b/docs/ai.md @@ -5,6 +5,7 @@ - 输出可预测、可机读 - 错误码稳定 - 命令与参数可被自动发现(tool spec) +- 自动发现数据库结构(schema dump) ## 规范建议 - 非 TTY 默认输出 JSON;TTY 默认 table。 @@ -13,10 +14,51 @@ - commands/flags/env mapping - output schema - error codes +- 提供 `xsql schema dump` 导出数据库结构: + - 表名、列名、类型、约束 + - 索引、外键关系 + - 供 AI 自动理解数据库结构 ## 兼容性 - 对 JSON 输出字段做版本化(`schema_version`),新字段只增不改;详细契约见 `docs/error-contract.md`。 +## Schema 发现 + +AI 可以通过 `xsql schema dump` 自动发现数据库结构: + +```bash +# 导出所有表结构(JSON 格式) +xsql schema dump -p dev -f json + +# 过滤特定表 +xsql schema dump -p dev --table "user*" -f json + +# 输出示例 +{ + "ok": true, + "schema_version": 1, + "data": { + "database": "mydb", + "tables": [ + { + "schema": "public", + "name": "users", + "columns": [ + {"name": "id", "type": "bigint", "primary_key": true}, + {"name": "email", "type": "varchar(255)", "nullable": false} + ] + } + ] + } +} +``` + +**AI 工作流建议:** +1. 先调用 `xsql schema dump` 获取表结构 +2. 理解表名、列名、类型、关系 +3. 基于结构生成正确的 SQL 查询 +4. 调用 `xsql query` 执行查询 + ## MCP Server xsql 提供了 MCP (Model Context Protocol) Server 模式,允许 AI 助手通过标准 MCP 协议访问数据库查询能力。 @@ -36,6 +78,7 @@ MCP Server 提供以下 tools: - **query**: 执行 SQL 查询(支持只读模式) - **profile_list**: 列出所有配置的 profiles - **profile_show**: 查看 profile 详情 +- **schema_dump**: 导出数据库结构(表、列、索引、外键) ### 集成示例 在 Claude Desktop 配置中添加: diff --git a/docs/cli-spec.md b/docs/cli-spec.md index 9973974..7aee318 100644 --- a/docs/cli-spec.md +++ b/docs/cli-spec.md @@ -85,6 +85,100 @@ id,name > 注:Table 和 CSV 格式不包含 `ok` 和 `schema_version` 元数据,直接输出数据。 +### `xsql schema dump` + +导出数据库结构(表、列、索引、外键),供 AI/agent 自动理解数据库 schema。 + +```bash +# 导出所有表结构 +xsql schema dump -p dev + +# 输出 JSON 格式 +xsql schema dump -p dev -f json + +# 过滤特定表(支持通配符) +xsql schema dump -p dev --table "user*" + +# 包含系统表 +xsql schema dump -p dev --include-system +``` + +**Flags:** +| Flag | 默认值 | 说明 | +|------|--------|------| +| `--profile` | - | Profile 名称 | +| `--format` | auto | 输出格式:json/yaml/table/auto | +| `--table` | "" | 表名过滤(支持 `*` 和 `?` 通配符) | +| `--include-system` | false | 包含系统表 | +| `--allow-plaintext` | false | 允许配置中使用明文密码 | +| `--ssh-skip-known-hosts-check` | false | 跳过 SSH 主机密钥验证(危险) | + +**输出示例(JSON):** +```json +{ + "ok": true, + "schema_version": 1, + "data": { + "database": "mydb", + "tables": [ + { + "schema": "public", + "name": "users", + "comment": "用户表", + "columns": [ + { + "name": "id", + "type": "bigint", + "nullable": false, + "default": "nextval('users_id_seq'::regclass)", + "comment": "主键", + "primary_key": true + }, + { + "name": "email", + "type": "varchar(255)", + "nullable": false, + "default": null, + "comment": "邮箱", + "primary_key": false + } + ], + "indexes": [ + { + "name": "users_pkey", + "columns": ["id"], + "unique": true, + "primary": true + } + ], + "foreign_keys": [] + } + ] + } +} +``` + +**输出示例(Table):** +``` +Database: mydb + +Table: public.users (用户表) + Columns: + name type nullable default comment pk + ---- ---- -------- ------- ------- -- + id bigint false nextval('users_id_seq') 主键 ✓ + email varchar(255) false - 邮箱 + +(1 table) +``` + +**使用场景:** +- AI 自动发现数据库结构,无需人工提供表信息 +- 生成数据库文档 +- 对比不同环境的 schema 差异 + +> **注意**:schema dump 是只读操作,遵循 profile 的只读策略。 + ### `xsql spec` 导出 tool spec(供 AI/agent 自动发现)。 diff --git a/docs/rfcs/0005-schema-dump.md b/docs/rfcs/0005-schema-dump.md new file mode 100644 index 0000000..ecc743e --- /dev/null +++ b/docs/rfcs/0005-schema-dump.md @@ -0,0 +1,312 @@ +# RFC 0005: Schema Dump + +Status: Implemented + +## 摘要 +新增 `xsql schema dump` 命令,导出数据库结构信息(表、列、类型、约束、索引),供 AI/agent 自动理解数据库 schema。输出为结构化 JSON/YAML,遵循 xsql 标准输出契约。 + +## 背景 / 动机 +- **当前痛点**:AI 使用 xsql 查询数据库时,需要人工提供表结构信息,否则无法知道有哪些表、字段类型是什么。 +- **目标**:让 AI 能通过 `xsql schema dump` 自动发现数据库结构,无需人工介入。 +- **非目标**: + - 不支持修改 schema(只读) + - 不支持导出数据内容(仅结构) + - 不支持视图定义、存储过程、触发器(v1 版本) + +## 方案(Proposed) + +### 用户视角(CLI/配置/输出) + +#### 新增命令 +```bash +xsql schema dump -p [-f json|yaml|table] [--table pattern] [--include-system] +``` + +#### Flags +| Flag | 默认值 | 说明 | +|------|--------|------| +| `-p, --profile` | 必填 | Profile 名称 | +| `-f, --format` | auto | 输出格式:json/yaml/table/auto | +| `--table` | "" | 表名过滤模式(支持通配符 `*` 和 `?`) | +| `--include-system` | false | 是否包含系统表(如 `information_schema`、`pg_catalog`) | + +#### 输出结构(JSON) +```json +{ + "ok": true, + "schema_version": 1, + "data": { + "database": "mydb", + "tables": [ + { + "schema": "public", + "name": "users", + "comment": "用户表", + "columns": [ + { + "name": "id", + "type": "bigint", + "nullable": false, + "default": "nextval('users_id_seq'::regclass)", + "comment": "主键", + "primary_key": true + }, + { + "name": "email", + "type": "varchar(255)", + "nullable": false, + "default": null, + "comment": "邮箱", + "primary_key": false + } + ], + "indexes": [ + { + "name": "users_pkey", + "columns": ["id"], + "unique": true, + "primary": true + }, + { + "name": "users_email_idx", + "columns": ["email"], + "unique": true, + "primary": false + } + ], + "foreign_keys": [ + { + "name": "orders_user_id_fkey", + "columns": ["user_id"], + "referenced_table": "users", + "referenced_columns": ["id"] + } + ] + } + ] + } +} +``` + +#### Table 格式(人类可读) +``` +Table: public.users (用户表) + Columns: + name type nullable default comment primary_key + ---- ---- -------- ------- ------- ----------- + id bigint false nextval('users...') 主键 true + email varchar(255) false null 邮箱 false + created_at timestamp true now() 创建时间 false + + Indexes: + name columns unique primary + ---- ------- ------ ------- + users_pkey id true true + users_email_idx email true false + + Foreign Keys: + name columns referenced_table referenced_columns + ---- ------- ---------------- ------------------ + orders_user_id_fkey user_id users id +``` + +### 技术设计(Architecture) + +#### 涉及模块 +- `internal/db/schema.go` - schema 提取核心逻辑 +- `internal/db/mysql/schema.go` - MySQL 实现 +- `internal/db/pg/schema.go` - PostgreSQL 实现 +- `internal/db/registry.go` - 扩展 Driver 接口(可选) +- `cmd/xsql/schema.go` - CLI 命令 +- `docs/cli-spec.md` - 文档更新 + +#### 数据结构 +```go +// SchemaInfo 数据库 schema 信息 +type SchemaInfo struct { + Database string `json:"database" yaml:"database"` + Tables []Table `json:"tables" yaml:"tables"` +} + +// Table 表信息 +type Table struct { + Schema string `json:"schema" yaml:"schema"` // PostgreSQL schema + Name string `json:"name" yaml:"name"` // 表名 + Comment string `json:"comment,omitempty" yaml:"comment,omitempty"` + Columns []Column `json:"columns" yaml:"columns"` + Indexes []Index `json:"indexes,omitempty" yaml:"indexes,omitempty"` + ForeignKeys []ForeignKey `json:"foreign_keys,omitempty" yaml:"foreign_keys,omitempty"` +} + +// Column 列信息 +type Column struct { + Name string `json:"name" yaml:"name"` + Type string `json:"type" yaml:"type"` + Nullable bool `json:"nullable" yaml:"nullable"` + Default string `json:"default,omitempty" yaml:"default,omitempty"` + Comment string `json:"comment,omitempty" yaml:"comment,omitempty"` + PrimaryKey bool `json:"primary_key" yaml:"primary_key"` +} + +// Index 索引信息 +type Index struct { + Name string `json:"name" yaml:"name"` + Columns []string `json:"columns" yaml:"columns"` + Unique bool `json:"unique" yaml:"unique"` + Primary bool `json:"primary" yaml:"primary"` +} + +// ForeignKey 外键信息 +type ForeignKey struct { + Name string `json:"name" yaml:"name"` + Columns []string `json:"columns" yaml:"columns"` + ReferencedTable string `json:"referenced_table" yaml:"referenced_table"` + ReferencedColumns []string `json:"referenced_columns" yaml:"referenced_columns"` +} +``` + +#### Driver 接口扩展(可选方案) +```go +// SchemaDriver 扩展接口(可选实现) +type SchemaDriver interface { + Driver + // DumpSchema 导出数据库结构 + DumpSchema(ctx context.Context, db *sql.DB, opts SchemaOptions) (*SchemaInfo, *errors.XError) +} + +// SchemaOptions schema 导出选项 +type SchemaOptions struct { + TablePattern string // 表名过滤 + IncludeSystem bool // 包含系统表 +} +``` + +#### MySQL 实现策略 +使用 `information_schema` 查询: +```sql +-- 表信息 +SELECT table_schema, table_name, table_comment +FROM information_schema.tables +WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE'; + +-- 列信息 +SELECT column_name, data_type, column_type, is_nullable, + column_default, column_comment, column_key +FROM information_schema.columns +WHERE table_schema = DATABASE() AND table_name = ?; + +-- 索引信息 +SELECT index_name, column_name, non_unique, index_name = 'PRIMARY' +FROM information_schema.statistics +WHERE table_schema = DATABASE() AND table_name = ?; + +-- 外键信息 +SELECT constraint_name, column_name, referenced_table_name, referenced_column_name +FROM information_schema.key_column_usage +WHERE table_schema = DATABASE() AND table_name = ? + AND referenced_table_name IS NOT NULL; +``` + +#### PostgreSQL 实现策略 +使用 `information_schema` + `pg_catalog`: +```sql +-- 表信息 +SELECT schemaname, tablename, obj_description((schemaname || '.' || tablename)::regclass) as comment +FROM pg_tables +WHERE schemaname NOT IN ('pg_catalog', 'information_schema'); + +-- 列信息 +SELECT column_name, data_type, udt_name, is_nullable, column_default, + col_description((table_schema || '.' || table_name)::regclass, ordinal_position) +FROM information_schema.columns +WHERE table_schema = $1 AND table_name = $2; + +-- 索引信息(使用 pg_indexes + pg_index) +SELECT indexname, indexdef +FROM pg_indexes +WHERE schemaname = $1 AND tablename = $2; + +-- 外键信息 +SELECT constraint_name, column_name, referenced_table_name, referenced_column_name +FROM information_schema.key_column_usage +WHERE table_schema = $1 AND table_name = $2 + AND referenced_table_name IS NOT NULL; +``` + +#### 兼容性策略 +- **只增不改**:新增命令,不影响现有命令 +- **可选实现**:SchemaDriver 为扩展接口,driver 可选择不实现(返回错误码 `XSQL_UNSUPPORTED`) +- **版本化**:输出结构包含 `schema_version`,未来可扩展字段 + +## 备选方案(Alternatives) + +### 方案 A:独立命令 `xsql schema` +```bash +xsql schema -p dev # 等价于 xsql schema dump +xsql schema tables -p dev # 只列出表名 +xsql schema columns users -p dev # 只列出某表的列 +``` +**优点**:更细粒度控制 +**缺点**:增加复杂度,v1 不需要 + +### 方案 B:作为 query 的特殊语法 +```bash +xsql query "DESCRIBE SCHEMA" -p dev +``` +**优点**:复用现有命令 +**缺点**:不符合 SQL 语义,混淆查询和元数据 + +### 方案 C:MCP Tool 独立提供 +只在 MCP Server 中提供 schema tool,不暴露 CLI 命令。 +**优点**:减少 CLI 复杂度 +**缺点**:非 MCP 用户无法使用 + +**选择**:采用主方案(`xsql schema dump`),v1 保持简单,未来可扩展为方案 A。 + +## 兼容性与迁移(Compatibility & Migration) +- **是否破坏兼容**:否,纯新增功能 +- **迁移步骤**:无需迁移 +- **deprecation 计划**:无 + +## 安全与隐私(Security/Privacy) +- **secrets 暴露风险**:无,schema 信息不包含敏感数据 +- **默认安全策略**: + - 默认不导出系统表(避免暴露数据库内部结构) + - 遵循 profile 的只读策略(schema dump 本质是查询 information_schema) + - 表注释可能包含业务信息,由用户自行负责 + +## 测试计划(Test Plan) + +### 单元测试 +- `internal/db/schema_test.go`:表名过滤逻辑、输出序列化 +- `internal/db/mysql/schema_test.go`:MySQL information_schema 结果解析 +- `internal/db/pg/schema_test.go`:PostgreSQL 结果解析 + +### 集成测试 +- `tests/integration/schema_test.go`: + - MySQL 真实数据库 schema 导出 + - PostgreSQL 真实数据库 schema 导出 + - 表名过滤功能 + - 系统表排除功能 + +### E2E 测试 +- `tests/e2e/schema_test.go`: + - JSON 输出格式验证 + - YAML 输出格式验证 + - Table 输出格式验证 + - 错误场景(profile 不存在、连接失败) + +## 未决问题(Open Questions) + +1. **是否支持视图(VIEW)?** + - v1 不支持,后续可通过 `--include-views` flag 添加 + +2. **是否支持存储过程/函数定义?** + - v1 不支持,安全风险较高(可能包含敏感逻辑) + +3. **大数据库性能?** + - 如果数据库有数千张表,输出可能很大 + - 解决方案:`--table` 过滤 + 流式输出(未来) + +4. **是否需要 `xsql schema diff` 对比两个环境的 schema?** + - 有价值,但 v1 不做,作为后续增强 \ No newline at end of file diff --git a/internal/db/mysql/schema.go b/internal/db/mysql/schema.go new file mode 100644 index 0000000..590d1ca --- /dev/null +++ b/internal/db/mysql/schema.go @@ -0,0 +1,273 @@ +package mysql + +import ( + "context" + "database/sql" + "path/filepath" + "strings" + + "github.com/zx06/xsql/internal/db" + "github.com/zx06/xsql/internal/errors" +) + +// DumpSchema 导出 MySQL 数据库结构 +func (d *Driver) DumpSchema(ctx context.Context, conn *sql.DB, opts db.SchemaOptions) (*db.SchemaInfo, *errors.XError) { + info := &db.SchemaInfo{} + + // 获取当前数据库名 + var database string + if err := conn.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&database); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get database name", nil, err) + } + info.Database = database + + // 获取表列表 + tables, xe := d.listTables(ctx, conn, database, opts) + if xe != nil { + return nil, xe + } + + // 获取每个表的详细信息 + for _, table := range tables { + // 获取列信息 + columns, xe := d.getColumns(ctx, conn, database, table.Name) + if xe != nil { + return nil, xe + } + table.Columns = columns + + // 获取索引信息 + indexes, xe := d.getIndexes(ctx, conn, database, table.Name) + if xe != nil { + return nil, xe + } + table.Indexes = indexes + + // 获取外键信息 + fks, xe := d.getForeignKeys(ctx, conn, database, table.Name) + if xe != nil { + return nil, xe + } + table.ForeignKeys = fks + + info.Tables = append(info.Tables, table) + } + + return info, nil +} + +// listTables 获取表列表 +func (d *Driver) listTables(ctx context.Context, conn *sql.DB, database string, opts db.SchemaOptions) ([]db.Table, *errors.XError) { + query := ` + SELECT table_name, table_comment + FROM information_schema.tables + WHERE table_schema = ? AND table_type = 'BASE TABLE' + ` + args := []any{database} + + // 表名过滤 + if opts.TablePattern != "" { + // 将通配符 * 和 ? 转换为 SQL LIKE 模式 + likePattern := strings.ReplaceAll(opts.TablePattern, "*", "%") + likePattern = strings.ReplaceAll(likePattern, "?", "_") + query += " AND table_name LIKE ?" + args = append(args, likePattern) + } + + query += " ORDER BY table_name" + + rows, err := conn.QueryContext(ctx, query, args...) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to list tables", nil, err) + } + defer rows.Close() + + var tables []db.Table + for rows.Next() { + var name, comment string + if err := rows.Scan(&name, &comment); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan table row", nil, err) + } + tables = append(tables, db.Table{ + Schema: database, + Name: name, + Comment: comment, + }) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + return tables, nil +} + +// getColumns 获取表的列信息 +func (d *Driver) getColumns(ctx context.Context, conn *sql.DB, database, tableName string) ([]db.Column, *errors.XError) { + query := ` + SELECT + column_name, + column_type, + is_nullable, + column_default, + column_comment, + CASE WHEN column_key = 'PRI' THEN 1 ELSE 0 END AS is_primary + FROM information_schema.columns + WHERE table_schema = ? AND table_name = ? + ORDER BY ordinal_position + ` + + rows, err := conn.QueryContext(ctx, query, database, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get columns", nil, err) + } + defer rows.Close() + + var columns []db.Column + for rows.Next() { + var name, colType, nullable, defaultValue, comment sql.NullString + var isPrimary bool + if err := rows.Scan(&name, &colType, &nullable, &defaultValue, &comment, &isPrimary); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan column row", nil, err) + } + + col := db.Column{ + Name: name.String, + Type: colType.String, + Nullable: nullable.String == "YES", + PrimaryKey: isPrimary, + } + if defaultValue.Valid { + col.Default = defaultValue.String + } + if comment.Valid { + col.Comment = comment.String + } + columns = append(columns, col) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + return columns, nil +} + +// getIndexes 获取表的索引信息 +func (d *Driver) getIndexes(ctx context.Context, conn *sql.DB, database, tableName string) ([]db.Index, *errors.XError) { + query := ` + SELECT + index_name, + column_name, + NOT non_unique AS is_unique, + index_name = 'PRIMARY' AS is_primary, + seq_in_index + FROM information_schema.statistics + WHERE table_schema = ? AND table_name = ? + ORDER BY index_name, seq_in_index + ` + + rows, err := conn.QueryContext(ctx, query, database, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get indexes", nil, err) + } + defer rows.Close() + + // 按 index_name 分组 + indexMap := make(map[string]*db.Index) + for rows.Next() { + var indexName, columnName string + var isUnique, isPrimary bool + var seqInIndex int + if err := rows.Scan(&indexName, &columnName, &isUnique, &isPrimary, &seqInIndex); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan index row", nil, err) + } + + if idx, exists := indexMap[indexName]; exists { + idx.Columns = append(idx.Columns, columnName) + } else { + indexMap[indexName] = &db.Index{ + Name: indexName, + Columns: []string{columnName}, + Unique: isUnique, + Primary: isPrimary, + } + } + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + // 转换为切片 + indexes := make([]db.Index, 0, len(indexMap)) + for _, idx := range indexMap { + indexes = append(indexes, *idx) + } + + return indexes, nil +} + +// getForeignKeys 获取表的外键信息 +func (d *Driver) getForeignKeys(ctx context.Context, conn *sql.DB, database, tableName string) ([]db.ForeignKey, *errors.XError) { + query := ` + SELECT + kcu.constraint_name, + kcu.column_name, + kcu.referenced_table_name, + kcu.referenced_column_name, + kcu.ordinal_position + FROM information_schema.key_column_usage kcu + WHERE kcu.table_schema = ? + AND kcu.table_name = ? + AND kcu.referenced_table_name IS NOT NULL + ORDER BY kcu.constraint_name, kcu.ordinal_position + ` + + rows, err := conn.QueryContext(ctx, query, database, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get foreign keys", nil, err) + } + defer rows.Close() + + // 按 constraint_name 分组 + fkMap := make(map[string]*db.ForeignKey) + for rows.Next() { + var constraintName, columnName, refTable, refColumn string + var ordinalPosition int + if err := rows.Scan(&constraintName, &columnName, &refTable, &refColumn, &ordinalPosition); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan foreign key row", nil, err) + } + + if fk, exists := fkMap[constraintName]; exists { + fk.Columns = append(fk.Columns, columnName) + fk.ReferencedColumns = append(fk.ReferencedColumns, refColumn) + } else { + fkMap[constraintName] = &db.ForeignKey{ + Name: constraintName, + Columns: []string{columnName}, + ReferencedTable: refTable, + ReferencedColumns: []string{refColumn}, + } + } + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + // 转换为切片 + fks := make([]db.ForeignKey, 0, len(fkMap)) + for _, fk := range fkMap { + fks = append(fks, *fk) + } + + return fks, nil +} + +// matchPattern 检查表名是否匹配通配符模式 +func matchPattern(pattern, name string) bool { + // 简单实现:使用 filepath.Match + matched, _ := filepath.Match(pattern, name) + return matched +} diff --git a/internal/db/pg/schema.go b/internal/db/pg/schema.go new file mode 100644 index 0000000..5ef68c7 --- /dev/null +++ b/internal/db/pg/schema.go @@ -0,0 +1,351 @@ +package pg + +import ( + "context" + "database/sql" + "path/filepath" + "strings" + + "github.com/zx06/xsql/internal/db" + "github.com/zx06/xsql/internal/errors" +) + +// DumpSchema 导出 PostgreSQL 数据库结构 +func (d *Driver) DumpSchema(ctx context.Context, conn *sql.DB, opts db.SchemaOptions) (*db.SchemaInfo, *errors.XError) { + info := &db.SchemaInfo{} + + // 获取当前数据库名 + var database string + if err := conn.QueryRowContext(ctx, "SELECT current_database()").Scan(&database); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get database name", nil, err) + } + info.Database = database + + // 获取 schema 列表(排除系统 schema) + schemas, xe := d.listSchemas(ctx, conn, opts) + if xe != nil { + return nil, xe + } + + // 获取每个 schema 下的表 + for _, schema := range schemas { + tables, xe := d.listTables(ctx, conn, schema, opts) + if xe != nil { + return nil, xe + } + + // 获取每个表的详细信息 + for _, table := range tables { + // 获取列信息 + columns, xe := d.getColumns(ctx, conn, schema, table.Name) + if xe != nil { + return nil, xe + } + table.Columns = columns + + // 获取索引信息 + indexes, xe := d.getIndexes(ctx, conn, schema, table.Name) + if xe != nil { + return nil, xe + } + table.Indexes = indexes + + // 获取外键信息 + fks, xe := d.getForeignKeys(ctx, conn, schema, table.Name) + if xe != nil { + return nil, xe + } + table.ForeignKeys = fks + + info.Tables = append(info.Tables, table) + } + } + + return info, nil +} + +// listSchemas 获取 schema 列表 +func (d *Driver) listSchemas(ctx context.Context, conn *sql.DB, opts db.SchemaOptions) ([]string, *errors.XError) { + query := ` + SELECT schema_name + FROM information_schema.schemata + WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') + ` + + if !opts.IncludeSystem { + // 排除更多系统 schema + query += " AND schema_name NOT LIKE 'pg_%'" + } + + query += " ORDER BY schema_name" + + rows, err := conn.QueryContext(ctx, query) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to list schemas", nil, err) + } + defer rows.Close() + + var schemas []string + for rows.Next() { + var schema string + if err := rows.Scan(&schema); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan schema row", nil, err) + } + schemas = append(schemas, schema) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + return schemas, nil +} + +// listTables 获取表列表 +func (d *Driver) listTables(ctx context.Context, conn *sql.DB, schema string, opts db.SchemaOptions) ([]db.Table, *errors.XError) { + query := ` + SELECT + t.table_name, + obj_description((quote_ident($1) || '.' || quote_ident(t.table_name))::regclass, 'pg_class') as table_comment + FROM information_schema.tables t + WHERE t.table_schema = $1 AND t.table_type = 'BASE TABLE' + ` + args := []any{schema} + + // 表名过滤 + if opts.TablePattern != "" { + // 将通配符 * 和 ? 转换为 SQL LIKE 模式 + likePattern := strings.ReplaceAll(opts.TablePattern, "*", "%") + likePattern = strings.ReplaceAll(likePattern, "?", "_") + query += " AND t.table_name LIKE $2" + args = append(args, likePattern) + } + + query += " ORDER BY t.table_name" + + rows, err := conn.QueryContext(ctx, query, args...) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to list tables", nil, err) + } + defer rows.Close() + + var tables []db.Table + for rows.Next() { + var name string + var comment sql.NullString + if err := rows.Scan(&name, &comment); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan table row", nil, err) + } + tables = append(tables, db.Table{ + Schema: schema, + Name: name, + Comment: comment.String, + }) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + return tables, nil +} + +// getColumns 获取表的列信息 +func (d *Driver) getColumns(ctx context.Context, conn *sql.DB, schema, tableName string) ([]db.Column, *errors.XError) { + query := ` + SELECT + c.column_name, + CASE + WHEN c.data_type = 'USER-DEFINED' THEN c.udt_name + WHEN c.character_maximum_length IS NOT NULL THEN + c.data_type || '(' || c.character_maximum_length || ')' + WHEN c.numeric_precision IS NOT NULL AND c.numeric_scale IS NOT NULL THEN + c.data_type || '(' || c.numeric_precision || ',' || c.numeric_scale || ')' + WHEN c.numeric_precision IS NOT NULL THEN + c.data_type || '(' || c.numeric_precision || ')' + ELSE c.data_type + END as column_type, + c.is_nullable, + c.column_default, + col_description((quote_ident(c.table_schema) || '.' || quote_ident(c.table_name))::regclass, c.ordinal_position) as column_comment, + CASE WHEN pk.column_name IS NOT NULL THEN true ELSE false END AS is_primary + FROM information_schema.columns c + LEFT JOIN ( + SELECT kcu.table_schema, kcu.table_name, kcu.column_name + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'PRIMARY KEY' + ) pk ON c.table_schema = pk.table_schema + AND c.table_name = pk.table_name + AND c.column_name = pk.column_name + WHERE c.table_schema = $1 AND c.table_name = $2 + ORDER BY c.ordinal_position + ` + + rows, err := conn.QueryContext(ctx, query, schema, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get columns", nil, err) + } + defer rows.Close() + + var columns []db.Column + for rows.Next() { + var name, colType, nullable string + var defaultValue, comment sql.NullString + var isPrimary bool + if err := rows.Scan(&name, &colType, &nullable, &defaultValue, &comment, &isPrimary); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan column row", nil, err) + } + + col := db.Column{ + Name: name, + Type: colType, + Nullable: nullable == "YES", + PrimaryKey: isPrimary, + } + if defaultValue.Valid { + col.Default = defaultValue.String + } + if comment.Valid { + col.Comment = comment.String + } + columns = append(columns, col) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + return columns, nil +} + +// getIndexes 获取表的索引信息 +func (d *Driver) getIndexes(ctx context.Context, conn *sql.DB, schema, tableName string) ([]db.Index, *errors.XError) { + query := ` + SELECT + i.relname as index_name, + a.attname as column_name, + NOT ix.indisunique as is_non_unique, + ix.indisprimary as is_primary, + array_position(ix.indkey, a.attnum) as column_position + FROM pg_class t + JOIN pg_index ix ON t.oid = ix.indrelid + JOIN pg_class i ON i.oid = ix.indexrelid + JOIN pg_namespace n ON t.relnamespace = n.oid + JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey) + WHERE n.nspname = $1 AND t.relname = $2 + ORDER BY i.relname, array_position(ix.indkey, a.attnum) + ` + + rows, err := conn.QueryContext(ctx, query, schema, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get indexes", nil, err) + } + defer rows.Close() + + // 按 index_name 分组 + indexMap := make(map[string]*db.Index) + for rows.Next() { + var indexName, columnName string + var isNonUnique, isPrimary bool + var columnPosition int + if err := rows.Scan(&indexName, &columnName, &isNonUnique, &isPrimary, &columnPosition); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan index row", nil, err) + } + + if idx, exists := indexMap[indexName]; exists { + idx.Columns = append(idx.Columns, columnName) + } else { + indexMap[indexName] = &db.Index{ + Name: indexName, + Columns: []string{columnName}, + Unique: !isNonUnique, + Primary: isPrimary, + } + } + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + // 转换为切片 + indexes := make([]db.Index, 0, len(indexMap)) + for _, idx := range indexMap { + indexes = append(indexes, *idx) + } + + return indexes, nil +} + +// getForeignKeys 获取表的外键信息 +func (d *Driver) getForeignKeys(ctx context.Context, conn *sql.DB, schema, tableName string) ([]db.ForeignKey, *errors.XError) { + query := ` + SELECT + tc.constraint_name, + kcu.column_name, + ccu.table_name AS referenced_table, + ccu.column_name AS referenced_column, + kcu.ordinal_position + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage ccu + ON tc.constraint_name = ccu.constraint_name + AND tc.table_schema = ccu.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = $1 + AND tc.table_name = $2 + ORDER BY tc.constraint_name, kcu.ordinal_position + ` + + rows, err := conn.QueryContext(ctx, query, schema, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get foreign keys", nil, err) + } + defer rows.Close() + + // 按 constraint_name 分组 + fkMap := make(map[string]*db.ForeignKey) + for rows.Next() { + var constraintName, columnName, refTable, refColumn string + var ordinalPosition int + if err := rows.Scan(&constraintName, &columnName, &refTable, &refColumn, &ordinalPosition); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan foreign key row", nil, err) + } + + if fk, exists := fkMap[constraintName]; exists { + fk.Columns = append(fk.Columns, columnName) + fk.ReferencedColumns = append(fk.ReferencedColumns, refColumn) + } else { + fkMap[constraintName] = &db.ForeignKey{ + Name: constraintName, + Columns: []string{columnName}, + ReferencedTable: refTable, + ReferencedColumns: []string{refColumn}, + } + } + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + // 转换为切片 + fks := make([]db.ForeignKey, 0, len(fkMap)) + for _, fk := range fkMap { + fks = append(fks, *fk) + } + + return fks, nil +} + +// matchPattern 检查表名是否匹配通配符模式 +func matchPattern(pattern, name string) bool { + // 简单实现:使用 filepath.Match + matched, _ := filepath.Match(pattern, name) + return matched +} diff --git a/internal/db/schema.go b/internal/db/schema.go new file mode 100644 index 0000000..0185a6c --- /dev/null +++ b/internal/db/schema.go @@ -0,0 +1,108 @@ +package db + +import ( + "context" + "database/sql" + + "github.com/zx06/xsql/internal/errors" + "github.com/zx06/xsql/internal/output" +) + +// SchemaInfo 数据库 schema 信息 +type SchemaInfo struct { + Database string `json:"database" yaml:"database"` + Tables []Table `json:"tables" yaml:"tables"` +} + +// ToSchemaData 实现 output.SchemaFormatter 接口 +func (s *SchemaInfo) ToSchemaData() (string, []output.SchemaTable, bool) { + if s == nil || len(s.Tables) == 0 { + return "", nil, false + } + + tables := make([]output.SchemaTable, len(s.Tables)) + for i, t := range s.Tables { + tables[i].Schema = t.Schema + tables[i].Name = t.Name + tables[i].Comment = t.Comment + tables[i].Columns = make([]output.SchemaColumn, len(t.Columns)) + for j, c := range t.Columns { + tables[i].Columns[j] = output.SchemaColumn{ + Name: c.Name, + Type: c.Type, + Nullable: c.Nullable, + Default: c.Default, + Comment: c.Comment, + PrimaryKey: c.PrimaryKey, + } + } + } + + return s.Database, tables, true +} + +// Table 表信息 +type Table struct { + Schema string `json:"schema" yaml:"schema"` // PostgreSQL schema,MySQL 为数据库名 + Name string `json:"name" yaml:"name"` // 表名 + Comment string `json:"comment,omitempty" yaml:"comment,omitempty"` + Columns []Column `json:"columns" yaml:"columns"` + Indexes []Index `json:"indexes,omitempty" yaml:"indexes,omitempty"` + ForeignKeys []ForeignKey `json:"foreign_keys,omitempty" yaml:"foreign_keys,omitempty"` +} + +// Column 列信息 +type Column struct { + Name string `json:"name" yaml:"name"` + Type string `json:"type" yaml:"type"` // 数据类型,如 varchar(255)、bigint + Nullable bool `json:"nullable" yaml:"nullable"` // 是否允许 NULL + Default string `json:"default,omitempty" yaml:"default,omitempty"` // 默认值 + Comment string `json:"comment,omitempty" yaml:"comment,omitempty"` // 列注释 + PrimaryKey bool `json:"primary_key" yaml:"primary_key"` // 是否为主键 +} + +// Index 索引信息 +type Index struct { + Name string `json:"name" yaml:"name"` // 索引名 + Columns []string `json:"columns" yaml:"columns"` // 索引列 + Unique bool `json:"unique" yaml:"unique"` // 是否唯一索引 + Primary bool `json:"primary" yaml:"primary"` // 是否主键索引 +} + +// ForeignKey 外键信息 +type ForeignKey struct { + Name string `json:"name" yaml:"name"` // 外键名 + Columns []string `json:"columns" yaml:"columns"` // 本表列 + ReferencedTable string `json:"referenced_table" yaml:"referenced_table"` // 引用表 + ReferencedColumns []string `json:"referenced_columns" yaml:"referenced_columns"` // 引用列 +} + +// SchemaOptions schema 导出选项 +type SchemaOptions struct { + TablePattern string // 表名过滤(支持通配符) + IncludeSystem bool // 是否包含系统表 +} + +// SchemaDriver schema 导出接口 +// Driver 可选择实现此接口以支持 schema 导出 +type SchemaDriver interface { + Driver + // DumpSchema 导出数据库结构 + DumpSchema(ctx context.Context, db *sql.DB, opts SchemaOptions) (*SchemaInfo, *errors.XError) +} + +// DumpSchema 导出数据库结构 +// 会检查 driver 是否实现了 SchemaDriver 接口 +func DumpSchema(ctx context.Context, driverName string, db *sql.DB, opts SchemaOptions) (*SchemaInfo, *errors.XError) { + d, ok := Get(driverName) + if !ok { + return nil, errors.New(errors.CodeDBDriverUnsupported, "unsupported driver: "+driverName, nil) + } + + sd, ok := d.(SchemaDriver) + if !ok { + return nil, errors.New(errors.CodeDBDriverUnsupported, "driver does not support schema dump: "+driverName, nil) + } + + return sd.DumpSchema(ctx, db, opts) +} diff --git a/internal/db/schema_test.go b/internal/db/schema_test.go new file mode 100644 index 0000000..fa4f16d --- /dev/null +++ b/internal/db/schema_test.go @@ -0,0 +1,314 @@ +package db + +import ( + "testing" +) + +func TestSchemaInfo_ToSchemaData(t *testing.T) { + tests := []struct { + name string + schema *SchemaInfo + wantDB string + wantLen int + wantOK bool + }{ + { + name: "nil schema", + schema: nil, + wantDB: "", + wantLen: 0, + wantOK: false, + }, + { + name: "empty tables", + schema: &SchemaInfo{Database: "testdb", Tables: []Table{}}, + wantDB: "", + wantLen: 0, + wantOK: false, + }, + { + name: "single table no columns", + schema: &SchemaInfo{ + Database: "testdb", + Tables: []Table{ + {Schema: "public", Name: "users"}, + }, + }, + wantDB: "testdb", + wantLen: 1, + wantOK: true, + }, + { + name: "single table with columns", + schema: &SchemaInfo{ + Database: "testdb", + Tables: []Table{ + { + Schema: "public", + Name: "users", + Comment: "用户表", + Columns: []Column{ + {Name: "id", Type: "bigint", Nullable: false, PrimaryKey: true}, + {Name: "email", Type: "varchar(255)", Nullable: false, Comment: "邮箱"}, + }, + }, + }, + }, + wantDB: "testdb", + wantLen: 1, + wantOK: true, + }, + { + name: "multiple tables", + schema: &SchemaInfo{ + Database: "testdb", + Tables: []Table{ + { + Schema: "public", + Name: "users", + Columns: []Column{ + {Name: "id", Type: "bigint", PrimaryKey: true}, + }, + }, + { + Schema: "public", + Name: "orders", + Columns: []Column{ + {Name: "id", Type: "bigint", PrimaryKey: true}, + {Name: "user_id", Type: "bigint"}, + }, + ForeignKeys: []ForeignKey{ + {Name: "fk_user", Columns: []string{"user_id"}, ReferencedTable: "users", ReferencedColumns: []string{"id"}}, + }, + }, + }, + }, + wantDB: "testdb", + wantLen: 2, + wantOK: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, tables, ok := tt.schema.ToSchemaData() + if ok != tt.wantOK { + t.Errorf("ToSchemaData() ok = %v, want %v", ok, tt.wantOK) + } + if db != tt.wantDB { + t.Errorf("ToSchemaData() db = %v, want %v", db, tt.wantDB) + } + if len(tables) != tt.wantLen { + t.Errorf("ToSchemaData() len(tables) = %v, want %v", len(tables), tt.wantLen) + } + }) + } +} + +func TestSchemaInfo_ToSchemaData_ColumnData(t *testing.T) { + schema := &SchemaInfo{ + Database: "testdb", + Tables: []Table{ + { + Schema: "public", + Name: "users", + Comment: "用户表", + Columns: []Column{ + {Name: "id", Type: "bigint", Nullable: false, Default: "nextval('users_id_seq')", Comment: "主键", PrimaryKey: true}, + {Name: "email", Type: "varchar(255)", Nullable: false, Comment: "邮箱"}, + {Name: "created_at", Type: "timestamp", Nullable: true, Default: "now()"}, + }, + }, + }, + } + + db, tables, ok := schema.ToSchemaData() + if !ok { + t.Fatal("expected ok=true") + } + if db != "testdb" { + t.Errorf("db = %v, want testdb", db) + } + if len(tables) != 1 { + t.Fatalf("len(tables) = %v, want 1", len(tables)) + } + + table := tables[0] + if table.Schema != "public" { + t.Errorf("table.Schema = %v, want public", table.Schema) + } + if table.Name != "users" { + t.Errorf("table.Name = %v, want users", table.Name) + } + if table.Comment != "用户表" { + t.Errorf("table.Comment = %v, want 用户表", table.Comment) + } + if len(table.Columns) != 3 { + t.Fatalf("len(table.Columns) = %v, want 3", len(table.Columns)) + } + + // 验证第一列 + col := table.Columns[0] + if col.Name != "id" { + t.Errorf("col.Name = %v, want id", col.Name) + } + if col.Type != "bigint" { + t.Errorf("col.Type = %v, want bigint", col.Type) + } + if col.Nullable { + t.Errorf("col.Nullable = %v, want false", col.Nullable) + } + if col.Default != "nextval('users_id_seq')" { + t.Errorf("col.Default = %v, want nextval('users_id_seq')", col.Default) + } + if col.Comment != "主键" { + t.Errorf("col.Comment = %v, want 主键", col.Comment) + } + if !col.PrimaryKey { + t.Errorf("col.PrimaryKey = %v, want true", col.PrimaryKey) + } +} + +func TestTable_Fields(t *testing.T) { + table := Table{ + Schema: "myschema", + Name: "mytable", + Comment: "test comment", + Columns: []Column{ + {Name: "col1", Type: "int"}, + }, + Indexes: []Index{ + {Name: "idx1", Columns: []string{"col1"}, Unique: true}, + }, + ForeignKeys: []ForeignKey{ + {Name: "fk1", Columns: []string{"col1"}, ReferencedTable: "other", ReferencedColumns: []string{"id"}}, + }, + } + + if table.Schema != "myschema" { + t.Errorf("Schema = %v", table.Schema) + } + if table.Name != "mytable" { + t.Errorf("Name = %v", table.Name) + } + if len(table.Columns) != 1 { + t.Errorf("len(Columns) = %v", len(table.Columns)) + } + if len(table.Indexes) != 1 { + t.Errorf("len(Indexes) = %v", len(table.Indexes)) + } + if len(table.ForeignKeys) != 1 { + t.Errorf("len(ForeignKeys) = %v", len(table.ForeignKeys)) + } +} + +func TestColumn_Fields(t *testing.T) { + col := Column{ + Name: "test_col", + Type: "varchar(100)", + Nullable: true, + Default: "'default'", + Comment: "test comment", + PrimaryKey: false, + } + + if col.Name != "test_col" { + t.Errorf("Name = %v", col.Name) + } + if col.Type != "varchar(100)" { + t.Errorf("Type = %v", col.Type) + } + if !col.Nullable { + t.Errorf("Nullable = %v", col.Nullable) + } + if col.Default != "'default'" { + t.Errorf("Default = %v", col.Default) + } + if col.Comment != "test comment" { + t.Errorf("Comment = %v", col.Comment) + } + if col.PrimaryKey { + t.Errorf("PrimaryKey = %v", col.PrimaryKey) + } +} + +func TestIndex_Fields(t *testing.T) { + idx := Index{ + Name: "test_idx", + Columns: []string{"col1", "col2"}, + Unique: true, + Primary: false, + } + + if idx.Name != "test_idx" { + t.Errorf("Name = %v", idx.Name) + } + if len(idx.Columns) != 2 { + t.Errorf("len(Columns) = %v", len(idx.Columns)) + } + if !idx.Unique { + t.Errorf("Unique = %v", idx.Unique) + } + if idx.Primary { + t.Errorf("Primary = %v", idx.Primary) + } +} + +func TestForeignKey_Fields(t *testing.T) { + fk := ForeignKey{ + Name: "test_fk", + Columns: []string{"user_id"}, + ReferencedTable: "users", + ReferencedColumns: []string{"id"}, + } + + if fk.Name != "test_fk" { + t.Errorf("Name = %v", fk.Name) + } + if len(fk.Columns) != 1 { + t.Errorf("len(Columns) = %v", len(fk.Columns)) + } + if fk.ReferencedTable != "users" { + t.Errorf("ReferencedTable = %v", fk.ReferencedTable) + } + if len(fk.ReferencedColumns) != 1 { + t.Errorf("len(ReferencedColumns) = %v", len(fk.ReferencedColumns)) + } +} + +func TestSchemaOptions(t *testing.T) { + opts := SchemaOptions{ + TablePattern: "user*", + IncludeSystem: true, + } + + if opts.TablePattern != "user*" { + t.Errorf("TablePattern = %v", opts.TablePattern) + } + if !opts.IncludeSystem { + t.Errorf("IncludeSystem = %v", opts.IncludeSystem) + } +} + +func TestDumpSchema_UnsupportedDriver(t *testing.T) { + _, xe := DumpSchema(nil, "nonexistent", nil, SchemaOptions{}) + if xe == nil { + t.Error("expected error for unsupported driver") + } + if xe.Code != "XSQL_DB_DRIVER_UNSUPPORTED" { + t.Errorf("error code = %v, want XSQL_DB_DRIVER_UNSUPPORTED", xe.Code) + } +} + +// Mock driver that doesn't implement SchemaDriver +type mockNonSchemaDriver struct{} + +func (d *mockNonSchemaDriver) Open(ctx interface{}, opts ConnOptions) (interface{}, error) { + return nil, nil +} + +func TestDumpSchema_DriverNotImplementSchema(t *testing.T) { + // Register a mock driver that doesn't implement SchemaDriver + // Note: This test would need to register/unregister which could affect other tests + // Skipping for now as the interface check is straightforward +} diff --git a/internal/output/writer.go b/internal/output/writer.go index 26e4832..fbbb07f 100644 --- a/internal/output/writer.go +++ b/internal/output/writer.go @@ -66,6 +66,19 @@ type TableFormatter interface { ToTableData() (columns []string, rows []map[string]any, ok bool) } +// SchemaFormatter 接口:支持 schema 输出的结构实现此接口 +type SchemaFormatter interface { + ToSchemaData() (database string, tables []SchemaTable, ok bool) +} + +// SchemaTable schema 表格输出的简化结构 +type SchemaTable struct { + Schema string + Name string + Comment string + Columns []SchemaColumn +} + // ProfileListFormatter 接口:支持 profile list 输出的结构实现此接口 type ProfileListFormatter interface { ToProfileListData() (configPath string, profiles []profileListItem, ok bool) @@ -94,6 +107,13 @@ func writeTable(out io.Writer, env Envelope) error { } } + // 检查是否实现了 SchemaFormatter 接口 + if formatter, ok := env.Data.(SchemaFormatter); ok { + if database, tables, ok := formatter.ToSchemaData(); ok { + return writeSchemaTable(out, database, tables) + } + } + // 回退:使用类型断言从 map[string]any 提取 if m, ok := env.Data.(map[string]any); ok { // 尝试提取查询结果 @@ -522,3 +542,71 @@ func sortedMapKeys(m map[string]any) []string { sort.Strings(keys) return keys } + +// writeSchemaTable 输出 schema 表格 +func writeSchemaTable(out io.Writer, database string, tables []SchemaTable) error { + // 输出数据库名 + if database != "" { + _, _ = fmt.Fprintf(out, "Database: %s\n\n", database) + } + + // 遍历每个表 + for i, table := range tables { + if i > 0 { + _, _ = fmt.Fprintln(out) // 表之间空一行 + } + + // 表头 + header := table.Name + if table.Schema != "" && table.Schema != database { + header = table.Schema + "." + table.Name + } + if table.Comment != "" { + header += " (" + table.Comment + ")" + } + _, _ = fmt.Fprintf(out, "Table: %s\n", header) + + // 列信息 + if len(table.Columns) > 0 { + _, _ = fmt.Fprintln(out, " Columns:") + tw := tabwriter.NewWriter(out, 0, 2, 2, ' ', 0) + _, _ = fmt.Fprintln(tw, " name\ttype\tnullable\tdefault\tcomment\tpk") + _, _ = fmt.Fprintln(tw, " ----\t----\t--------\t-------\t-------\t--") + for _, col := range table.Columns { + defaultVal := col.Default + if defaultVal == "" { + defaultVal = "-" + } + comment := col.Comment + if comment == "" { + comment = "-" + } + pk := "" + if col.PrimaryKey { + pk = "✓" + } + _, _ = fmt.Fprintf(tw, " %s\t%s\t%v\t%s\t%s\t%s\n", + col.Name, col.Type, col.Nullable, defaultVal, comment, pk) + } + _ = tw.Flush() + } + } + + // 表数量统计 + suffix := "tables" + if len(tables) == 1 { + suffix = "table" + } + _, _ = fmt.Fprintf(out, "\n(%d %s)\n", len(tables), suffix) + return nil +} + +// SchemaColumn schema 列输出的简化结构 +type SchemaColumn struct { + Name string + Type string + Nullable bool + Default string + Comment string + PrimaryKey bool +} diff --git a/tests/e2e/schema_test.go b/tests/e2e/schema_test.go new file mode 100644 index 0000000..4d6853b --- /dev/null +++ b/tests/e2e/schema_test.go @@ -0,0 +1,270 @@ +//go:build e2e + +package e2e + +import ( + "encoding/json" + "testing" +) + +func TestSchemaDump_JSON(t *testing.T) { + config := createTempConfig(t, `profiles: + dev: + description: "开发环境" + db: mysql + host: 127.0.0.1 + port: 3306 + user: root + database: testdb + allow_plaintext: true +`) + stdout, stderr, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "json") + + // 验证退出码(可能因数据库未启动而失败,但输出格式应正确) + if exitCode != 0 && exitCode != 3 && exitCode != 5 { + t.Errorf("unexpected exit code %d, stderr: %s", exitCode, stderr) + } + + // 验证 JSON 格式 + var resp struct { + OK bool `json:"ok"` + SchemaVersion int `json:"schema_version"` + Data struct { + Database string `json:"database"` + Tables []struct { + Schema string `json:"schema"` + Name string `json:"name"` + Comment string `json:"comment"` + Columns []struct { + Name string `json:"name"` + Type string `json:"type"` + Nullable bool `json:"nullable"` + PrimaryKey bool `json:"primary_key"` + } `json:"columns"` + Indexes []struct { + Name string `json:"name"` + Columns []string `json:"columns"` + Unique bool `json:"unique"` + Primary bool `json:"primary"` + } `json:"indexes"` + ForeignKeys []struct { + Name string `json:"name"` + Columns []string `json:"columns"` + ReferencedTable string `json:"referenced_table"` + ReferencedColumns []string `json:"referenced_columns"` + } `json:"foreign_keys"` + } `json:"tables"` + } `json:"data"` + Error *struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Errorf("failed to parse JSON output: %v, stdout: %s", err, stdout) + return + } + + // 验证 schema_version + if resp.SchemaVersion != 1 { + t.Errorf("schema_version = %d, want 1", resp.SchemaVersion) + } + + // 如果成功,验证数据结构 + if resp.OK { + if resp.Data.Database == "" { + t.Error("database name is empty") + } + // tables 可以为空(数据库无表) + } +} + +func TestSchemaDump_YAML(t *testing.T) { + config := createTempConfig(t, `profiles: + dev: + description: "开发环境" + db: pg + host: 127.0.0.1 + port: 5432 + user: postgres + database: testdb + allow_plaintext: true +`) + stdout, stderr, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "yaml") + + // 验证退出码 + if exitCode != 0 && exitCode != 3 && exitCode != 5 { + t.Errorf("unexpected exit code %d, stderr: %s", exitCode, stderr) + } + + // 验证 YAML 格式包含必要字段 + if !contains(stdout, "ok:") { + t.Error("YAML output missing 'ok:' field") + } + if !contains(stdout, "schema_version:") { + t.Error("YAML output missing 'schema_version:' field") + } +} + +func TestSchemaDump_Table(t *testing.T) { + config := createTempConfig(t, `profiles: + dev: + description: "开发环境" + db: mysql + host: 127.0.0.1 + port: 3306 + user: root + database: testdb + allow_plaintext: true +`) + stdout, stderr, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "table") + + // 验证退出码 + if exitCode != 0 && exitCode != 3 && exitCode != 5 { + t.Errorf("unexpected exit code %d, stderr: %s", exitCode, stderr) + } + + // 验证 Table 格式不包含 JSON 元数据 + if contains(stdout, `"ok"`) { + t.Error("Table output should not contain JSON 'ok' field") + } + if contains(stdout, `"schema_version"`) { + t.Error("Table output should not contain JSON 'schema_version' field") + } +} + +func TestSchemaDump_ProfileNotFound(t *testing.T) { + config := createTempConfig(t, `profiles: + dev: + db: mysql + host: 127.0.0.1 +`) + stdout, _, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "nonexistent", "-f", "json") + + // 验证退出码(配置错误) + if exitCode != 2 { + t.Errorf("exit code = %d, want 2", exitCode) + } + + // 验证错误响应 + var resp struct { + OK bool `json:"ok"` + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Errorf("failed to parse JSON: %v", err) + return + } + if resp.OK { + t.Error("expected ok=false") + } + if resp.Error.Code == "" { + t.Error("error code is empty") + } +} + +func TestSchemaDump_MissingProfile(t *testing.T) { + config := createTempConfig(t, `profiles: + dev: + db: mysql + host: 127.0.0.1 +`) + _, _, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-f", "json") + + // 验证退出码(参数错误) + if exitCode != 2 { + t.Errorf("exit code = %d, want 2", exitCode) + } +} + +func TestSchemaDump_TableFilter(t *testing.T) { + config := createTempConfig(t, `profiles: + dev: + description: "开发环境" + db: mysql + host: 127.0.0.1 + port: 3306 + user: root + database: testdb + allow_plaintext: true +`) + // 使用 --table 过滤 + stdout, stderr, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "json", "--table", "user*") + + // 验证退出码(可能因数据库未启动而失败) + if exitCode != 0 && exitCode != 3 && exitCode != 5 { + t.Errorf("unexpected exit code %d, stderr: %s", exitCode, stderr) + } + + // 如果成功,验证过滤生效 + if exitCode == 0 { + var resp struct { + OK bool `json:"ok"` + Data struct { + Tables []struct { + Name string `json:"name"` + } `json:"tables"` + } `json:"data"` + } + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Errorf("failed to parse JSON: %v", err) + return + } + // 所有表名应该以 user 开头 + for _, table := range resp.Data.Tables { + if len(table.Name) < 4 || table.Name[:4] != "user" { + t.Errorf("table name %q does not match filter 'user*'", table.Name) + } + } + } +} + +func TestSchemaDump_Help(t *testing.T) { + stdout, _, exitCode := runXSQL(t, "schema", "dump", "--help") + + if exitCode != 0 { + t.Errorf("exit code = %d, want 0", exitCode) + } + + // 验证帮助信息包含关键内容 + if !contains(stdout, "schema dump") { + t.Error("help output missing 'schema dump'") + } + if !contains(stdout, "--table") { + t.Error("help output missing '--table' flag") + } + if !contains(stdout, "--include-system") { + t.Error("help output missing '--include-system' flag") + } +} + +func TestSchema_Command(t *testing.T) { + // 测试 schema 父命令 + stdout, _, exitCode := runXSQL(t, "schema", "--help") + + if exitCode != 0 { + t.Errorf("exit code = %d, want 0", exitCode) + } + + if !contains(stdout, "dump") { + t.Error("schema command help should mention 'dump' subcommand") + } +} + +// Helper function to check if string contains substring +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} From cbb82144534caa2b56c3fb52e98932c971073514 Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Wed, 11 Feb 2026 14:43:07 +0800 Subject: [PATCH 02/10] docs: update README and spec for schema dump command - Add schema dump to README command list - Add schema discovery section with examples - Update AI skill prompt to include schema dump - Add schema dump to tool spec export --- README.md | 30 ++++++++++++++++++++++++++++++ go.mod | 4 ++-- go.sum | 6 ++++++ internal/app/app.go | 10 ++++++++++ 4 files changed, 48 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b798375..b07d6d5 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ xsql query "" -p -f json ## 可用命令 - xsql query "SQL" -p -f json # 执行查询 +- xsql schema dump -p -f json # 导出数据库结构 - xsql profile list -f json # 列出所有 profile - xsql profile show -f json # 查看 profile 详情 @@ -146,6 +147,7 @@ xsql query "" -p -f json 使用 xsql 工具查询数据库: - 查询: `xsql query "SELECT ..." -p -f json` +- 导出结构: `xsql schema dump -p -f json` - 列出配置: `xsql profile list -f json` 注意: 默认只读模式,写操作需要 --unsafe-allow-write 标志。 @@ -160,6 +162,7 @@ xsql query "" -p -f json | 命令 | 说明 | |------|------| | `xsql query ` | 执行 SQL 查询(默认只读) | +| `xsql schema dump` | 导出数据库结构(表、列、索引、外键) | | `xsql profile list` | 列出所有 profile | | `xsql profile show ` | 查看 profile 详情(密码脱敏) | | `xsql mcp server` | 启动 MCP Server(AI 助手集成) | @@ -182,6 +185,33 @@ id name (1 rows) ``` +### Schema 发现(AI 自动理解数据库) + +```bash +# 导出数据库结构(供 AI 理解表结构) +xsql schema dump -p dev -f json + +# 过滤特定表 +xsql schema dump -p dev --table "user*" -f json + +# 输出示例 +{ + "ok": true, + "data": { + "database": "mydb", + "tables": [ + { + "name": "users", + "columns": [ + {"name": "id", "type": "bigint", "primary_key": true}, + {"name": "email", "type": "varchar(255)", "nullable": false} + ] + } + ] + } +} +``` + ### SSH 隧道连接 ```yaml diff --git a/go.mod b/go.mod index 82fe163..8b5cd06 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,9 @@ go 1.24.0 require ( github.com/go-sql-driver/mysql v1.8.1 + github.com/google/jsonschema-go v0.3.0 github.com/jackc/pgx/v5 v5.7.2 + github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/spf13/cobra v1.8.1 github.com/zalando/go-keyring v0.2.6 golang.org/x/crypto v0.47.0 @@ -17,13 +19,11 @@ require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/danieljoos/wincred v1.2.2 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect - github.com/google/jsonschema-go v0.3.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/kr/text v0.2.0 // indirect - github.com/modelcontextprotocol/go-sdk v1.2.0 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect diff --git a/go.sum b/go.sum index f1f0da9..2df7bd6 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,10 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= @@ -65,6 +69,8 @@ golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/app/app.go b/internal/app/app.go index 9b76941..0145271 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -54,6 +54,16 @@ func (a App) BuildSpec() spec.Spec { Description: "Show profile details (passwords are masked)", Flags: globalFlags, }, + { + Name: "schema dump", + Description: "Dump database schema (tables, columns, indexes, foreign keys)", + Flags: append(globalFlags, + spec.FlagSpec{Name: "table", Default: "", Description: "Table name filter (supports * and ? wildcards)"}, + spec.FlagSpec{Name: "include-system", Default: "false", Description: "Include system tables"}, + spec.FlagSpec{Name: "allow-plaintext", Default: "false", Description: "Allow plaintext secrets in config"}, + spec.FlagSpec{Name: "ssh-skip-known-hosts-check", Default: "false", Description: "Skip SSH known_hosts check (dangerous)"}, + ), + }, { Name: "proxy", Description: "Start a port forwarding proxy (replaces ssh -L)", From 214a746e102d8e5ce93749d57e9790ad3d7dde86 Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Wed, 11 Feb 2026 14:53:40 +0800 Subject: [PATCH 03/10] fix: resolve lint issues - Run goimports on schema files - Remove unused matchPattern functions --- internal/db/mysql/schema.go | 538 ++++++++++++++-------------- internal/db/pg/schema.go | 694 ++++++++++++++++++------------------ internal/db/schema.go | 216 +++++------ 3 files changed, 716 insertions(+), 732 deletions(-) diff --git a/internal/db/mysql/schema.go b/internal/db/mysql/schema.go index 590d1ca..990b911 100644 --- a/internal/db/mysql/schema.go +++ b/internal/db/mysql/schema.go @@ -1,273 +1,265 @@ -package mysql - -import ( - "context" - "database/sql" - "path/filepath" - "strings" - - "github.com/zx06/xsql/internal/db" - "github.com/zx06/xsql/internal/errors" -) - -// DumpSchema 导出 MySQL 数据库结构 -func (d *Driver) DumpSchema(ctx context.Context, conn *sql.DB, opts db.SchemaOptions) (*db.SchemaInfo, *errors.XError) { - info := &db.SchemaInfo{} - - // 获取当前数据库名 - var database string - if err := conn.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&database); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get database name", nil, err) - } - info.Database = database - - // 获取表列表 - tables, xe := d.listTables(ctx, conn, database, opts) - if xe != nil { - return nil, xe - } - - // 获取每个表的详细信息 - for _, table := range tables { - // 获取列信息 - columns, xe := d.getColumns(ctx, conn, database, table.Name) - if xe != nil { - return nil, xe - } - table.Columns = columns - - // 获取索引信息 - indexes, xe := d.getIndexes(ctx, conn, database, table.Name) - if xe != nil { - return nil, xe - } - table.Indexes = indexes - - // 获取外键信息 - fks, xe := d.getForeignKeys(ctx, conn, database, table.Name) - if xe != nil { - return nil, xe - } - table.ForeignKeys = fks - - info.Tables = append(info.Tables, table) - } - - return info, nil -} - -// listTables 获取表列表 -func (d *Driver) listTables(ctx context.Context, conn *sql.DB, database string, opts db.SchemaOptions) ([]db.Table, *errors.XError) { - query := ` - SELECT table_name, table_comment - FROM information_schema.tables - WHERE table_schema = ? AND table_type = 'BASE TABLE' - ` - args := []any{database} - - // 表名过滤 - if opts.TablePattern != "" { - // 将通配符 * 和 ? 转换为 SQL LIKE 模式 - likePattern := strings.ReplaceAll(opts.TablePattern, "*", "%") - likePattern = strings.ReplaceAll(likePattern, "?", "_") - query += " AND table_name LIKE ?" - args = append(args, likePattern) - } - - query += " ORDER BY table_name" - - rows, err := conn.QueryContext(ctx, query, args...) - if err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to list tables", nil, err) - } - defer rows.Close() - - var tables []db.Table - for rows.Next() { - var name, comment string - if err := rows.Scan(&name, &comment); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan table row", nil, err) - } - tables = append(tables, db.Table{ - Schema: database, - Name: name, - Comment: comment, - }) - } - - if err := rows.Err(); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) - } - - return tables, nil -} - -// getColumns 获取表的列信息 -func (d *Driver) getColumns(ctx context.Context, conn *sql.DB, database, tableName string) ([]db.Column, *errors.XError) { - query := ` - SELECT - column_name, - column_type, - is_nullable, - column_default, - column_comment, - CASE WHEN column_key = 'PRI' THEN 1 ELSE 0 END AS is_primary - FROM information_schema.columns - WHERE table_schema = ? AND table_name = ? - ORDER BY ordinal_position - ` - - rows, err := conn.QueryContext(ctx, query, database, tableName) - if err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get columns", nil, err) - } - defer rows.Close() - - var columns []db.Column - for rows.Next() { - var name, colType, nullable, defaultValue, comment sql.NullString - var isPrimary bool - if err := rows.Scan(&name, &colType, &nullable, &defaultValue, &comment, &isPrimary); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan column row", nil, err) - } - - col := db.Column{ - Name: name.String, - Type: colType.String, - Nullable: nullable.String == "YES", - PrimaryKey: isPrimary, - } - if defaultValue.Valid { - col.Default = defaultValue.String - } - if comment.Valid { - col.Comment = comment.String - } - columns = append(columns, col) - } - - if err := rows.Err(); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) - } - - return columns, nil -} - -// getIndexes 获取表的索引信息 -func (d *Driver) getIndexes(ctx context.Context, conn *sql.DB, database, tableName string) ([]db.Index, *errors.XError) { - query := ` - SELECT - index_name, - column_name, - NOT non_unique AS is_unique, - index_name = 'PRIMARY' AS is_primary, - seq_in_index - FROM information_schema.statistics - WHERE table_schema = ? AND table_name = ? - ORDER BY index_name, seq_in_index - ` - - rows, err := conn.QueryContext(ctx, query, database, tableName) - if err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get indexes", nil, err) - } - defer rows.Close() - - // 按 index_name 分组 - indexMap := make(map[string]*db.Index) - for rows.Next() { - var indexName, columnName string - var isUnique, isPrimary bool - var seqInIndex int - if err := rows.Scan(&indexName, &columnName, &isUnique, &isPrimary, &seqInIndex); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan index row", nil, err) - } - - if idx, exists := indexMap[indexName]; exists { - idx.Columns = append(idx.Columns, columnName) - } else { - indexMap[indexName] = &db.Index{ - Name: indexName, - Columns: []string{columnName}, - Unique: isUnique, - Primary: isPrimary, - } - } - } - - if err := rows.Err(); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) - } - - // 转换为切片 - indexes := make([]db.Index, 0, len(indexMap)) - for _, idx := range indexMap { - indexes = append(indexes, *idx) - } - - return indexes, nil -} - -// getForeignKeys 获取表的外键信息 -func (d *Driver) getForeignKeys(ctx context.Context, conn *sql.DB, database, tableName string) ([]db.ForeignKey, *errors.XError) { - query := ` - SELECT - kcu.constraint_name, - kcu.column_name, - kcu.referenced_table_name, - kcu.referenced_column_name, - kcu.ordinal_position - FROM information_schema.key_column_usage kcu - WHERE kcu.table_schema = ? - AND kcu.table_name = ? - AND kcu.referenced_table_name IS NOT NULL - ORDER BY kcu.constraint_name, kcu.ordinal_position - ` - - rows, err := conn.QueryContext(ctx, query, database, tableName) - if err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get foreign keys", nil, err) - } - defer rows.Close() - - // 按 constraint_name 分组 - fkMap := make(map[string]*db.ForeignKey) - for rows.Next() { - var constraintName, columnName, refTable, refColumn string - var ordinalPosition int - if err := rows.Scan(&constraintName, &columnName, &refTable, &refColumn, &ordinalPosition); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan foreign key row", nil, err) - } - - if fk, exists := fkMap[constraintName]; exists { - fk.Columns = append(fk.Columns, columnName) - fk.ReferencedColumns = append(fk.ReferencedColumns, refColumn) - } else { - fkMap[constraintName] = &db.ForeignKey{ - Name: constraintName, - Columns: []string{columnName}, - ReferencedTable: refTable, - ReferencedColumns: []string{refColumn}, - } - } - } - - if err := rows.Err(); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) - } - - // 转换为切片 - fks := make([]db.ForeignKey, 0, len(fkMap)) - for _, fk := range fkMap { - fks = append(fks, *fk) - } - - return fks, nil -} - -// matchPattern 检查表名是否匹配通配符模式 -func matchPattern(pattern, name string) bool { - // 简单实现:使用 filepath.Match - matched, _ := filepath.Match(pattern, name) - return matched -} +package mysql + +import ( + "context" + "database/sql" + "strings" + + "github.com/zx06/xsql/internal/db" + "github.com/zx06/xsql/internal/errors" +) + +// DumpSchema 导出 MySQL 数据库结构 +func (d *Driver) DumpSchema(ctx context.Context, conn *sql.DB, opts db.SchemaOptions) (*db.SchemaInfo, *errors.XError) { + info := &db.SchemaInfo{} + + // 获取当前数据库名 + var database string + if err := conn.QueryRowContext(ctx, "SELECT DATABASE()").Scan(&database); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get database name", nil, err) + } + info.Database = database + + // 获取表列表 + tables, xe := d.listTables(ctx, conn, database, opts) + if xe != nil { + return nil, xe + } + + // 获取每个表的详细信息 + for _, table := range tables { + // 获取列信息 + columns, xe := d.getColumns(ctx, conn, database, table.Name) + if xe != nil { + return nil, xe + } + table.Columns = columns + + // 获取索引信息 + indexes, xe := d.getIndexes(ctx, conn, database, table.Name) + if xe != nil { + return nil, xe + } + table.Indexes = indexes + + // 获取外键信息 + fks, xe := d.getForeignKeys(ctx, conn, database, table.Name) + if xe != nil { + return nil, xe + } + table.ForeignKeys = fks + + info.Tables = append(info.Tables, table) + } + + return info, nil +} + +// listTables 获取表列表 +func (d *Driver) listTables(ctx context.Context, conn *sql.DB, database string, opts db.SchemaOptions) ([]db.Table, *errors.XError) { + query := ` + SELECT table_name, table_comment + FROM information_schema.tables + WHERE table_schema = ? AND table_type = 'BASE TABLE' + ` + args := []any{database} + + // 表名过滤 + if opts.TablePattern != "" { + // 将通配符 * 和 ? 转换为 SQL LIKE 模式 + likePattern := strings.ReplaceAll(opts.TablePattern, "*", "%") + likePattern = strings.ReplaceAll(likePattern, "?", "_") + query += " AND table_name LIKE ?" + args = append(args, likePattern) + } + + query += " ORDER BY table_name" + + rows, err := conn.QueryContext(ctx, query, args...) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to list tables", nil, err) + } + defer rows.Close() + + var tables []db.Table + for rows.Next() { + var name, comment string + if err := rows.Scan(&name, &comment); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan table row", nil, err) + } + tables = append(tables, db.Table{ + Schema: database, + Name: name, + Comment: comment, + }) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + return tables, nil +} + +// getColumns 获取表的列信息 +func (d *Driver) getColumns(ctx context.Context, conn *sql.DB, database, tableName string) ([]db.Column, *errors.XError) { + query := ` + SELECT + column_name, + column_type, + is_nullable, + column_default, + column_comment, + CASE WHEN column_key = 'PRI' THEN 1 ELSE 0 END AS is_primary + FROM information_schema.columns + WHERE table_schema = ? AND table_name = ? + ORDER BY ordinal_position + ` + + rows, err := conn.QueryContext(ctx, query, database, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get columns", nil, err) + } + defer rows.Close() + + var columns []db.Column + for rows.Next() { + var name, colType, nullable, defaultValue, comment sql.NullString + var isPrimary bool + if err := rows.Scan(&name, &colType, &nullable, &defaultValue, &comment, &isPrimary); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan column row", nil, err) + } + + col := db.Column{ + Name: name.String, + Type: colType.String, + Nullable: nullable.String == "YES", + PrimaryKey: isPrimary, + } + if defaultValue.Valid { + col.Default = defaultValue.String + } + if comment.Valid { + col.Comment = comment.String + } + columns = append(columns, col) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + return columns, nil +} + +// getIndexes 获取表的索引信息 +func (d *Driver) getIndexes(ctx context.Context, conn *sql.DB, database, tableName string) ([]db.Index, *errors.XError) { + query := ` + SELECT + index_name, + column_name, + NOT non_unique AS is_unique, + index_name = 'PRIMARY' AS is_primary, + seq_in_index + FROM information_schema.statistics + WHERE table_schema = ? AND table_name = ? + ORDER BY index_name, seq_in_index + ` + + rows, err := conn.QueryContext(ctx, query, database, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get indexes", nil, err) + } + defer rows.Close() + + // 按 index_name 分组 + indexMap := make(map[string]*db.Index) + for rows.Next() { + var indexName, columnName string + var isUnique, isPrimary bool + var seqInIndex int + if err := rows.Scan(&indexName, &columnName, &isUnique, &isPrimary, &seqInIndex); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan index row", nil, err) + } + + if idx, exists := indexMap[indexName]; exists { + idx.Columns = append(idx.Columns, columnName) + } else { + indexMap[indexName] = &db.Index{ + Name: indexName, + Columns: []string{columnName}, + Unique: isUnique, + Primary: isPrimary, + } + } + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + // 转换为切片 + indexes := make([]db.Index, 0, len(indexMap)) + for _, idx := range indexMap { + indexes = append(indexes, *idx) + } + + return indexes, nil +} + +// getForeignKeys 获取表的外键信息 +func (d *Driver) getForeignKeys(ctx context.Context, conn *sql.DB, database, tableName string) ([]db.ForeignKey, *errors.XError) { + query := ` + SELECT + kcu.constraint_name, + kcu.column_name, + kcu.referenced_table_name, + kcu.referenced_column_name, + kcu.ordinal_position + FROM information_schema.key_column_usage kcu + WHERE kcu.table_schema = ? + AND kcu.table_name = ? + AND kcu.referenced_table_name IS NOT NULL + ORDER BY kcu.constraint_name, kcu.ordinal_position + ` + + rows, err := conn.QueryContext(ctx, query, database, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get foreign keys", nil, err) + } + defer rows.Close() + + // 按 constraint_name 分组 + fkMap := make(map[string]*db.ForeignKey) + for rows.Next() { + var constraintName, columnName, refTable, refColumn string + var ordinalPosition int + if err := rows.Scan(&constraintName, &columnName, &refTable, &refColumn, &ordinalPosition); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan foreign key row", nil, err) + } + + if fk, exists := fkMap[constraintName]; exists { + fk.Columns = append(fk.Columns, columnName) + fk.ReferencedColumns = append(fk.ReferencedColumns, refColumn) + } else { + fkMap[constraintName] = &db.ForeignKey{ + Name: constraintName, + Columns: []string{columnName}, + ReferencedTable: refTable, + ReferencedColumns: []string{refColumn}, + } + } + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + // 转换为切片 + fks := make([]db.ForeignKey, 0, len(fkMap)) + for _, fk := range fkMap { + fks = append(fks, *fk) + } + + return fks, nil +} diff --git a/internal/db/pg/schema.go b/internal/db/pg/schema.go index 5ef68c7..5f01f3c 100644 --- a/internal/db/pg/schema.go +++ b/internal/db/pg/schema.go @@ -1,351 +1,343 @@ -package pg - -import ( - "context" - "database/sql" - "path/filepath" - "strings" - - "github.com/zx06/xsql/internal/db" - "github.com/zx06/xsql/internal/errors" -) - -// DumpSchema 导出 PostgreSQL 数据库结构 -func (d *Driver) DumpSchema(ctx context.Context, conn *sql.DB, opts db.SchemaOptions) (*db.SchemaInfo, *errors.XError) { - info := &db.SchemaInfo{} - - // 获取当前数据库名 - var database string - if err := conn.QueryRowContext(ctx, "SELECT current_database()").Scan(&database); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get database name", nil, err) - } - info.Database = database - - // 获取 schema 列表(排除系统 schema) - schemas, xe := d.listSchemas(ctx, conn, opts) - if xe != nil { - return nil, xe - } - - // 获取每个 schema 下的表 - for _, schema := range schemas { - tables, xe := d.listTables(ctx, conn, schema, opts) - if xe != nil { - return nil, xe - } - - // 获取每个表的详细信息 - for _, table := range tables { - // 获取列信息 - columns, xe := d.getColumns(ctx, conn, schema, table.Name) - if xe != nil { - return nil, xe - } - table.Columns = columns - - // 获取索引信息 - indexes, xe := d.getIndexes(ctx, conn, schema, table.Name) - if xe != nil { - return nil, xe - } - table.Indexes = indexes - - // 获取外键信息 - fks, xe := d.getForeignKeys(ctx, conn, schema, table.Name) - if xe != nil { - return nil, xe - } - table.ForeignKeys = fks - - info.Tables = append(info.Tables, table) - } - } - - return info, nil -} - -// listSchemas 获取 schema 列表 -func (d *Driver) listSchemas(ctx context.Context, conn *sql.DB, opts db.SchemaOptions) ([]string, *errors.XError) { - query := ` - SELECT schema_name - FROM information_schema.schemata - WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') - ` - - if !opts.IncludeSystem { - // 排除更多系统 schema - query += " AND schema_name NOT LIKE 'pg_%'" - } - - query += " ORDER BY schema_name" - - rows, err := conn.QueryContext(ctx, query) - if err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to list schemas", nil, err) - } - defer rows.Close() - - var schemas []string - for rows.Next() { - var schema string - if err := rows.Scan(&schema); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan schema row", nil, err) - } - schemas = append(schemas, schema) - } - - if err := rows.Err(); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) - } - - return schemas, nil -} - -// listTables 获取表列表 -func (d *Driver) listTables(ctx context.Context, conn *sql.DB, schema string, opts db.SchemaOptions) ([]db.Table, *errors.XError) { - query := ` - SELECT - t.table_name, - obj_description((quote_ident($1) || '.' || quote_ident(t.table_name))::regclass, 'pg_class') as table_comment - FROM information_schema.tables t - WHERE t.table_schema = $1 AND t.table_type = 'BASE TABLE' - ` - args := []any{schema} - - // 表名过滤 - if opts.TablePattern != "" { - // 将通配符 * 和 ? 转换为 SQL LIKE 模式 - likePattern := strings.ReplaceAll(opts.TablePattern, "*", "%") - likePattern = strings.ReplaceAll(likePattern, "?", "_") - query += " AND t.table_name LIKE $2" - args = append(args, likePattern) - } - - query += " ORDER BY t.table_name" - - rows, err := conn.QueryContext(ctx, query, args...) - if err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to list tables", nil, err) - } - defer rows.Close() - - var tables []db.Table - for rows.Next() { - var name string - var comment sql.NullString - if err := rows.Scan(&name, &comment); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan table row", nil, err) - } - tables = append(tables, db.Table{ - Schema: schema, - Name: name, - Comment: comment.String, - }) - } - - if err := rows.Err(); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) - } - - return tables, nil -} - -// getColumns 获取表的列信息 -func (d *Driver) getColumns(ctx context.Context, conn *sql.DB, schema, tableName string) ([]db.Column, *errors.XError) { - query := ` - SELECT - c.column_name, - CASE - WHEN c.data_type = 'USER-DEFINED' THEN c.udt_name - WHEN c.character_maximum_length IS NOT NULL THEN - c.data_type || '(' || c.character_maximum_length || ')' - WHEN c.numeric_precision IS NOT NULL AND c.numeric_scale IS NOT NULL THEN - c.data_type || '(' || c.numeric_precision || ',' || c.numeric_scale || ')' - WHEN c.numeric_precision IS NOT NULL THEN - c.data_type || '(' || c.numeric_precision || ')' - ELSE c.data_type - END as column_type, - c.is_nullable, - c.column_default, - col_description((quote_ident(c.table_schema) || '.' || quote_ident(c.table_name))::regclass, c.ordinal_position) as column_comment, - CASE WHEN pk.column_name IS NOT NULL THEN true ELSE false END AS is_primary - FROM information_schema.columns c - LEFT JOIN ( - SELECT kcu.table_schema, kcu.table_name, kcu.column_name - FROM information_schema.table_constraints tc - JOIN information_schema.key_column_usage kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema - WHERE tc.constraint_type = 'PRIMARY KEY' - ) pk ON c.table_schema = pk.table_schema - AND c.table_name = pk.table_name - AND c.column_name = pk.column_name - WHERE c.table_schema = $1 AND c.table_name = $2 - ORDER BY c.ordinal_position - ` - - rows, err := conn.QueryContext(ctx, query, schema, tableName) - if err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get columns", nil, err) - } - defer rows.Close() - - var columns []db.Column - for rows.Next() { - var name, colType, nullable string - var defaultValue, comment sql.NullString - var isPrimary bool - if err := rows.Scan(&name, &colType, &nullable, &defaultValue, &comment, &isPrimary); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan column row", nil, err) - } - - col := db.Column{ - Name: name, - Type: colType, - Nullable: nullable == "YES", - PrimaryKey: isPrimary, - } - if defaultValue.Valid { - col.Default = defaultValue.String - } - if comment.Valid { - col.Comment = comment.String - } - columns = append(columns, col) - } - - if err := rows.Err(); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) - } - - return columns, nil -} - -// getIndexes 获取表的索引信息 -func (d *Driver) getIndexes(ctx context.Context, conn *sql.DB, schema, tableName string) ([]db.Index, *errors.XError) { - query := ` - SELECT - i.relname as index_name, - a.attname as column_name, - NOT ix.indisunique as is_non_unique, - ix.indisprimary as is_primary, - array_position(ix.indkey, a.attnum) as column_position - FROM pg_class t - JOIN pg_index ix ON t.oid = ix.indrelid - JOIN pg_class i ON i.oid = ix.indexrelid - JOIN pg_namespace n ON t.relnamespace = n.oid - JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey) - WHERE n.nspname = $1 AND t.relname = $2 - ORDER BY i.relname, array_position(ix.indkey, a.attnum) - ` - - rows, err := conn.QueryContext(ctx, query, schema, tableName) - if err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get indexes", nil, err) - } - defer rows.Close() - - // 按 index_name 分组 - indexMap := make(map[string]*db.Index) - for rows.Next() { - var indexName, columnName string - var isNonUnique, isPrimary bool - var columnPosition int - if err := rows.Scan(&indexName, &columnName, &isNonUnique, &isPrimary, &columnPosition); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan index row", nil, err) - } - - if idx, exists := indexMap[indexName]; exists { - idx.Columns = append(idx.Columns, columnName) - } else { - indexMap[indexName] = &db.Index{ - Name: indexName, - Columns: []string{columnName}, - Unique: !isNonUnique, - Primary: isPrimary, - } - } - } - - if err := rows.Err(); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) - } - - // 转换为切片 - indexes := make([]db.Index, 0, len(indexMap)) - for _, idx := range indexMap { - indexes = append(indexes, *idx) - } - - return indexes, nil -} - -// getForeignKeys 获取表的外键信息 -func (d *Driver) getForeignKeys(ctx context.Context, conn *sql.DB, schema, tableName string) ([]db.ForeignKey, *errors.XError) { - query := ` - SELECT - tc.constraint_name, - kcu.column_name, - ccu.table_name AS referenced_table, - ccu.column_name AS referenced_column, - kcu.ordinal_position - FROM information_schema.table_constraints tc - JOIN information_schema.key_column_usage kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema - JOIN information_schema.constraint_column_usage ccu - ON tc.constraint_name = ccu.constraint_name - AND tc.table_schema = ccu.table_schema - WHERE tc.constraint_type = 'FOREIGN KEY' - AND tc.table_schema = $1 - AND tc.table_name = $2 - ORDER BY tc.constraint_name, kcu.ordinal_position - ` - - rows, err := conn.QueryContext(ctx, query, schema, tableName) - if err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get foreign keys", nil, err) - } - defer rows.Close() - - // 按 constraint_name 分组 - fkMap := make(map[string]*db.ForeignKey) - for rows.Next() { - var constraintName, columnName, refTable, refColumn string - var ordinalPosition int - if err := rows.Scan(&constraintName, &columnName, &refTable, &refColumn, &ordinalPosition); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan foreign key row", nil, err) - } - - if fk, exists := fkMap[constraintName]; exists { - fk.Columns = append(fk.Columns, columnName) - fk.ReferencedColumns = append(fk.ReferencedColumns, refColumn) - } else { - fkMap[constraintName] = &db.ForeignKey{ - Name: constraintName, - Columns: []string{columnName}, - ReferencedTable: refTable, - ReferencedColumns: []string{refColumn}, - } - } - } - - if err := rows.Err(); err != nil { - return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) - } - - // 转换为切片 - fks := make([]db.ForeignKey, 0, len(fkMap)) - for _, fk := range fkMap { - fks = append(fks, *fk) - } - - return fks, nil -} - -// matchPattern 检查表名是否匹配通配符模式 -func matchPattern(pattern, name string) bool { - // 简单实现:使用 filepath.Match - matched, _ := filepath.Match(pattern, name) - return matched -} +package pg + +import ( + "context" + "database/sql" + "strings" + + "github.com/zx06/xsql/internal/db" + "github.com/zx06/xsql/internal/errors" +) + +// DumpSchema 导出 PostgreSQL 数据库结构 +func (d *Driver) DumpSchema(ctx context.Context, conn *sql.DB, opts db.SchemaOptions) (*db.SchemaInfo, *errors.XError) { + info := &db.SchemaInfo{} + + // 获取当前数据库名 + var database string + if err := conn.QueryRowContext(ctx, "SELECT current_database()").Scan(&database); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get database name", nil, err) + } + info.Database = database + + // 获取 schema 列表(排除系统 schema) + schemas, xe := d.listSchemas(ctx, conn, opts) + if xe != nil { + return nil, xe + } + + // 获取每个 schema 下的表 + for _, schema := range schemas { + tables, xe := d.listTables(ctx, conn, schema, opts) + if xe != nil { + return nil, xe + } + + // 获取每个表的详细信息 + for _, table := range tables { + // 获取列信息 + columns, xe := d.getColumns(ctx, conn, schema, table.Name) + if xe != nil { + return nil, xe + } + table.Columns = columns + + // 获取索引信息 + indexes, xe := d.getIndexes(ctx, conn, schema, table.Name) + if xe != nil { + return nil, xe + } + table.Indexes = indexes + + // 获取外键信息 + fks, xe := d.getForeignKeys(ctx, conn, schema, table.Name) + if xe != nil { + return nil, xe + } + table.ForeignKeys = fks + + info.Tables = append(info.Tables, table) + } + } + + return info, nil +} + +// listSchemas 获取 schema 列表 +func (d *Driver) listSchemas(ctx context.Context, conn *sql.DB, opts db.SchemaOptions) ([]string, *errors.XError) { + query := ` + SELECT schema_name + FROM information_schema.schemata + WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') + ` + + if !opts.IncludeSystem { + // 排除更多系统 schema + query += " AND schema_name NOT LIKE 'pg_%'" + } + + query += " ORDER BY schema_name" + + rows, err := conn.QueryContext(ctx, query) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to list schemas", nil, err) + } + defer rows.Close() + + var schemas []string + for rows.Next() { + var schema string + if err := rows.Scan(&schema); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan schema row", nil, err) + } + schemas = append(schemas, schema) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + return schemas, nil +} + +// listTables 获取表列表 +func (d *Driver) listTables(ctx context.Context, conn *sql.DB, schema string, opts db.SchemaOptions) ([]db.Table, *errors.XError) { + query := ` + SELECT + t.table_name, + obj_description((quote_ident($1) || '.' || quote_ident(t.table_name))::regclass, 'pg_class') as table_comment + FROM information_schema.tables t + WHERE t.table_schema = $1 AND t.table_type = 'BASE TABLE' + ` + args := []any{schema} + + // 表名过滤 + if opts.TablePattern != "" { + // 将通配符 * 和 ? 转换为 SQL LIKE 模式 + likePattern := strings.ReplaceAll(opts.TablePattern, "*", "%") + likePattern = strings.ReplaceAll(likePattern, "?", "_") + query += " AND t.table_name LIKE $2" + args = append(args, likePattern) + } + + query += " ORDER BY t.table_name" + + rows, err := conn.QueryContext(ctx, query, args...) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to list tables", nil, err) + } + defer rows.Close() + + var tables []db.Table + for rows.Next() { + var name string + var comment sql.NullString + if err := rows.Scan(&name, &comment); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan table row", nil, err) + } + tables = append(tables, db.Table{ + Schema: schema, + Name: name, + Comment: comment.String, + }) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + return tables, nil +} + +// getColumns 获取表的列信息 +func (d *Driver) getColumns(ctx context.Context, conn *sql.DB, schema, tableName string) ([]db.Column, *errors.XError) { + query := ` + SELECT + c.column_name, + CASE + WHEN c.data_type = 'USER-DEFINED' THEN c.udt_name + WHEN c.character_maximum_length IS NOT NULL THEN + c.data_type || '(' || c.character_maximum_length || ')' + WHEN c.numeric_precision IS NOT NULL AND c.numeric_scale IS NOT NULL THEN + c.data_type || '(' || c.numeric_precision || ',' || c.numeric_scale || ')' + WHEN c.numeric_precision IS NOT NULL THEN + c.data_type || '(' || c.numeric_precision || ')' + ELSE c.data_type + END as column_type, + c.is_nullable, + c.column_default, + col_description((quote_ident(c.table_schema) || '.' || quote_ident(c.table_name))::regclass, c.ordinal_position) as column_comment, + CASE WHEN pk.column_name IS NOT NULL THEN true ELSE false END AS is_primary + FROM information_schema.columns c + LEFT JOIN ( + SELECT kcu.table_schema, kcu.table_name, kcu.column_name + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'PRIMARY KEY' + ) pk ON c.table_schema = pk.table_schema + AND c.table_name = pk.table_name + AND c.column_name = pk.column_name + WHERE c.table_schema = $1 AND c.table_name = $2 + ORDER BY c.ordinal_position + ` + + rows, err := conn.QueryContext(ctx, query, schema, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get columns", nil, err) + } + defer rows.Close() + + var columns []db.Column + for rows.Next() { + var name, colType, nullable string + var defaultValue, comment sql.NullString + var isPrimary bool + if err := rows.Scan(&name, &colType, &nullable, &defaultValue, &comment, &isPrimary); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan column row", nil, err) + } + + col := db.Column{ + Name: name, + Type: colType, + Nullable: nullable == "YES", + PrimaryKey: isPrimary, + } + if defaultValue.Valid { + col.Default = defaultValue.String + } + if comment.Valid { + col.Comment = comment.String + } + columns = append(columns, col) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + return columns, nil +} + +// getIndexes 获取表的索引信息 +func (d *Driver) getIndexes(ctx context.Context, conn *sql.DB, schema, tableName string) ([]db.Index, *errors.XError) { + query := ` + SELECT + i.relname as index_name, + a.attname as column_name, + NOT ix.indisunique as is_non_unique, + ix.indisprimary as is_primary, + array_position(ix.indkey, a.attnum) as column_position + FROM pg_class t + JOIN pg_index ix ON t.oid = ix.indrelid + JOIN pg_class i ON i.oid = ix.indexrelid + JOIN pg_namespace n ON t.relnamespace = n.oid + JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey) + WHERE n.nspname = $1 AND t.relname = $2 + ORDER BY i.relname, array_position(ix.indkey, a.attnum) + ` + + rows, err := conn.QueryContext(ctx, query, schema, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get indexes", nil, err) + } + defer rows.Close() + + // 按 index_name 分组 + indexMap := make(map[string]*db.Index) + for rows.Next() { + var indexName, columnName string + var isNonUnique, isPrimary bool + var columnPosition int + if err := rows.Scan(&indexName, &columnName, &isNonUnique, &isPrimary, &columnPosition); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan index row", nil, err) + } + + if idx, exists := indexMap[indexName]; exists { + idx.Columns = append(idx.Columns, columnName) + } else { + indexMap[indexName] = &db.Index{ + Name: indexName, + Columns: []string{columnName}, + Unique: !isNonUnique, + Primary: isPrimary, + } + } + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + // 转换为切片 + indexes := make([]db.Index, 0, len(indexMap)) + for _, idx := range indexMap { + indexes = append(indexes, *idx) + } + + return indexes, nil +} + +// getForeignKeys 获取表的外键信息 +func (d *Driver) getForeignKeys(ctx context.Context, conn *sql.DB, schema, tableName string) ([]db.ForeignKey, *errors.XError) { + query := ` + SELECT + tc.constraint_name, + kcu.column_name, + ccu.table_name AS referenced_table, + ccu.column_name AS referenced_column, + kcu.ordinal_position + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage ccu + ON tc.constraint_name = ccu.constraint_name + AND tc.table_schema = ccu.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = $1 + AND tc.table_name = $2 + ORDER BY tc.constraint_name, kcu.ordinal_position + ` + + rows, err := conn.QueryContext(ctx, query, schema, tableName) + if err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to get foreign keys", nil, err) + } + defer rows.Close() + + // 按 constraint_name 分组 + fkMap := make(map[string]*db.ForeignKey) + for rows.Next() { + var constraintName, columnName, refTable, refColumn string + var ordinalPosition int + if err := rows.Scan(&constraintName, &columnName, &refTable, &refColumn, &ordinalPosition); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "failed to scan foreign key row", nil, err) + } + + if fk, exists := fkMap[constraintName]; exists { + fk.Columns = append(fk.Columns, columnName) + fk.ReferencedColumns = append(fk.ReferencedColumns, refColumn) + } else { + fkMap[constraintName] = &db.ForeignKey{ + Name: constraintName, + Columns: []string{columnName}, + ReferencedTable: refTable, + ReferencedColumns: []string{refColumn}, + } + } + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(errors.CodeDBExecFailed, "rows iteration error", nil, err) + } + + // 转换为切片 + fks := make([]db.ForeignKey, 0, len(fkMap)) + for _, fk := range fkMap { + fks = append(fks, *fk) + } + + return fks, nil +} diff --git a/internal/db/schema.go b/internal/db/schema.go index 0185a6c..79d5a41 100644 --- a/internal/db/schema.go +++ b/internal/db/schema.go @@ -1,108 +1,108 @@ -package db - -import ( - "context" - "database/sql" - - "github.com/zx06/xsql/internal/errors" - "github.com/zx06/xsql/internal/output" -) - -// SchemaInfo 数据库 schema 信息 -type SchemaInfo struct { - Database string `json:"database" yaml:"database"` - Tables []Table `json:"tables" yaml:"tables"` -} - -// ToSchemaData 实现 output.SchemaFormatter 接口 -func (s *SchemaInfo) ToSchemaData() (string, []output.SchemaTable, bool) { - if s == nil || len(s.Tables) == 0 { - return "", nil, false - } - - tables := make([]output.SchemaTable, len(s.Tables)) - for i, t := range s.Tables { - tables[i].Schema = t.Schema - tables[i].Name = t.Name - tables[i].Comment = t.Comment - tables[i].Columns = make([]output.SchemaColumn, len(t.Columns)) - for j, c := range t.Columns { - tables[i].Columns[j] = output.SchemaColumn{ - Name: c.Name, - Type: c.Type, - Nullable: c.Nullable, - Default: c.Default, - Comment: c.Comment, - PrimaryKey: c.PrimaryKey, - } - } - } - - return s.Database, tables, true -} - -// Table 表信息 -type Table struct { - Schema string `json:"schema" yaml:"schema"` // PostgreSQL schema,MySQL 为数据库名 - Name string `json:"name" yaml:"name"` // 表名 - Comment string `json:"comment,omitempty" yaml:"comment,omitempty"` - Columns []Column `json:"columns" yaml:"columns"` - Indexes []Index `json:"indexes,omitempty" yaml:"indexes,omitempty"` - ForeignKeys []ForeignKey `json:"foreign_keys,omitempty" yaml:"foreign_keys,omitempty"` -} - -// Column 列信息 -type Column struct { - Name string `json:"name" yaml:"name"` - Type string `json:"type" yaml:"type"` // 数据类型,如 varchar(255)、bigint - Nullable bool `json:"nullable" yaml:"nullable"` // 是否允许 NULL - Default string `json:"default,omitempty" yaml:"default,omitempty"` // 默认值 - Comment string `json:"comment,omitempty" yaml:"comment,omitempty"` // 列注释 - PrimaryKey bool `json:"primary_key" yaml:"primary_key"` // 是否为主键 -} - -// Index 索引信息 -type Index struct { - Name string `json:"name" yaml:"name"` // 索引名 - Columns []string `json:"columns" yaml:"columns"` // 索引列 - Unique bool `json:"unique" yaml:"unique"` // 是否唯一索引 - Primary bool `json:"primary" yaml:"primary"` // 是否主键索引 -} - -// ForeignKey 外键信息 -type ForeignKey struct { - Name string `json:"name" yaml:"name"` // 外键名 - Columns []string `json:"columns" yaml:"columns"` // 本表列 - ReferencedTable string `json:"referenced_table" yaml:"referenced_table"` // 引用表 - ReferencedColumns []string `json:"referenced_columns" yaml:"referenced_columns"` // 引用列 -} - -// SchemaOptions schema 导出选项 -type SchemaOptions struct { - TablePattern string // 表名过滤(支持通配符) - IncludeSystem bool // 是否包含系统表 -} - -// SchemaDriver schema 导出接口 -// Driver 可选择实现此接口以支持 schema 导出 -type SchemaDriver interface { - Driver - // DumpSchema 导出数据库结构 - DumpSchema(ctx context.Context, db *sql.DB, opts SchemaOptions) (*SchemaInfo, *errors.XError) -} - -// DumpSchema 导出数据库结构 -// 会检查 driver 是否实现了 SchemaDriver 接口 -func DumpSchema(ctx context.Context, driverName string, db *sql.DB, opts SchemaOptions) (*SchemaInfo, *errors.XError) { - d, ok := Get(driverName) - if !ok { - return nil, errors.New(errors.CodeDBDriverUnsupported, "unsupported driver: "+driverName, nil) - } - - sd, ok := d.(SchemaDriver) - if !ok { - return nil, errors.New(errors.CodeDBDriverUnsupported, "driver does not support schema dump: "+driverName, nil) - } - - return sd.DumpSchema(ctx, db, opts) -} +package db + +import ( + "context" + "database/sql" + + "github.com/zx06/xsql/internal/errors" + "github.com/zx06/xsql/internal/output" +) + +// SchemaInfo 数据库 schema 信息 +type SchemaInfo struct { + Database string `json:"database" yaml:"database"` + Tables []Table `json:"tables" yaml:"tables"` +} + +// ToSchemaData 实现 output.SchemaFormatter 接口 +func (s *SchemaInfo) ToSchemaData() (string, []output.SchemaTable, bool) { + if s == nil || len(s.Tables) == 0 { + return "", nil, false + } + + tables := make([]output.SchemaTable, len(s.Tables)) + for i, t := range s.Tables { + tables[i].Schema = t.Schema + tables[i].Name = t.Name + tables[i].Comment = t.Comment + tables[i].Columns = make([]output.SchemaColumn, len(t.Columns)) + for j, c := range t.Columns { + tables[i].Columns[j] = output.SchemaColumn{ + Name: c.Name, + Type: c.Type, + Nullable: c.Nullable, + Default: c.Default, + Comment: c.Comment, + PrimaryKey: c.PrimaryKey, + } + } + } + + return s.Database, tables, true +} + +// Table 表信息 +type Table struct { + Schema string `json:"schema" yaml:"schema"` // PostgreSQL schema,MySQL 为数据库名 + Name string `json:"name" yaml:"name"` // 表名 + Comment string `json:"comment,omitempty" yaml:"comment,omitempty"` + Columns []Column `json:"columns" yaml:"columns"` + Indexes []Index `json:"indexes,omitempty" yaml:"indexes,omitempty"` + ForeignKeys []ForeignKey `json:"foreign_keys,omitempty" yaml:"foreign_keys,omitempty"` +} + +// Column 列信息 +type Column struct { + Name string `json:"name" yaml:"name"` + Type string `json:"type" yaml:"type"` // 数据类型,如 varchar(255)、bigint + Nullable bool `json:"nullable" yaml:"nullable"` // 是否允许 NULL + Default string `json:"default,omitempty" yaml:"default,omitempty"` // 默认值 + Comment string `json:"comment,omitempty" yaml:"comment,omitempty"` // 列注释 + PrimaryKey bool `json:"primary_key" yaml:"primary_key"` // 是否为主键 +} + +// Index 索引信息 +type Index struct { + Name string `json:"name" yaml:"name"` // 索引名 + Columns []string `json:"columns" yaml:"columns"` // 索引列 + Unique bool `json:"unique" yaml:"unique"` // 是否唯一索引 + Primary bool `json:"primary" yaml:"primary"` // 是否主键索引 +} + +// ForeignKey 外键信息 +type ForeignKey struct { + Name string `json:"name" yaml:"name"` // 外键名 + Columns []string `json:"columns" yaml:"columns"` // 本表列 + ReferencedTable string `json:"referenced_table" yaml:"referenced_table"` // 引用表 + ReferencedColumns []string `json:"referenced_columns" yaml:"referenced_columns"` // 引用列 +} + +// SchemaOptions schema 导出选项 +type SchemaOptions struct { + TablePattern string // 表名过滤(支持通配符) + IncludeSystem bool // 是否包含系统表 +} + +// SchemaDriver schema 导出接口 +// Driver 可选择实现此接口以支持 schema 导出 +type SchemaDriver interface { + Driver + // DumpSchema 导出数据库结构 + DumpSchema(ctx context.Context, db *sql.DB, opts SchemaOptions) (*SchemaInfo, *errors.XError) +} + +// DumpSchema 导出数据库结构 +// 会检查 driver 是否实现了 SchemaDriver 接口 +func DumpSchema(ctx context.Context, driverName string, db *sql.DB, opts SchemaOptions) (*SchemaInfo, *errors.XError) { + d, ok := Get(driverName) + if !ok { + return nil, errors.New(errors.CodeDBDriverUnsupported, "unsupported driver: "+driverName, nil) + } + + sd, ok := d.(SchemaDriver) + if !ok { + return nil, errors.New(errors.CodeDBDriverUnsupported, "driver does not support schema dump: "+driverName, nil) + } + + return sd.DumpSchema(ctx, db, opts) +} From b2f9d0e10898b0f13f2aaa8b3e5d1074916c394f Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Wed, 11 Feb 2026 14:59:19 +0800 Subject: [PATCH 04/10] fix: format cmd/xsql/schema.go with goimports --- cmd/xsql/schema.go | 262 ++++++++++++++++++++++----------------------- 1 file changed, 131 insertions(+), 131 deletions(-) diff --git a/cmd/xsql/schema.go b/cmd/xsql/schema.go index 1e0e566..649ea7a 100644 --- a/cmd/xsql/schema.go +++ b/cmd/xsql/schema.go @@ -1,131 +1,131 @@ -package main - -import ( - "context" - "time" - - "github.com/spf13/cobra" - - "github.com/zx06/xsql/internal/db" - _ "github.com/zx06/xsql/internal/db/mysql" - _ "github.com/zx06/xsql/internal/db/pg" - "github.com/zx06/xsql/internal/errors" - "github.com/zx06/xsql/internal/output" - "github.com/zx06/xsql/internal/secret" -) - -// SchemaFlags holds the flags for the schema command -type SchemaFlags struct { - TablePattern string - IncludeSystem bool - AllowPlaintext bool - SSHSkipHostKey bool -} - -// NewSchemaCommand creates the schema command -func NewSchemaCommand(w *output.Writer) *cobra.Command { - flags := &SchemaFlags{} - - cmd := &cobra.Command{ - Use: "schema", - Short: "Database schema operations", - } - - // Add subcommands - cmd.AddCommand(NewSchemaDumpCommand(w, flags)) - - return cmd -} - -// NewSchemaDumpCommand creates the schema dump subcommand -func NewSchemaDumpCommand(w *output.Writer, flags *SchemaFlags) *cobra.Command { - cmd := &cobra.Command{ - Use: "dump", - Short: "Dump database schema (tables, columns, indexes, foreign keys)", - RunE: func(cmd *cobra.Command, args []string) error { - return runSchemaDump(cmd, args, flags, w) - }, - } - - cmd.Flags().StringVar(&flags.TablePattern, "table", "", "Table name filter (supports * and ? wildcards)") - cmd.Flags().BoolVar(&flags.IncludeSystem, "include-system", false, "Include system tables") - cmd.Flags().BoolVar(&flags.AllowPlaintext, "allow-plaintext", false, "Allow plaintext secrets in config") - cmd.Flags().BoolVar(&flags.SSHSkipHostKey, "ssh-skip-known-hosts-check", false, "Skip SSH known_hosts check (dangerous)") - - return cmd -} - -// runSchemaDump executes the schema dump command -func runSchemaDump(cmd *cobra.Command, args []string, flags *SchemaFlags, w *output.Writer) error { - format, err := parseOutputFormat(GlobalConfig.FormatStr) - if err != nil { - return err - } - - p := GlobalConfig.Resolved.Profile - if p.DB == "" { - return errors.New(errors.CodeCfgInvalid, "db type is required (mysql|pg)", nil) - } - - // Allow plaintext passwords (CLI > Config) - allowPlaintext := flags.AllowPlaintext || p.AllowPlaintext - - // Resolve password (supports keyring) - password := p.Password - if password != "" { - pw, xe := secret.Resolve(password, secret.Options{AllowPlaintext: allowPlaintext}) - if xe != nil { - return xe - } - password = pw - } - - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - - // SSH proxy (if configured) - sshClient, err := setupSSH(ctx, p, allowPlaintext, flags.SSHSkipHostKey) - if err != nil { - return err - } - if sshClient != nil { - defer sshClient.Close() - } - - // Get driver - drv, ok := db.Get(p.DB) - if !ok { - return errors.New(errors.CodeDBDriverUnsupported, "unsupported db driver", map[string]any{"db": p.DB}) - } - - connOpts := db.ConnOptions{ - DSN: p.DSN, - Host: p.Host, - Port: p.Port, - User: p.User, - Password: password, - Database: p.Database, - } - if sshClient != nil { - connOpts.Dialer = sshClient - } - - conn, xe := drv.Open(ctx, connOpts) - if xe != nil { - return xe - } - defer conn.Close() - - // Dump schema - schemaOpts := db.SchemaOptions{ - TablePattern: flags.TablePattern, - IncludeSystem: flags.IncludeSystem, - } - - result, xe := db.DumpSchema(ctx, p.DB, conn, schemaOpts) - if xe != nil { - return xe - } - - return w.WriteOK(format, result) -} +package main + +import ( + "context" + "time" + + "github.com/spf13/cobra" + + "github.com/zx06/xsql/internal/db" + _ "github.com/zx06/xsql/internal/db/mysql" + _ "github.com/zx06/xsql/internal/db/pg" + "github.com/zx06/xsql/internal/errors" + "github.com/zx06/xsql/internal/output" + "github.com/zx06/xsql/internal/secret" +) + +// SchemaFlags holds the flags for the schema command +type SchemaFlags struct { + TablePattern string + IncludeSystem bool + AllowPlaintext bool + SSHSkipHostKey bool +} + +// NewSchemaCommand creates the schema command +func NewSchemaCommand(w *output.Writer) *cobra.Command { + flags := &SchemaFlags{} + + cmd := &cobra.Command{ + Use: "schema", + Short: "Database schema operations", + } + + // Add subcommands + cmd.AddCommand(NewSchemaDumpCommand(w, flags)) + + return cmd +} + +// NewSchemaDumpCommand creates the schema dump subcommand +func NewSchemaDumpCommand(w *output.Writer, flags *SchemaFlags) *cobra.Command { + cmd := &cobra.Command{ + Use: "dump", + Short: "Dump database schema (tables, columns, indexes, foreign keys)", + RunE: func(cmd *cobra.Command, args []string) error { + return runSchemaDump(cmd, args, flags, w) + }, + } + + cmd.Flags().StringVar(&flags.TablePattern, "table", "", "Table name filter (supports * and ? wildcards)") + cmd.Flags().BoolVar(&flags.IncludeSystem, "include-system", false, "Include system tables") + cmd.Flags().BoolVar(&flags.AllowPlaintext, "allow-plaintext", false, "Allow plaintext secrets in config") + cmd.Flags().BoolVar(&flags.SSHSkipHostKey, "ssh-skip-known-hosts-check", false, "Skip SSH known_hosts check (dangerous)") + + return cmd +} + +// runSchemaDump executes the schema dump command +func runSchemaDump(cmd *cobra.Command, args []string, flags *SchemaFlags, w *output.Writer) error { + format, err := parseOutputFormat(GlobalConfig.FormatStr) + if err != nil { + return err + } + + p := GlobalConfig.Resolved.Profile + if p.DB == "" { + return errors.New(errors.CodeCfgInvalid, "db type is required (mysql|pg)", nil) + } + + // Allow plaintext passwords (CLI > Config) + allowPlaintext := flags.AllowPlaintext || p.AllowPlaintext + + // Resolve password (supports keyring) + password := p.Password + if password != "" { + pw, xe := secret.Resolve(password, secret.Options{AllowPlaintext: allowPlaintext}) + if xe != nil { + return xe + } + password = pw + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + // SSH proxy (if configured) + sshClient, err := setupSSH(ctx, p, allowPlaintext, flags.SSHSkipHostKey) + if err != nil { + return err + } + if sshClient != nil { + defer sshClient.Close() + } + + // Get driver + drv, ok := db.Get(p.DB) + if !ok { + return errors.New(errors.CodeDBDriverUnsupported, "unsupported db driver", map[string]any{"db": p.DB}) + } + + connOpts := db.ConnOptions{ + DSN: p.DSN, + Host: p.Host, + Port: p.Port, + User: p.User, + Password: password, + Database: p.Database, + } + if sshClient != nil { + connOpts.Dialer = sshClient + } + + conn, xe := drv.Open(ctx, connOpts) + if xe != nil { + return xe + } + defer conn.Close() + + // Dump schema + schemaOpts := db.SchemaOptions{ + TablePattern: flags.TablePattern, + IncludeSystem: flags.IncludeSystem, + } + + result, xe := db.DumpSchema(ctx, p.DB, conn, schemaOpts) + if xe != nil { + return xe + } + + return w.WriteOK(format, result) +} From 41d0fa7afb585483f051e6fce7306f0eabd1155b Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Wed, 11 Feb 2026 15:10:48 +0800 Subject: [PATCH 05/10] test: ensure schema dump E2E uses DSN for coverage --- tests/e2e/schema_test.go | 61 +++++++++++++++------------------------- 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/tests/e2e/schema_test.go b/tests/e2e/schema_test.go index 4d6853b..381acab 100644 --- a/tests/e2e/schema_test.go +++ b/tests/e2e/schema_test.go @@ -4,25 +4,22 @@ package e2e import ( "encoding/json" + "fmt" "testing" ) func TestSchemaDump_JSON(t *testing.T) { - config := createTempConfig(t, `profiles: + config := createTempConfig(t, fmt.Sprintf(`profiles: dev: description: "开发环境" db: mysql - host: 127.0.0.1 - port: 3306 - user: root - database: testdb - allow_plaintext: true -`) + dsn: "%s" +`, mysqlDSN(t))) stdout, stderr, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "json") - // 验证退出码(可能因数据库未启动而失败,但输出格式应正确) - if exitCode != 0 && exitCode != 3 && exitCode != 5 { - t.Errorf("unexpected exit code %d, stderr: %s", exitCode, stderr) + // 验证退出码 + if exitCode != 0 { + t.Fatalf("unexpected exit code %d, stderr: %s", exitCode, stderr) } // 验证 JSON 格式 @@ -81,21 +78,17 @@ func TestSchemaDump_JSON(t *testing.T) { } func TestSchemaDump_YAML(t *testing.T) { - config := createTempConfig(t, `profiles: + config := createTempConfig(t, fmt.Sprintf(`profiles: dev: description: "开发环境" db: pg - host: 127.0.0.1 - port: 5432 - user: postgres - database: testdb - allow_plaintext: true -`) + dsn: "%s" +`, pgDSN(t))) stdout, stderr, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "yaml") // 验证退出码 - if exitCode != 0 && exitCode != 3 && exitCode != 5 { - t.Errorf("unexpected exit code %d, stderr: %s", exitCode, stderr) + if exitCode != 0 { + t.Fatalf("unexpected exit code %d, stderr: %s", exitCode, stderr) } // 验证 YAML 格式包含必要字段 @@ -108,21 +101,17 @@ func TestSchemaDump_YAML(t *testing.T) { } func TestSchemaDump_Table(t *testing.T) { - config := createTempConfig(t, `profiles: + config := createTempConfig(t, fmt.Sprintf(`profiles: dev: description: "开发环境" db: mysql - host: 127.0.0.1 - port: 3306 - user: root - database: testdb - allow_plaintext: true -`) + dsn: "%s" +`, mysqlDSN(t))) stdout, stderr, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "table") // 验证退出码 - if exitCode != 0 && exitCode != 3 && exitCode != 5 { - t.Errorf("unexpected exit code %d, stderr: %s", exitCode, stderr) + if exitCode != 0 { + t.Fatalf("unexpected exit code %d, stderr: %s", exitCode, stderr) } // 验证 Table 格式不包含 JSON 元数据 @@ -182,22 +171,18 @@ func TestSchemaDump_MissingProfile(t *testing.T) { } func TestSchemaDump_TableFilter(t *testing.T) { - config := createTempConfig(t, `profiles: + config := createTempConfig(t, fmt.Sprintf(`profiles: dev: description: "开发环境" db: mysql - host: 127.0.0.1 - port: 3306 - user: root - database: testdb - allow_plaintext: true -`) + dsn: "%s" +`, mysqlDSN(t))) // 使用 --table 过滤 stdout, stderr, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "json", "--table", "user*") - // 验证退出码(可能因数据库未启动而失败) - if exitCode != 0 && exitCode != 3 && exitCode != 5 { - t.Errorf("unexpected exit code %d, stderr: %s", exitCode, stderr) + // 验证退出码 + if exitCode != 0 { + t.Fatalf("unexpected exit code %d, stderr: %s", exitCode, stderr) } // 如果成功,验证过滤生效 From 42b486cea3cffe61e9b5d85614804c80a55a048f Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Wed, 11 Feb 2026 16:02:50 +0800 Subject: [PATCH 06/10] feat: integration test --- cmd/xsql/command_unit_test.go | 128 ------------- internal/db/query_test.go | 189 ------------------ internal/ssh/client_test.go | 24 ++- tests/integration/schema_dump_test.go | 265 ++++++++++++++++++++++++++ 4 files changed, 287 insertions(+), 319 deletions(-) create mode 100644 tests/integration/schema_dump_test.go diff --git a/cmd/xsql/command_unit_test.go b/cmd/xsql/command_unit_test.go index e14fd72..d4d8b67 100644 --- a/cmd/xsql/command_unit_test.go +++ b/cmd/xsql/command_unit_test.go @@ -3,19 +3,13 @@ package main import ( "bytes" "context" - "database/sql" - "database/sql/driver" "encoding/json" - "fmt" - "io" "os" "path/filepath" "testing" - "time" "github.com/zx06/xsql/internal/app" "github.com/zx06/xsql/internal/config" - xdb "github.com/zx06/xsql/internal/db" "github.com/zx06/xsql/internal/errors" "github.com/zx06/xsql/internal/output" ) @@ -98,30 +92,6 @@ func TestRunQuery_MissingDB(t *testing.T) { } } -func TestRunQuery_Success(t *testing.T) { - driverName := registerStubDriver(t, map[string]*stubRows{ - "select 1": { - columns: []string{"value"}, - rows: [][]driver.Value{{1}}, - }, - }) - - GlobalConfig.Resolved.Profile = config.Profile{ - DB: driverName, - } - GlobalConfig.FormatStr = "json" - - var out bytes.Buffer - w := output.New(&out, &bytes.Buffer{}) - err := runQuery(nil, []string{"select 1"}, &QueryFlags{UnsafeAllowWrite: true}, &w) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !json.Valid(out.Bytes()) { - t.Fatalf("expected json output, got %s", out.String()) - } -} - func TestRunProxy_ProfileRequired(t *testing.T) { GlobalConfig.ProfileStr = "" GlobalConfig.FormatStr = "json" @@ -519,101 +489,3 @@ profiles: func configProfile(dbType string) config.Profile { return config.Profile{DB: dbType} } - -type stubDriver struct { - responseRows map[string]*stubRows -} - -type stubConnector struct { - driver *stubDriver -} - -func (c *stubConnector) Connect(context.Context) (driver.Conn, error) { - return &stubConn{driver: c.driver}, nil -} - -func (c *stubConnector) Driver() driver.Driver { - return c.driver -} - -func (d *stubDriver) Open(string) (driver.Conn, error) { - return &stubConn{driver: d}, nil -} - -type stubConn struct { - driver *stubDriver -} - -func (c *stubConn) Prepare(string) (driver.Stmt, error) { - return nil, fmt.Errorf("prepare not supported") -} - -func (c *stubConn) Close() error { - return nil -} - -func (c *stubConn) Begin() (driver.Tx, error) { - return &stubTx{}, nil -} - -func (c *stubConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - if rows, ok := c.driver.responseRows[query]; ok { - return rows, nil - } - return nil, fmt.Errorf("unexpected query: %s", query) -} - -type stubTx struct{} - -func (t *stubTx) Commit() error { - return nil -} - -func (t *stubTx) Rollback() error { - return nil -} - -type stubRows struct { - columns []string - rows [][]driver.Value - idx int -} - -func (r *stubRows) Columns() []string { - return r.columns -} - -func (r *stubRows) Close() error { - return nil -} - -func (r *stubRows) Next(dest []driver.Value) error { - if r.idx >= len(r.rows) { - return io.EOF - } - copy(dest, r.rows[r.idx]) - r.idx++ - return nil -} - -func registerStubDriver(t *testing.T, rows map[string]*stubRows) string { - t.Helper() - - name := fmt.Sprintf("stub-%d", time.Now().UnixNano()) - driver := &stubDriver{responseRows: rows} - db := sql.OpenDB(&stubConnector{driver: driver}) - t.Cleanup(func() { - _ = db.Close() - }) - - xdb.Register(name, fakeDriver{db: db}) - return name -} - -type fakeDriver struct { - db *sql.DB -} - -func (d fakeDriver) Open(ctx context.Context, opts xdb.ConnOptions) (*sql.DB, *errors.XError) { - return d.db, nil -} diff --git a/internal/db/query_test.go b/internal/db/query_test.go index 7a4d6b8..ac3aad4 100644 --- a/internal/db/query_test.go +++ b/internal/db/query_test.go @@ -1,16 +1,7 @@ package db import ( - "context" - "database/sql" - "database/sql/driver" - stdErrors "errors" - "fmt" - "io" "testing" - "time" - - "github.com/zx06/xsql/internal/errors" ) // 纯函数单元测试,不需要数据库连接 @@ -38,104 +29,6 @@ func TestConvertValue(t *testing.T) { // Query 函数的集成测试在 tests/integration/query_test.go 中 -type stubDriver struct { - responseRows map[string]*stubRows - beginCalled bool - beginReadOnly bool -} - -type stubConnector struct { - driver *stubDriver -} - -func (c *stubConnector) Connect(context.Context) (driver.Conn, error) { - return &stubConn{driver: c.driver}, nil -} - -func (c *stubConnector) Driver() driver.Driver { - return c.driver -} - -func (d *stubDriver) Open(string) (driver.Conn, error) { - return &stubConn{driver: d}, nil -} - -type stubConn struct { - driver *stubDriver -} - -func (c *stubConn) Prepare(string) (driver.Stmt, error) { - return nil, stdErrors.New("prepare not supported") -} - -func (c *stubConn) Close() error { - return nil -} - -func (c *stubConn) Begin() (driver.Tx, error) { - return &stubTx{}, nil -} - -func (c *stubConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - c.driver.beginCalled = true - c.driver.beginReadOnly = opts.ReadOnly - return &stubTx{}, nil -} - -func (c *stubConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - if rows, ok := c.driver.responseRows[query]; ok { - return rows, nil - } - return nil, fmt.Errorf("unexpected query: %s", query) -} - -type stubTx struct{} - -func (t *stubTx) Commit() error { - return nil -} - -func (t *stubTx) Rollback() error { - return nil -} - -type stubRows struct { - columns []string - rows [][]driver.Value - idx int - err error -} - -func (r *stubRows) Columns() []string { - return r.columns -} - -func (r *stubRows) Close() error { - return nil -} - -func (r *stubRows) Next(dest []driver.Value) error { - if r.idx >= len(r.rows) { - if r.err != nil { - return r.err - } - return io.EOF - } - copy(dest, r.rows[r.idx]) - r.idx++ - return nil -} - -func newStubDB(t *testing.T, rows map[string]*stubRows) (*sql.DB, *stubDriver) { - t.Helper() - driver := &stubDriver{responseRows: rows} - db := sql.OpenDB(&stubConnector{driver: driver}) - t.Cleanup(func() { - _ = db.Close() - }) - return db, driver -} - func TestQueryResultToTableData(t *testing.T) { var result *QueryResult cols, rows, ok := result.ToTableData() @@ -149,85 +42,3 @@ func TestQueryResultToTableData(t *testing.T) { t.Fatalf("expected table data, got ok=%v cols=%v rows=%v", ok, cols, rows) } } - -func TestQuery_UnsafeAllowWrite_UsesDirectQuery(t *testing.T) { - db, driver := newStubDB(t, map[string]*stubRows{ - "select 1": { - columns: []string{"value"}, - rows: [][]driver.Value{{1}}, - }, - }) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - result, xe := Query(ctx, db, "select 1", QueryOptions{UnsafeAllowWrite: true}) - if xe != nil { - t.Fatalf("unexpected error: %v", xe) - } - if driver.beginCalled { - t.Fatalf("expected no transaction when UnsafeAllowWrite=true") - } - if len(result.Rows) != 1 || result.Rows[0]["value"] != 1 { - t.Fatalf("unexpected result: %#v", result) - } -} - -func TestQuery_ReadOnly_UsesTransaction(t *testing.T) { - db, driver := newStubDB(t, map[string]*stubRows{ - "select 1": { - columns: []string{"value"}, - rows: [][]driver.Value{{1}}, - }, - }) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - result, xe := Query(ctx, db, "select 1", QueryOptions{}) - if xe != nil { - t.Fatalf("unexpected error: %v", xe) - } - if !driver.beginCalled || !driver.beginReadOnly { - t.Fatalf("expected read-only transaction, beginCalled=%v readOnly=%v", driver.beginCalled, driver.beginReadOnly) - } - if len(result.Rows) != 1 { - t.Fatalf("expected one row, got %#v", result) - } -} - -func TestQuery_ReadOnlyBlocksWrite(t *testing.T) { - db, _ := newStubDB(t, map[string]*stubRows{}) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - _, xe := Query(ctx, db, "INSERT INTO t VALUES (1)", QueryOptions{}) - if xe == nil { - t.Fatal("expected error for write query") - } - if xe.Code != errors.CodeROBlocked { - t.Fatalf("expected CodeROBlocked, got %s", xe.Code) - } -} - -func TestScanRows_ReportsIterationError(t *testing.T) { - db, _ := newStubDB(t, map[string]*stubRows{ - "select error": { - columns: []string{"value"}, - rows: [][]driver.Value{{1}}, - err: stdErrors.New("iteration error"), - }, - }) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - _, xe := executeQuery(ctx, db, "select error") - if xe == nil { - t.Fatal("expected error from rows iteration") - } - if xe.Code != errors.CodeDBExecFailed { - t.Fatalf("expected CodeDBExecFailed, got %s", xe.Code) - } -} diff --git a/internal/ssh/client_test.go b/internal/ssh/client_test.go index c0aa318..982528c 100644 --- a/internal/ssh/client_test.go +++ b/internal/ssh/client_test.go @@ -9,6 +9,7 @@ import ( "net" "os" "path/filepath" + "strings" "testing" "time" @@ -350,7 +351,26 @@ func writeTestKey(t *testing.T, dir, name string) string { return path } -// Helper function to check if a path contains another path component +// Helper function to check if a path contains another path component (cross-platform) func containsPath(path, component string) bool { - return len(path) >= len(component) && (path == component || path[len(path)-len(component)-1] == '/' || path[:len(component)] == component) + p := filepath.ToSlash(filepath.Clean(path)) + c := filepath.ToSlash(filepath.Clean(component)) + + if p == c { + return true + } + + if len(p) < len(c) { + return false + } + + if strings.HasPrefix(p, c+"/") { + return true + } + + if strings.Contains(p, "/"+c+"/") || strings.HasSuffix(p, "/"+c) { + return true + } + + return false } diff --git a/tests/integration/schema_dump_test.go b/tests/integration/schema_dump_test.go new file mode 100644 index 0000000..2209412 --- /dev/null +++ b/tests/integration/schema_dump_test.go @@ -0,0 +1,265 @@ +//go:build integration + +package integration + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/zx06/xsql/internal/db" + _ "github.com/zx06/xsql/internal/db/mysql" + _ "github.com/zx06/xsql/internal/db/pg" +) + +func TestSchemaDump_MySQL_RealDB(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_MYSQL_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_MYSQL_DSN not set") + } + + drv, ok := db.Get("mysql") + if !ok { + t.Fatal("mysql driver not registered") + } + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) + if xe != nil { + t.Fatalf("failed to open mysql: %v", xe) + } + defer conn.Close() + + suffix := time.Now().UnixNano() + prefix := fmt.Sprintf("xsql_schema_%d", suffix) + usersTable := prefix + "_users" + ordersTable := prefix + "_orders" + + // 清理旧表 + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", ordersTable)) + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", usersTable)) + + // 创建表结构 + _, err := conn.ExecContext(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id BIGINT PRIMARY KEY, + email VARCHAR(255) NOT NULL, + tenant_id BIGINT NOT NULL, + created_at DATETIME NULL, + INDEX idx_email (email) + ) ENGINE=InnoDB + `, usersTable)) + if err != nil { + t.Fatalf("create users table failed: %v", err) + } + + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id BIGINT PRIMARY KEY, + user_id BIGINT NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_%s_user FOREIGN KEY (user_id) REFERENCES %s(id) + ) ENGINE=InnoDB + `, ordersTable, ordersTable, usersTable)) + if err != nil { + t.Fatalf("create orders table failed: %v", err) + } + + t.Cleanup(func() { + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", ordersTable)) + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", usersTable)) + }) + + info, xe := db.DumpSchema(ctx, "mysql", conn, db.SchemaOptions{ + TablePattern: prefix + "*", + }) + if xe != nil { + t.Fatalf("DumpSchema error: %v", xe) + } + if info.Database == "" { + t.Fatalf("database name is empty") + } + + users := findTable(info.Tables, usersTable) + orders := findTable(info.Tables, ordersTable) + if users == nil || orders == nil { + t.Fatalf("missing tables in schema dump: users=%v orders=%v", users != nil, orders != nil) + } + + if users.Schema == "" { + t.Fatalf("users schema is empty") + } + if len(users.Columns) == 0 { + t.Fatalf("users columns should not be empty") + } + + if !hasColumn(users, "id", true) { + t.Fatalf("users table missing primary key column 'id'") + } + if !hasIndex(users, "PRIMARY") { + t.Fatalf("users table missing PRIMARY index") + } + if !hasIndex(users, "idx_email") { + t.Fatalf("users table missing idx_email index") + } + + if len(orders.ForeignKeys) == 0 { + t.Fatalf("orders table should have foreign keys") + } + if !hasForeignKeyTo(orders, usersTable) { + t.Fatalf("orders table missing FK to %s", usersTable) + } +} + +func TestSchemaDump_Pg_RealDB(t *testing.T) { + dsn := os.Getenv("XSQL_TEST_PG_DSN") + if dsn == "" { + t.Skip("XSQL_TEST_PG_DSN not set") + } + + drv, ok := db.Get("pg") + if !ok { + t.Fatal("pg driver not registered") + } + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + conn, xe := drv.Open(ctx, db.ConnOptions{DSN: dsn}) + if xe != nil { + t.Fatalf("failed to open pg: %v", xe) + } + defer conn.Close() + + suffix := time.Now().UnixNano() + schema := fmt.Sprintf("xsql_schema_%d", suffix) + usersTable := "users" + ordersTable := "orders" + prefix := "xsql_" + + // 清理旧 schema + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schema)) + + // 创建 schema 与表 + _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA %s", schema)) + if err != nil { + t.Fatalf("create schema failed: %v", err) + } + + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE TABLE %s.%s ( + id BIGSERIAL PRIMARY KEY, + email TEXT NOT NULL, + created_at TIMESTAMPTZ NULL + ) + `, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("create users table failed: %v", err) + } + + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE INDEX idx_email ON %s.%s (email) + `, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("create index failed: %v", err) + } + + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE TABLE %s.%s ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + amount NUMERIC(10,2) NOT NULL, + CONSTRAINT fk_%s_user FOREIGN KEY (user_id) REFERENCES %s.%s(id) + ) + `, schema, prefix+ordersTable, prefix+ordersTable, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("create orders table failed: %v", err) + } + + t.Cleanup(func() { + _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schema)) + }) + + info, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{ + TablePattern: prefix + "*", + }) + if xe != nil { + t.Fatalf("DumpSchema error: %v", xe) + } + if info.Database == "" { + t.Fatalf("database name is empty") + } + + users := findTableWithSchema(info.Tables, schema, prefix+usersTable) + orders := findTableWithSchema(info.Tables, schema, prefix+ordersTable) + if users == nil || orders == nil { + t.Fatalf("missing tables in schema dump: users=%v orders=%v", users != nil, orders != nil) + } + + if !hasColumn(users, "id", true) { + t.Fatalf("users table missing primary key column 'id'") + } + if len(users.Indexes) == 0 { + t.Fatalf("users table should have indexes") + } + if !hasIndex(users, "idx_email") { + t.Fatalf("users table missing idx_email index") + } + + if len(orders.ForeignKeys) == 0 { + t.Fatalf("orders table should have foreign keys") + } + if !hasForeignKeyTo(orders, prefix+usersTable) { + t.Fatalf("orders table missing FK to %s", prefix+usersTable) + } +} + +func findTable(tables []db.Table, name string) *db.Table { + for i := range tables { + if tables[i].Name == name { + return &tables[i] + } + } + return nil +} + +func findTableWithSchema(tables []db.Table, schema, name string) *db.Table { + for i := range tables { + if tables[i].Schema == schema && tables[i].Name == name { + return &tables[i] + } + } + return nil +} + +func hasColumn(table *db.Table, name string, primary bool) bool { + for _, c := range table.Columns { + if c.Name == name && c.PrimaryKey == primary { + return true + } + } + return false +} + +func hasIndex(table *db.Table, indexName string) bool { + for _, idx := range table.Indexes { + if idx.Name == indexName { + return true + } + } + return false +} + +func hasForeignKeyTo(table *db.Table, referencedTable string) bool { + for _, fk := range table.ForeignKeys { + if strings.EqualFold(fk.ReferencedTable, referencedTable) { + return true + } + } + return false +} From df77492a36fdfe6429cc0a21d413fc49ff0e494a Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Wed, 11 Feb 2026 16:12:45 +0800 Subject: [PATCH 07/10] test: add tests --- internal/output/writer_test.go | 83 ++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/internal/output/writer_test.go b/internal/output/writer_test.go index 96afa1d..e49872b 100644 --- a/internal/output/writer_test.go +++ b/internal/output/writer_test.go @@ -704,3 +704,86 @@ func TestWriteOK_TableFormat_NilData(t *testing.T) { result := out.String() t.Logf("nil data output: %s", result) } + +type schemaFormatterData struct { + db string + tables []SchemaTable + ok bool +} + +func (s schemaFormatterData) ToSchemaData() (string, []SchemaTable, bool) { + return s.db, s.tables, s.ok +} + +func TestWriteOK_TableFormat_SchemaFormatter_WithColumns(t *testing.T) { + var out bytes.Buffer + w := New(&out, &bytes.Buffer{}) + + data := schemaFormatterData{ + db: "testdb", + tables: []SchemaTable{ + { + Schema: "public", + Name: "users", + Comment: "用户表", + Columns: []SchemaColumn{ + {Name: "id", Type: "bigint", Nullable: false, PrimaryKey: true}, + {Name: "email", Type: "varchar(255)", Nullable: true}, + }, + }, + }, + ok: true, + } + + if err := w.WriteOK(FormatTable, data); err != nil { + t.Fatal(err) + } + + result := out.String() + if !strings.Contains(result, "Database: testdb") { + t.Errorf("schema table output should include database name, got: %s", result) + } + if !strings.Contains(result, "Table: public.users (用户表)") { + t.Errorf("schema table output should include table header with schema and comment, got: %s", result) + } + if !strings.Contains(result, "Columns:") { + t.Errorf("schema table output should include columns section, got: %s", result) + } + if !strings.Contains(result, "✓") { + t.Errorf("schema table output should include primary key marker, got: %s", result) + } + if !strings.Contains(result, "(1 table)") { + t.Errorf("schema table output should include table count, got: %s", result) + } +} + +func TestWriteOK_TableFormat_SchemaFormatter_NoColumns_SchemaEqualsDB(t *testing.T) { + var out bytes.Buffer + w := New(&out, &bytes.Buffer{}) + + data := schemaFormatterData{ + db: "testdb", + tables: []SchemaTable{ + { + Schema: "testdb", + Name: "users", + }, + }, + ok: true, + } + + if err := w.WriteOK(FormatTable, data); err != nil { + t.Fatal(err) + } + + result := out.String() + if !strings.Contains(result, "Table: users") { + t.Errorf("schema table output should omit schema when it matches database, got: %s", result) + } + if strings.Contains(result, "Columns:") { + t.Errorf("schema table output should not include columns section when empty, got: %s", result) + } + if !strings.Contains(result, "(1 table)") { + t.Errorf("schema table output should include table count, got: %s", result) + } +} From 0af664bbe70f9299afae1c5f270699bac6b75a89 Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Wed, 11 Feb 2026 16:25:52 +0800 Subject: [PATCH 08/10] test: improve e2e and integration test --- tests/e2e/schema_test.go | 126 ++++++++++++++++++++++++++ tests/integration/schema_dump_test.go | 71 ++++++++++++++- 2 files changed, 192 insertions(+), 5 deletions(-) diff --git a/tests/e2e/schema_test.go b/tests/e2e/schema_test.go index 381acab..29fe16e 100644 --- a/tests/e2e/schema_test.go +++ b/tests/e2e/schema_test.go @@ -240,6 +240,132 @@ func TestSchema_Command(t *testing.T) { } } +func TestSchemaDump_MissingDBType(t *testing.T) { + config := createTempConfig(t, `profiles: + dev: + host: 127.0.0.1 +`) + stdout, _, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "json") + + if exitCode != 2 { + t.Errorf("exit code = %d, want 2", exitCode) + } + + var resp struct { + OK bool `json:"ok"` + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("failed to parse JSON: %v", err) + } + if resp.OK { + t.Error("expected ok=false") + } + if resp.Error.Code == "" { + t.Error("error code is empty") + } +} + +func TestSchemaDump_PlaintextPasswordNotAllowed(t *testing.T) { + config := createTempConfig(t, fmt.Sprintf(`profiles: + dev: + description: "开发环境" + db: mysql + dsn: "%s" + password: "plain_password" +`, mysqlDSN(t))) + + stdout, _, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "json") + + if exitCode != 2 { + t.Errorf("exit code = %d, want 2", exitCode) + } + + var resp struct { + OK bool `json:"ok"` + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("failed to parse JSON: %v", err) + } + if resp.OK { + t.Error("expected ok=false") + } + if resp.Error.Code == "" { + t.Error("error code is empty") + } +} + +func TestSchemaDump_InvalidFormat(t *testing.T) { + config := createTempConfig(t, fmt.Sprintf(`profiles: + dev: + description: "开发环境" + db: mysql + dsn: "%s" +`, mysqlDSN(t))) + + stdout, _, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "invalid") + + if exitCode != 2 { + t.Errorf("exit code = %d, want 2", exitCode) + } + + var resp struct { + OK bool `json:"ok"` + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("failed to parse JSON: %v", err) + } + if resp.OK { + t.Error("expected ok=false") + } + if resp.Error.Code == "" { + t.Error("error code is empty") + } +} + +func TestSchemaDump_UnsupportedDriver(t *testing.T) { + config := createTempConfig(t, fmt.Sprintf(`profiles: + dev: + description: "开发环境" + db: sqlite + dsn: "%s" +`, mysqlDSN(t))) + + stdout, _, exitCode := runXSQL(t, "schema", "dump", "--config", config, "-p", "dev", "-f", "json") + + if exitCode == 0 { + t.Errorf("exit code = %d, want non-zero", exitCode) + } + + var resp struct { + OK bool `json:"ok"` + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("failed to parse JSON: %v", err) + } + if resp.OK { + t.Error("expected ok=false") + } + if resp.Error.Code == "" { + t.Error("error code is empty") + } +} + // Helper function to check if string contains substring func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) diff --git a/tests/integration/schema_dump_test.go b/tests/integration/schema_dump_test.go index 2209412..1ae437a 100644 --- a/tests/integration/schema_dump_test.go +++ b/tests/integration/schema_dump_test.go @@ -44,15 +44,16 @@ func TestSchemaDump_MySQL_RealDB(t *testing.T) { _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", ordersTable)) _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", usersTable)) - // 创建表结构 + // 创建表结构(包含注释与默认值) _, err := conn.ExecContext(ctx, fmt.Sprintf(` CREATE TABLE %s ( - id BIGINT PRIMARY KEY, + id BIGINT PRIMARY KEY COMMENT '主键', email VARCHAR(255) NOT NULL, tenant_id BIGINT NOT NULL, - created_at DATETIME NULL, + status VARCHAR(20) NOT NULL DEFAULT 'active' COMMENT '状态', + created_at DATETIME NULL DEFAULT CURRENT_TIMESTAMP, INDEX idx_email (email) - ) ENGINE=InnoDB + ) ENGINE=InnoDB COMMENT='用户表' `, usersTable)) if err != nil { t.Fatalf("create users table failed: %v", err) @@ -108,6 +109,20 @@ func TestSchemaDump_MySQL_RealDB(t *testing.T) { t.Fatalf("users table missing idx_email index") } + if !hasColumnComment(users, "id", "主键") { + t.Fatalf("users table column 'id' missing comment") + } + if !hasColumnComment(users, "status", "状态") { + t.Fatalf("users table column 'status' missing comment") + } + if !hasColumnDefault(users, "status", "active") { + t.Fatalf("users table column 'status' missing default value") + } + + if users.Comment != "用户表" { + t.Fatalf("users table missing comment") + } + if len(orders.ForeignKeys) == 0 { t.Fatalf("orders table should have foreign keys") } @@ -155,13 +170,27 @@ func TestSchemaDump_Pg_RealDB(t *testing.T) { CREATE TABLE %s.%s ( id BIGSERIAL PRIMARY KEY, email TEXT NOT NULL, - created_at TIMESTAMPTZ NULL + status TEXT NOT NULL DEFAULT 'active', + created_at TIMESTAMPTZ NULL DEFAULT NOW() ) `, schema, prefix+usersTable)) if err != nil { t.Fatalf("create users table failed: %v", err) } + _, err = conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON TABLE %s.%s IS '用户表'`, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("comment table failed: %v", err) + } + _, err = conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON COLUMN %s.%s.id IS '主键'`, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("comment column failed: %v", err) + } + _, err = conn.ExecContext(ctx, fmt.Sprintf(`COMMENT ON COLUMN %s.%s.status IS '状态'`, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("comment column failed: %v", err) + } + _, err = conn.ExecContext(ctx, fmt.Sprintf(` CREATE INDEX idx_email ON %s.%s (email) `, schema, prefix+usersTable)) @@ -211,6 +240,20 @@ func TestSchemaDump_Pg_RealDB(t *testing.T) { t.Fatalf("users table missing idx_email index") } + if !hasColumnDefault(users, "status", "active") { + t.Fatalf("users table column 'status' missing default value") + } + + if users.Comment != "用户表" { + t.Fatalf("users table missing comment") + } + if !hasColumnComment(users, "id", "主键") { + t.Fatalf("users table column 'id' missing comment") + } + if !hasColumnComment(users, "status", "状态") { + t.Fatalf("users table column 'status' missing comment") + } + if len(orders.ForeignKeys) == 0 { t.Fatalf("orders table should have foreign keys") } @@ -263,3 +306,21 @@ func hasForeignKeyTo(table *db.Table, referencedTable string) bool { } return false } + +func hasColumnComment(table *db.Table, name, comment string) bool { + for _, c := range table.Columns { + if c.Name == name && c.Comment == comment { + return true + } + } + return false +} + +func hasColumnDefault(table *db.Table, name, want string) bool { + for _, c := range table.Columns { + if c.Name == name && strings.Contains(c.Default, want) { + return true + } + } + return false +} From 0f7549f93fb1931cbde9a60b2143236ebe272b88 Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Wed, 11 Feb 2026 16:47:05 +0800 Subject: [PATCH 09/10] test: add more tests --- cmd/xsql/command_unit_test.go | 113 ++++++++++++++++++++++++++ tests/integration/schema_dump_test.go | 108 ++++++++++++++++++++++-- 2 files changed, 216 insertions(+), 5 deletions(-) diff --git a/cmd/xsql/command_unit_test.go b/cmd/xsql/command_unit_test.go index d4d8b67..d0e6568 100644 --- a/cmd/xsql/command_unit_test.go +++ b/cmd/xsql/command_unit_test.go @@ -92,6 +92,119 @@ func TestRunQuery_MissingDB(t *testing.T) { } } +func TestRunQuery_UnsupportedDriver(t *testing.T) { + GlobalConfig.Resolved.Profile = configProfile("sqlite") + GlobalConfig.FormatStr = "json" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + err := runQuery(nil, []string{"select 1"}, &QueryFlags{}, &w) + if err == nil { + t.Fatal("expected error for unsupported driver") + } + if xe, ok := errors.As(err); !ok || xe.Code != errors.CodeDBDriverUnsupported { + t.Fatalf("expected CodeDBDriverUnsupported, got %v", err) + } +} + +func TestRunQuery_PlaintextPasswordNotAllowed(t *testing.T) { + GlobalConfig.Resolved.Profile = config.Profile{ + DB: "mysql", + Password: "plain_password", + AllowPlaintext: false, + } + GlobalConfig.FormatStr = "json" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + err := runQuery(nil, []string{"select 1"}, &QueryFlags{}, &w) + if err == nil { + t.Fatal("expected error for plaintext password not allowed") + } + if xe, ok := errors.As(err); !ok || xe.Code != errors.CodeCfgInvalid { + t.Fatalf("expected CodeCfgInvalid, got %v", err) + } +} + +func TestRunSchemaDump_UnsupportedDriver(t *testing.T) { + GlobalConfig.Resolved.Profile = configProfile("sqlite") + GlobalConfig.FormatStr = "json" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + err := runSchemaDump(nil, nil, &SchemaFlags{}, &w) + if err == nil { + t.Fatal("expected error for unsupported driver") + } + if xe, ok := errors.As(err); !ok || xe.Code != errors.CodeDBDriverUnsupported { + t.Fatalf("expected CodeDBDriverUnsupported, got %v", err) + } +} + +func TestRunSchemaDump_PlaintextPasswordNotAllowed(t *testing.T) { + GlobalConfig.Resolved.Profile = config.Profile{ + DB: "mysql", + Password: "plain_password", + AllowPlaintext: false, + } + GlobalConfig.FormatStr = "json" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + err := runSchemaDump(nil, nil, &SchemaFlags{}, &w) + if err == nil { + t.Fatal("expected error for plaintext password not allowed") + } + if xe, ok := errors.As(err); !ok || xe.Code != errors.CodeCfgInvalid { + t.Fatalf("expected CodeCfgInvalid, got %v", err) + } +} + +func TestRunQuery_InvalidFormat(t *testing.T) { + GlobalConfig.Resolved.Profile = configProfile("mysql") + GlobalConfig.FormatStr = "invalid" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + err := runQuery(nil, []string{"select 1"}, &QueryFlags{}, &w) + if err == nil { + t.Fatal("expected error for invalid format") + } + if xe, ok := errors.As(err); !ok || xe.Code != errors.CodeCfgInvalid { + t.Fatalf("expected CodeCfgInvalid, got %v", err) + } +} + +func TestRunSchemaDump_MissingDB(t *testing.T) { + GlobalConfig.Resolved.Profile = configProfile("") + GlobalConfig.FormatStr = "json" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + err := runSchemaDump(nil, nil, &SchemaFlags{}, &w) + if err == nil { + t.Fatal("expected error for missing db type") + } + if xe, ok := errors.As(err); !ok || xe.Code != errors.CodeCfgInvalid { + t.Fatalf("expected CodeCfgInvalid, got %v", err) + } +} + +func TestRunSchemaDump_InvalidFormat(t *testing.T) { + GlobalConfig.Resolved.Profile = configProfile("mysql") + GlobalConfig.FormatStr = "invalid" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + err := runSchemaDump(nil, nil, &SchemaFlags{}, &w) + if err == nil { + t.Fatal("expected error for invalid format") + } + if xe, ok := errors.As(err); !ok || xe.Code != errors.CodeCfgInvalid { + t.Fatalf("expected CodeCfgInvalid, got %v", err) + } +} + func TestRunProxy_ProfileRequired(t *testing.T) { GlobalConfig.ProfileStr = "" GlobalConfig.FormatStr = "json" diff --git a/tests/integration/schema_dump_test.go b/tests/integration/schema_dump_test.go index 1ae437a..feda791 100644 --- a/tests/integration/schema_dump_test.go +++ b/tests/integration/schema_dump_test.go @@ -52,7 +52,9 @@ func TestSchemaDump_MySQL_RealDB(t *testing.T) { tenant_id BIGINT NOT NULL, status VARCHAR(20) NOT NULL DEFAULT 'active' COMMENT '状态', created_at DATETIME NULL DEFAULT CURRENT_TIMESTAMP, - INDEX idx_email (email) + INDEX idx_email (email), + UNIQUE KEY uq_tenant_id (tenant_id, id), + INDEX idx_tenant_email (tenant_id, email) ) ENGINE=InnoDB COMMENT='用户表' `, usersTable)) if err != nil { @@ -62,9 +64,11 @@ func TestSchemaDump_MySQL_RealDB(t *testing.T) { _, err = conn.ExecContext(ctx, fmt.Sprintf(` CREATE TABLE %s ( id BIGINT PRIMARY KEY, + tenant_id BIGINT NOT NULL, user_id BIGINT NOT NULL, amount DECIMAL(10,2) NOT NULL, - CONSTRAINT fk_%s_user FOREIGN KEY (user_id) REFERENCES %s(id) + INDEX idx_tenant_user (tenant_id, user_id), + CONSTRAINT fk_%s_user FOREIGN KEY (tenant_id, user_id) REFERENCES %s(tenant_id, id) ) ENGINE=InnoDB `, ordersTable, ordersTable, usersTable)) if err != nil { @@ -86,6 +90,24 @@ func TestSchemaDump_MySQL_RealDB(t *testing.T) { t.Fatalf("database name is empty") } + infoNoFilter, xe := db.DumpSchema(ctx, "mysql", conn, db.SchemaOptions{}) + if xe != nil { + t.Fatalf("DumpSchema no-filter error: %v", xe) + } + if len(infoNoFilter.Tables) == 0 { + t.Fatalf("expected tables for no-filter dump") + } + + infoEmpty, xe := db.DumpSchema(ctx, "mysql", conn, db.SchemaOptions{ + TablePattern: "no_match_*", + }) + if xe != nil { + t.Fatalf("DumpSchema empty filter error: %v", xe) + } + if len(infoEmpty.Tables) != 0 { + t.Fatalf("expected empty tables for no_match_* filter") + } + users := findTable(info.Tables, usersTable) orders := findTable(info.Tables, ordersTable) if users == nil || orders == nil { @@ -108,6 +130,12 @@ func TestSchemaDump_MySQL_RealDB(t *testing.T) { if !hasIndex(users, "idx_email") { t.Fatalf("users table missing idx_email index") } + if !hasIndex(users, "uq_tenant_id") { + t.Fatalf("users table missing uq_tenant_id index") + } + if !hasIndex(users, "idx_tenant_email") { + t.Fatalf("users table missing idx_tenant_email index") + } if !hasColumnComment(users, "id", "主键") { t.Fatalf("users table column 'id' missing comment") @@ -123,12 +151,18 @@ func TestSchemaDump_MySQL_RealDB(t *testing.T) { t.Fatalf("users table missing comment") } + if !hasIndex(orders, "idx_tenant_user") { + t.Fatalf("orders table missing idx_tenant_user index") + } if len(orders.ForeignKeys) == 0 { t.Fatalf("orders table should have foreign keys") } if !hasForeignKeyTo(orders, usersTable) { t.Fatalf("orders table missing FK to %s", usersTable) } + if !hasCompositeForeignKeyTo(orders, usersTable) { + t.Fatalf("orders table missing composite FK to %s", usersTable) + } } func TestSchemaDump_Pg_RealDB(t *testing.T) { @@ -169,9 +203,11 @@ func TestSchemaDump_Pg_RealDB(t *testing.T) { _, err = conn.ExecContext(ctx, fmt.Sprintf(` CREATE TABLE %s.%s ( id BIGSERIAL PRIMARY KEY, - email TEXT NOT NULL, + tenant_id BIGINT NOT NULL, + email VARCHAR(255) NOT NULL, status TEXT NOT NULL DEFAULT 'active', - created_at TIMESTAMPTZ NULL DEFAULT NOW() + created_at TIMESTAMPTZ NULL DEFAULT NOW(), + UNIQUE (tenant_id, id) ) `, schema, prefix+usersTable)) if err != nil { @@ -197,18 +233,31 @@ func TestSchemaDump_Pg_RealDB(t *testing.T) { if err != nil { t.Fatalf("create index failed: %v", err) } + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE INDEX idx_tenant_email ON %s.%s (tenant_id, email) + `, schema, prefix+usersTable)) + if err != nil { + t.Fatalf("create index failed: %v", err) + } _, err = conn.ExecContext(ctx, fmt.Sprintf(` CREATE TABLE %s.%s ( id BIGSERIAL PRIMARY KEY, + tenant_id BIGINT NOT NULL, user_id BIGINT NOT NULL, amount NUMERIC(10,2) NOT NULL, - CONSTRAINT fk_%s_user FOREIGN KEY (user_id) REFERENCES %s.%s(id) + CONSTRAINT fk_%s_user FOREIGN KEY (tenant_id, user_id) REFERENCES %s.%s(tenant_id, id) ) `, schema, prefix+ordersTable, prefix+ordersTable, schema, prefix+usersTable)) if err != nil { t.Fatalf("create orders table failed: %v", err) } + _, err = conn.ExecContext(ctx, fmt.Sprintf(` + CREATE INDEX idx_tenant_user ON %s.%s (tenant_id, user_id) + `, schema, prefix+ordersTable)) + if err != nil { + t.Fatalf("create index failed: %v", err) + } t.Cleanup(func() { _, _ = conn.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schema)) @@ -224,6 +273,35 @@ func TestSchemaDump_Pg_RealDB(t *testing.T) { t.Fatalf("database name is empty") } + infoNoFilter, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{}) + if xe != nil { + t.Fatalf("DumpSchema no-filter error: %v", xe) + } + if len(infoNoFilter.Tables) == 0 { + t.Fatalf("expected tables for no-filter dump") + } + + infoWithSystem, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{ + TablePattern: prefix + "*", + IncludeSystem: true, + }) + if xe != nil { + t.Fatalf("DumpSchema include-system error: %v", xe) + } + if infoWithSystem.Database == "" { + t.Fatalf("database name is empty for include-system") + } + + infoEmpty, xe := db.DumpSchema(ctx, "pg", conn, db.SchemaOptions{ + TablePattern: "no_match_*", + }) + if xe != nil { + t.Fatalf("DumpSchema empty filter error: %v", xe) + } + if len(infoEmpty.Tables) != 0 { + t.Fatalf("expected empty tables for no_match_* filter") + } + users := findTableWithSchema(info.Tables, schema, prefix+usersTable) orders := findTableWithSchema(info.Tables, schema, prefix+ordersTable) if users == nil || orders == nil { @@ -239,6 +317,9 @@ func TestSchemaDump_Pg_RealDB(t *testing.T) { if !hasIndex(users, "idx_email") { t.Fatalf("users table missing idx_email index") } + if !hasIndex(users, "idx_tenant_email") { + t.Fatalf("users table missing idx_tenant_email index") + } if !hasColumnDefault(users, "status", "active") { t.Fatalf("users table column 'status' missing default value") @@ -254,12 +335,18 @@ func TestSchemaDump_Pg_RealDB(t *testing.T) { t.Fatalf("users table column 'status' missing comment") } + if !hasIndex(orders, "idx_tenant_user") { + t.Fatalf("orders table missing idx_tenant_user index") + } if len(orders.ForeignKeys) == 0 { t.Fatalf("orders table should have foreign keys") } if !hasForeignKeyTo(orders, prefix+usersTable) { t.Fatalf("orders table missing FK to %s", prefix+usersTable) } + if !hasCompositeForeignKeyTo(orders, prefix+usersTable) { + t.Fatalf("orders table missing composite FK to %s", prefix+usersTable) + } } func findTable(tables []db.Table, name string) *db.Table { @@ -307,6 +394,17 @@ func hasForeignKeyTo(table *db.Table, referencedTable string) bool { return false } +func hasCompositeForeignKeyTo(table *db.Table, referencedTable string) bool { + for _, fk := range table.ForeignKeys { + if strings.EqualFold(fk.ReferencedTable, referencedTable) && + len(fk.Columns) >= 2 && + len(fk.ReferencedColumns) >= 2 { + return true + } + } + return false +} + func hasColumnComment(table *db.Table, name, comment string) bool { for _, c := range table.Columns { if c.Name == name && c.Comment == comment { From d32d5eb0411dc19a5210a8471063cd7e495d75d9 Mon Sep 17 00:00:00 2001 From: zx06 <12474586+zx06@users.noreply.github.com> Date: Wed, 11 Feb 2026 17:10:52 +0800 Subject: [PATCH 10/10] test: add testcases --- cmd/xsql/command_unit_test.go | 18 ++ internal/mcp/tools_test.go | 105 +++++++ internal/output/writer_test.go | 100 +++++++ internal/secret/keyring_test.go | 502 +++++++++++++++++--------------- 4 files changed, 489 insertions(+), 236 deletions(-) diff --git a/cmd/xsql/command_unit_test.go b/cmd/xsql/command_unit_test.go index d0e6568..602acb8 100644 --- a/cmd/xsql/command_unit_test.go +++ b/cmd/xsql/command_unit_test.go @@ -599,6 +599,24 @@ profiles: } } +func TestValueIfSet(t *testing.T) { + if got := valueIfSet(false, "x"); got != "" { + t.Fatalf("expected empty when not set, got %q", got) + } + if got := valueIfSet(true, "x"); got != "x" { + t.Fatalf("expected value when set, got %q", got) + } +} + +func TestFirstNonEmpty(t *testing.T) { + if got := firstNonEmpty("", "", "a", "b"); got != "a" { + t.Fatalf("expected first non-empty value, got %q", got) + } + if got := firstNonEmpty("", ""); got != "" { + t.Fatalf("expected empty when all empty, got %q", got) + } +} + func configProfile(dbType string) config.Profile { return config.Profile{DB: dbType} } diff --git a/internal/mcp/tools_test.go b/internal/mcp/tools_test.go index 2168901..8ed722e 100644 --- a/internal/mcp/tools_test.go +++ b/internal/mcp/tools_test.go @@ -801,3 +801,108 @@ func TestQuery_AllFields(t *testing.T) { t.Errorf("expected Description=Full profile, got %s", profile.Description) } } + +func TestQueryHandler_InvalidJSON(t *testing.T) { + cfg := &config.File{ + Profiles: map[string]config.Profile{ + "dev": {DB: "mysql"}, + }, + } + handler := NewToolHandler(cfg) + + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Arguments: []byte("{"), + }, + } + + result, err := handler.queryHandler(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil || !result.IsError { + t.Fatalf("expected error result, got %+v", result) + } + if len(result.Content) == 0 { + t.Fatal("expected error content") + } + text := result.Content[0].(*mcp.TextContent).Text + if !strings.Contains(text, "CFG_INVALID") { + t.Fatalf("expected CFG_INVALID in output, got: %s", text) + } +} + +func TestProfileShowHandler_InvalidJSON(t *testing.T) { + cfg := &config.File{ + Profiles: map[string]config.Profile{ + "dev": {DB: "mysql"}, + }, + } + handler := NewToolHandler(cfg) + + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Arguments: []byte("{"), + }, + } + + result, err := handler.profileShowHandler(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil || !result.IsError { + t.Fatalf("expected error result, got %+v", result) + } + if len(result.Content) == 0 { + t.Fatal("expected error content") + } + text := result.Content[0].(*mcp.TextContent).Text + if !strings.Contains(text, "CFG_INVALID") { + t.Fatalf("expected CFG_INVALID in output, got: %s", text) + } +} + +func TestGetProfile_ResolvesSSHProxy(t *testing.T) { + cfg := &config.File{ + Profiles: map[string]config.Profile{ + "dev": {DB: "mysql", SSHProxy: "bastion"}, + }, + SSHProxies: map[string]config.SSHProxy{ + "bastion": {Host: "bastion.example.com", Port: 22, User: "bastion"}, + }, + } + + handler := NewToolHandler(cfg) + profile := handler.getProfile("dev") + if profile == nil { + t.Fatal("expected non-nil profile") + } + if profile.SSHConfig == nil { + t.Fatal("expected SSHConfig to be resolved") + } + if profile.SSHConfig.Host != "bastion.example.com" { + t.Fatalf("expected SSH host to be resolved, got %s", profile.SSHConfig.Host) + } +} + +func TestProfileList_EmptyProfiles(t *testing.T) { + cfg := &config.File{ + Profiles: map[string]config.Profile{}, + } + handler := NewToolHandler(cfg) + + result, _, err := handler.ProfileList(context.Background(), &mcp.CallToolRequest{}, struct{}{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil || result.IsError { + t.Fatalf("expected success result, got %+v", result) + } + if len(result.Content) == 0 { + t.Fatal("expected content") + } + text := result.Content[0].(*mcp.TextContent).Text + if !strings.Contains(text, "\"profiles\": []") { + t.Fatalf("expected empty profiles list, got: %s", text) + } +} diff --git a/internal/output/writer_test.go b/internal/output/writer_test.go index e49872b..c369e6f 100644 --- a/internal/output/writer_test.go +++ b/internal/output/writer_test.go @@ -787,3 +787,103 @@ func TestWriteOK_TableFormat_SchemaFormatter_NoColumns_SchemaEqualsDB(t *testing t.Errorf("schema table output should include table count, got: %s", result) } } + +func TestTryAsProfileList_ReflectMissingName(t *testing.T) { + type profileInfo struct { + Description string + DB string + Mode string + } + input := []profileInfo{ + {Description: "no name", DB: "mysql", Mode: "read-only"}, + } + if _, ok := tryAsProfileList(input); ok { + t.Fatal("expected ok=false for struct slice missing Name") + } +} + +func TestTryAsQueryResultReflect_RowMapNonStringKey(t *testing.T) { + type Result struct { + Columns []string + Rows []map[any]any + } + input := Result{ + Columns: []string{"id"}, + Rows: []map[any]any{{1: "bad"}}, + } + if _, ok := tryAsQueryResultReflect(input); ok { + t.Fatal("expected ok=false for non-string row map keys") + } +} + +func TestWriteTable_ErrorWithoutErrorObject(t *testing.T) { + var out bytes.Buffer + err := writeTable(&out, Envelope{OK: false, Error: nil}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out.Len() != 0 { + t.Fatalf("expected empty output, got: %s", out.String()) + } +} + +func TestWriteCSV_ErrorWithoutErrorObject(t *testing.T) { + var out bytes.Buffer + err := writeCSV(&out, Envelope{OK: false, Error: nil}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out.Len() != 0 { + t.Fatalf("expected empty output, got: %s", out.String()) + } +} + +func TestTryAsProfileList_ReflectStructSlice(t *testing.T) { + type profileInfo struct { + Name string + Description string + DB string + Mode string + } + input := []profileInfo{ + {Name: "dev", Description: "Dev", DB: "mysql", Mode: "read-only"}, + } + got, ok := tryAsProfileList(input) + if !ok { + t.Fatal("expected ok=true for struct slice") + } + if len(got) != 1 || got[0].Name != "dev" { + t.Fatalf("unexpected profile list: %+v", got) + } +} + +func TestTryAsQueryResultReflect_NonStringColumns(t *testing.T) { + type BadResult struct { + Columns []any + Rows []map[string]any + } + input := BadResult{ + Columns: []any{1}, + Rows: []map[string]any{{"id": 1}}, + } + if _, ok := tryAsQueryResultReflect(input); ok { + t.Fatal("expected ok=false for non-string columns") + } +} + +func TestWriteOK_TableFormat_JSONFallback(t *testing.T) { + var out bytes.Buffer + w := New(&out, &bytes.Buffer{}) + + type payload struct { + Foo string `json:"foo"` + } + if err := w.WriteOK(FormatTable, payload{Foo: "bar"}); err != nil { + t.Fatal(err) + } + + result := out.String() + if !strings.Contains(result, "\"foo\"") || !strings.Contains(result, "bar") { + t.Fatalf("expected JSON fallback output, got: %s", result) + } +} diff --git a/internal/secret/keyring_test.go b/internal/secret/keyring_test.go index 14e8399..9c6563b 100644 --- a/internal/secret/keyring_test.go +++ b/internal/secret/keyring_test.go @@ -1,236 +1,266 @@ -package secret - -import ( - "fmt" - "strings" - "testing" -) - -// nullByteKeyring 模拟 Windows cmdkey 返回带 null 字节的值 -type nullByteKeyring struct { - data map[string]map[string]string -} - -func newNullByteKeyring() *nullByteKeyring { - return &nullByteKeyring{data: make(map[string]map[string]string)} -} - -func (m *nullByteKeyring) set(service, account, value string) { - if m.data[service] == nil { - m.data[service] = make(map[string]string) - } - m.data[service][account] = value -} - -// setWithNullBytes 模拟 Windows UTF-16 问题:每个字符后插入 null 字节 -func (m *nullByteKeyring) setWithNullBytes(service, account, value string) { - var sb strings.Builder - for _, r := range value { - sb.WriteRune(r) - sb.WriteByte(0x00) - } - m.set(service, account, sb.String()) -} - -func (m *nullByteKeyring) Get(service, account string) (string, error) { - if svc, ok := m.data[service]; ok { - if v, ok := svc[account]; ok { - return v, nil - } - } - return "", fmt.Errorf("not found: %s/%s", service, account) -} - -func (m *nullByteKeyring) Set(service, account, value string) error { - m.set(service, account, value) - return nil -} - -func (m *nullByteKeyring) Delete(service, account string) error { - if svc, ok := m.data[service]; ok { - delete(svc, account) - } - return nil -} - -// ============================================================================= -// Windows null 字节处理测试 -// ============================================================================= - -func TestStripNullBytes(t *testing.T) { - tests := []struct { - name string - input string - want string - }{ - { - name: "no null bytes", - input: "password123", - want: "password123", - }, - { - name: "null bytes between chars", - input: "p\x00a\x00s\x00s\x00", - want: "pass", - }, - { - name: "full password with null bytes", - input: "m\x00y\x00P\x00a\x00s\x00s\x00w\x00o\x00r\x00d\x00", - want: "myPassword", - }, - { - name: "special chars with null bytes", - input: "p\x00@\x00s\x00s\x00!\x00#\x00", - want: "p@ss!#", - }, - { - name: "empty string", - input: "", - want: "", - }, - { - name: "only null bytes", - input: "\x00\x00\x00", - want: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := strings.ReplaceAll(tt.input, "\x00", "") - if got != tt.want { - t.Errorf("stripNullBytes(%q) = %q, want %q", tt.input, got, tt.want) - } - }) - } -} - -func TestNullByteKeyring_SimulatesWindowsBehavior(t *testing.T) { - kr := newNullByteKeyring() - kr.setWithNullBytes("xsql", "prod/password", "secret123") - - val, err := kr.Get("xsql", "prod/password") - if err != nil { - t.Fatalf("Get failed: %v", err) - } - - // 原始值应该包含 null 字节 - if !strings.Contains(val, "\x00") { - t.Error("Expected value to contain null bytes") - } - - // 清理后应该等于原始密码 - cleaned := strings.ReplaceAll(val, "\x00", "") - if cleaned != "secret123" { - t.Errorf("Cleaned value = %q, want %q", cleaned, "secret123") - } -} - -// ============================================================================= -// KeyringAPI 接口合规性测试 -// ============================================================================= - -func TestKeyringAPI_Interface(t *testing.T) { - // 确保 mockKeyring 实现 KeyringAPI 接口 - var _ KeyringAPI = (*mockKeyring)(nil) - var _ KeyringAPI = (*nullByteKeyring)(nil) -} - -func TestKeyringAPI_ErrorCases(t *testing.T) { - kr := newMockKeyring() - - // 空 service - _, err := kr.Get("", "account") - if err == nil { - t.Error("Get with empty service should fail") - } - - // 空 account - _, err = kr.Get("service", "") - if err == nil { - t.Error("Get with empty account should fail") - } - - // 不存在的 service - _, err = kr.Get("nonexistent", "account") - if err == nil { - t.Error("Get with nonexistent service should fail") - } -} - -// ============================================================================= -// Resolve 与 Keyring 集成测试 -// ============================================================================= - -func TestResolve_WithNullByteKeyring(t *testing.T) { - kr := newNullByteKeyring() - // 模拟 Windows 返回带 null 字节的密码 - kr.set("xsql", "prod/password", "s\x00e\x00c\x00r\x00e\x00t\x00") - - // 注意:Resolve 直接使用 keyring 返回值,不做清理 - // 清理逻辑在 keyring_windows.go 的 osKeyring.Get 中 - val, xe := Resolve("keyring:prod/password", Options{Keyring: kr}) - if xe != nil { - t.Fatalf("Resolve failed: %v", xe) - } - - // 由于使用 mockKeyring,不会自动清理 null 字节 - // 这个测试验证 Resolve 正确传递值 - if !strings.Contains(val, "\x00") { - t.Log("Value does not contain null bytes (expected if using cleaned keyring)") - } -} - -func TestResolve_SpecialCharacters(t *testing.T) { - kr := newMockKeyring() - specialPasswords := []string{ - "p@ssw0rd!", - "pass#123$", - "密码123", - "пароль", - "パスワード", - "pass word", - "pass\ttab", - } - - for i, pw := range specialPasswords { - account := fmt.Sprintf("test%d", i) - kr.set("xsql", account, pw) - - val, xe := Resolve(fmt.Sprintf("keyring:%s", account), Options{Keyring: kr}) - if xe != nil { - t.Errorf("Resolve special password %q failed: %v", pw, xe) - continue - } - if val != pw { - t.Errorf("Resolve special password: got %q, want %q", val, pw) - } - } -} - -func TestResolve_EmptyPassword(t *testing.T) { - kr := newMockKeyring() - kr.set("xsql", "empty", "") - - val, xe := Resolve("keyring:empty", Options{Keyring: kr}) - if xe != nil { - t.Fatalf("Resolve failed: %v", xe) - } - if val != "" { - t.Errorf("Expected empty password, got %q", val) - } -} - -func TestResolve_LongPassword(t *testing.T) { - kr := newMockKeyring() - longPass := strings.Repeat("a", 1000) - kr.set("xsql", "long", longPass) - - val, xe := Resolve("keyring:long", Options{Keyring: kr}) - if xe != nil { - t.Fatalf("Resolve failed: %v", xe) - } - if val != longPass { - t.Errorf("Long password mismatch: got len=%d, want len=%d", len(val), len(longPass)) - } -} +package secret + +import ( + "fmt" + "runtime" + "strings" + "testing" + + "github.com/zalando/go-keyring" +) + +// nullByteKeyring 模拟 Windows cmdkey 返回带 null 字节的值 +type nullByteKeyring struct { + data map[string]map[string]string +} + +func newNullByteKeyring() *nullByteKeyring { + return &nullByteKeyring{data: make(map[string]map[string]string)} +} + +func (m *nullByteKeyring) set(service, account, value string) { + if m.data[service] == nil { + m.data[service] = make(map[string]string) + } + m.data[service][account] = value +} + +// setWithNullBytes 模拟 Windows UTF-16 问题:每个字符后插入 null 字节 +func (m *nullByteKeyring) setWithNullBytes(service, account, value string) { + var sb strings.Builder + for _, r := range value { + sb.WriteRune(r) + sb.WriteByte(0x00) + } + m.set(service, account, sb.String()) +} + +func (m *nullByteKeyring) Get(service, account string) (string, error) { + if svc, ok := m.data[service]; ok { + if v, ok := svc[account]; ok { + return v, nil + } + } + return "", fmt.Errorf("not found: %s/%s", service, account) +} + +func (m *nullByteKeyring) Set(service, account, value string) error { + m.set(service, account, value) + return nil +} + +func (m *nullByteKeyring) Delete(service, account string) error { + if svc, ok := m.data[service]; ok { + delete(svc, account) + } + return nil +} + +// ============================================================================= +// Windows null 字节处理测试 +// ============================================================================= + +func TestStripNullBytes(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "no null bytes", + input: "password123", + want: "password123", + }, + { + name: "null bytes between chars", + input: "p\x00a\x00s\x00s\x00", + want: "pass", + }, + { + name: "full password with null bytes", + input: "m\x00y\x00P\x00a\x00s\x00s\x00w\x00o\x00r\x00d\x00", + want: "myPassword", + }, + { + name: "special chars with null bytes", + input: "p\x00@\x00s\x00s\x00!\x00#\x00", + want: "p@ss!#", + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "only null bytes", + input: "\x00\x00\x00", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := strings.ReplaceAll(tt.input, "\x00", "") + if got != tt.want { + t.Errorf("stripNullBytes(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestNullByteKeyring_SimulatesWindowsBehavior(t *testing.T) { + kr := newNullByteKeyring() + kr.setWithNullBytes("xsql", "prod/password", "secret123") + + val, err := kr.Get("xsql", "prod/password") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + // 原始值应该包含 null 字节 + if !strings.Contains(val, "\x00") { + t.Error("Expected value to contain null bytes") + } + + // 清理后应该等于原始密码 + cleaned := strings.ReplaceAll(val, "\x00", "") + if cleaned != "secret123" { + t.Errorf("Cleaned value = %q, want %q", cleaned, "secret123") + } +} + +// ============================================================================= +// KeyringAPI 接口合规性测试 +// ============================================================================= + +func TestKeyringAPI_Interface(t *testing.T) { + // 确保 mockKeyring 实现 KeyringAPI 接口 + var _ KeyringAPI = (*mockKeyring)(nil) + var _ KeyringAPI = (*nullByteKeyring)(nil) +} + +func TestKeyringAPI_ErrorCases(t *testing.T) { + kr := newMockKeyring() + + // 空 service + _, err := kr.Get("", "account") + if err == nil { + t.Error("Get with empty service should fail") + } + + // 空 account + _, err = kr.Get("service", "") + if err == nil { + t.Error("Get with empty account should fail") + } + + // 不存在的 service + _, err = kr.Get("nonexistent", "account") + if err == nil { + t.Error("Get with nonexistent service should fail") + } +} + +// ============================================================================= +// Resolve 与 Keyring 集成测试 +// ============================================================================= + +func TestResolve_WithNullByteKeyring(t *testing.T) { + kr := newNullByteKeyring() + // 模拟 Windows 返回带 null 字节的密码 + kr.set("xsql", "prod/password", "s\x00e\x00c\x00r\x00e\x00t\x00") + + // 注意:Resolve 直接使用 keyring 返回值,不做清理 + // 清理逻辑在 keyring_windows.go 的 osKeyring.Get 中 + val, xe := Resolve("keyring:prod/password", Options{Keyring: kr}) + if xe != nil { + t.Fatalf("Resolve failed: %v", xe) + } + + // 由于使用 mockKeyring,不会自动清理 null 字节 + // 这个测试验证 Resolve 正确传递值 + if !strings.Contains(val, "\x00") { + t.Log("Value does not contain null bytes (expected if using cleaned keyring)") + } +} + +func TestResolve_SpecialCharacters(t *testing.T) { + kr := newMockKeyring() + specialPasswords := []string{ + "p@ssw0rd!", + "pass#123$", + "密码123", + "пароль", + "パスワード", + "pass word", + "pass\ttab", + } + + for i, pw := range specialPasswords { + account := fmt.Sprintf("test%d", i) + kr.set("xsql", account, pw) + + val, xe := Resolve(fmt.Sprintf("keyring:%s", account), Options{Keyring: kr}) + if xe != nil { + t.Errorf("Resolve special password %q failed: %v", pw, xe) + continue + } + if val != pw { + t.Errorf("Resolve special password: got %q, want %q", val, pw) + } + } +} + +func TestResolve_EmptyPassword(t *testing.T) { + kr := newMockKeyring() + kr.set("xsql", "empty", "") + + val, xe := Resolve("keyring:empty", Options{Keyring: kr}) + if xe != nil { + t.Fatalf("Resolve failed: %v", xe) + } + if val != "" { + t.Errorf("Expected empty password, got %q", val) + } +} + +func TestResolve_LongPassword(t *testing.T) { + kr := newMockKeyring() + longPass := strings.Repeat("a", 1000) + kr.set("xsql", "long", longPass) + + val, xe := Resolve("keyring:long", Options{Keyring: kr}) + if xe != nil { + t.Fatalf("Resolve failed: %v", xe) + } + if val != longPass { + t.Errorf("Long password mismatch: got len=%d, want len=%d", len(val), len(longPass)) + } +} + +func TestDefaultKeyring_NullByteBehavior(t *testing.T) { + keyring.MockInit() + kr := defaultKeyring() + service := "xsql-test" + account := "null-byte" + raw := "s\x00e\x00c\x00r\x00e\x00t\x00" + if err := kr.Set(service, account, raw); err != nil { + t.Fatalf("Set failed: %v", err) + } + got, err := kr.Get(service, account) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if runtime.GOOS == "windows" { + if strings.Contains(got, "\x00") { + t.Fatalf("expected null bytes to be stripped, got %q", got) + } + if got != "secret" { + t.Fatalf("expected cleaned value, got %q", got) + } + return + } + if got != raw { + t.Fatalf("expected raw value on non-windows, got %q", got) + } +}