diff --git a/docker/scripts/mod_world_0.sql b/docker/scripts/mod_world_0.sql index 2b78d55c..a2509025 100644 --- a/docker/scripts/mod_world_0.sql +++ b/docker/scripts/mod_world_0.sql @@ -3441,6 +3441,10 @@ COMMIT; SET FOREIGN_KEY_CHECKS = 1; +DROP DATABASE IF EXISTS meta; +CREATE DATABASE IF NOT EXISTS meta; +USE meta; + CREATE TABLE `undo_log` ( `id` bigint NOT NULL AUTO_INCREMENT, `branch_id` bigint NOT NULL, diff --git a/docker/scripts/mod_world_1.sql b/docker/scripts/mod_world_1.sql index 167f20d5..930d6acb 100644 --- a/docker/scripts/mod_world_1.sql +++ b/docker/scripts/mod_world_1.sql @@ -3432,6 +3432,10 @@ COMMIT; SET FOREIGN_KEY_CHECKS = 1; +DROP DATABASE IF EXISTS meta; +CREATE DATABASE IF NOT EXISTS meta; +USE meta; + CREATE TABLE `undo_log` ( `id` bigint NOT NULL AUTO_INCREMENT, `branch_id` bigint NOT NULL, diff --git a/docker/scripts/range_world_0.sql b/docker/scripts/range_world_0.sql index 692f3468..2c047841 100644 --- a/docker/scripts/range_world_0.sql +++ b/docker/scripts/range_world_0.sql @@ -3901,6 +3901,10 @@ COMMIT; SET FOREIGN_KEY_CHECKS = 1; +DROP DATABASE IF EXISTS meta; +CREATE DATABASE IF NOT EXISTS meta; +USE meta; + CREATE TABLE `undo_log` ( `id` bigint NOT NULL AUTO_INCREMENT, `branch_id` bigint NOT NULL, diff --git a/docker/scripts/range_world_1.sql b/docker/scripts/range_world_1.sql index 6cff38ea..b1162584 100644 --- a/docker/scripts/range_world_1.sql +++ b/docker/scripts/range_world_1.sql @@ -2966,6 +2966,10 @@ COMMIT; SET FOREIGN_KEY_CHECKS = 1; +DROP DATABASE IF EXISTS meta; +CREATE DATABASE IF NOT EXISTS meta; +USE meta; + CREATE TABLE `undo_log` ( `id` bigint NOT NULL AUTO_INCREMENT, `branch_id` bigint NOT NULL, diff --git a/pkg/executor/sharding.go b/pkg/executor/sharding.go index 2f1d3b85..0b3ec704 100644 --- a/pkg/executor/sharding.go +++ b/pkg/executor/sharding.go @@ -38,6 +38,7 @@ import ( "github.com/cectc/dbpack/pkg/topo" "github.com/cectc/dbpack/pkg/tracing" "github.com/cectc/dbpack/third_party/parser/ast" + "github.com/cectc/dbpack/third_party/parser/format" ) type ShardingExecutor struct { @@ -207,15 +208,23 @@ func (executor *ShardingExecutor) ExecutorComQuery(ctx context.Context, sql stri } }() - var plan proto.Plan + var ( + plan proto.Plan + sb strings.Builder + ) - log.Debugf("query: %s", sql) connectionID := proto.ConnectionID(spanCtx) queryStmt := proto.QueryStmt(spanCtx) if queryStmt == nil { return nil, 0, errors.New("query stmt should not be nil") } + if err := queryStmt.Restore(format.NewRestoreCtx(constant.DBPackRestoreFormat, &sb)); err != nil { + return nil, 0, err + } + newSql := sb.String() + spanCtx = proto.WithSqlText(spanCtx, newSql) + log.Debugf("connectionID: %d, query: %s", connectionID, newSql) switch stmt := queryStmt.(type) { case *ast.SetStmt: if shouldStartTransaction(stmt) { @@ -278,7 +287,7 @@ func (executor *ShardingExecutor) ExecutorComQuery(ctx context.Context, sql stri case *ast.SelectStmt: if stmt.Fields != nil && len(stmt.Fields.Fields) > 0 { if _, ok := stmt.Fields.Fields[0].Expr.(*ast.VariableExpr); ok { - return executor.executors[0].Query(spanCtx, sql) + return executor.executors[0].Query(spanCtx, newSql) } } txi, ok := executor.localTransactionMap.Load(connectionID) diff --git a/pkg/listener/mysql.go b/pkg/listener/mysql.go index 02c3a4d9..401dc89d 100644 --- a/pkg/listener/mysql.go +++ b/pkg/listener/mysql.go @@ -518,7 +518,6 @@ func (l *MysqlListener) ExecuteCommand(ctx context.Context, c *mysql.Conn, data connectionID := proto.ConnectionID(ctx) l.executor.ConnectionClose(proto.WithConnectionID(ctx, connectionID)) log.Debugf("connection closed, id: %d", connectionID) - return errors.New("ComQuit") case constant.ComInitDB: db := string(data[1:]) c.RecycleReadPacket() @@ -549,8 +548,14 @@ func (l *MysqlListener) ExecuteCommand(ctx context.Context, c *mysql.Conn, data return nil } - if showStmt, ok := stmt.(*ast.ShowStmt); ok && showStmt.Tp == ast.ShowTables { - showStmt.DBName = c.Database() + if showStmt, ok := stmt.(*ast.ShowStmt); ok { + switch showStmt.Tp { + case ast.ShowTables, ast.ShowTableStatus, ast.ShowColumns, ast.ShowIndex, ast.ShowTriggers: + if misc.IsBlank(showStmt.DBName) { + showStmt.DBName = c.Database() + } + default: + } } if !misc.IsBlank(c.Database()) { diff --git a/pkg/plan/delete.go b/pkg/plan/delete.go index d485b4ae..9c9a7932 100644 --- a/pkg/plan/delete.go +++ b/pkg/plan/delete.go @@ -66,9 +66,10 @@ func (p *DeletePlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi return nil, 0, errors.WithStack(err) } } + schema := proto.Schema(ctx) for _, table := range p.Tables { sb.Reset() - if err = p.generate(&sb, table, hints...); err != nil { + if err = p.generate(&sb, schema, table, hints...); err != nil { return nil, 0, errors.Wrap(err, "failed to generate sql for delete") } sql := sb.String() @@ -114,7 +115,7 @@ func (p *DeletePlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi return mysqlResult, warnings, nil } -func (p *DeletePlan) generate(sb *strings.Builder, table string, hints ...*ast.TableOptimizerHint) error { +func (p *DeletePlan) generate(sb *strings.Builder, schema, table string, hints ...*ast.TableOptimizerHint) error { ctx := format.NewRestoreCtx(constant.DBPackRestoreFormat, sb) ctx.WriteKeyWord("DELETE ") @@ -133,7 +134,7 @@ func (p *DeletePlan) generate(sb *strings.Builder, table string, hints ...*ast.T } ctx.WriteKeyWord("FROM ") - ctx.WritePlain(table) + ctx.WritePlainf("%s.%s", schema, table) if p.Stmt.Where != nil { ctx.WriteKeyWord(" WHERE ") if err := p.Stmt.Where.Restore(ctx); err != nil { diff --git a/pkg/plan/delete_test.go b/pkg/plan/delete_test.go index b3818383..0a45203a 100644 --- a/pkg/plan/delete_test.go +++ b/pkg/plan/delete_test.go @@ -37,15 +37,15 @@ func TestDeleteOnSingleDBPlan(t *testing.T) { deleteSql: "delete from student where id in (?,?)", tables: []string{"student_1", "student_5"}, expectedGenerateSqls: []string{ - "DELETE FROM student_1 WHERE `id` IN (?,?)", - "DELETE FROM student_5 WHERE `id` IN (?,?)", + "DELETE FROM school.student_1 WHERE `id` IN (?,?)", + "DELETE FROM school.student_5 WHERE `id` IN (?,?)", }, }, { deleteSql: "delete from student where id = 9", tables: []string{"student_9"}, expectedGenerateSqls: []string{ - "DELETE FROM student_9 WHERE `id`=9", + "DELETE FROM school.student_9 WHERE `id`=9", }, }, } @@ -68,7 +68,7 @@ func TestDeleteOnSingleDBPlan(t *testing.T) { } for i, table := range plan.Tables { var sb strings.Builder - err := plan.generate(&sb, table) + err := plan.generate(&sb, "school", table) assert.Nil(t, err) assert.Equal(t, c.expectedGenerateSqls[i], sb.String()) } diff --git a/pkg/plan/insert.go b/pkg/plan/insert.go index 427c1380..756746ca 100644 --- a/pkg/plan/insert.go +++ b/pkg/plan/insert.go @@ -44,7 +44,8 @@ func (p *InsertPlan) Execute(ctx context.Context, _ ...*ast.TableOptimizerHint) tx proto.Tx err error ) - if err = p.generate(&sb); err != nil { + schema := proto.Schema(ctx) + if err = p.generate(&sb, schema); err != nil { return nil, 0, errors.WithStack(err) } sql := sb.String() @@ -77,13 +78,13 @@ func (p *InsertPlan) Execute(ctx context.Context, _ ...*ast.TableOptimizerHint) } } -func (p *InsertPlan) generate(sb *strings.Builder) (err error) { +func (p *InsertPlan) generate(sb *strings.Builder, schema string) (err error) { ctx := format.NewRestoreCtx(constant.DBPackRestoreFormat, sb) ctx.WriteKeyWord("INSERT ") ctx.WriteKeyWord("INTO ") - ctx.WritePlain(p.Table) + ctx.WritePlainf("%s.%s", schema, p.Table) ctx.WritePlain("(") columnLen := len(p.Columns) diff --git a/pkg/plan/insert_test.go b/pkg/plan/insert_test.go index e4664d16..58aeddb2 100644 --- a/pkg/plan/insert_test.go +++ b/pkg/plan/insert_test.go @@ -36,7 +36,7 @@ func TestInsertPlan(t *testing.T) { { insertSql: "insert into student(id, name, gender, age) values(?,?,?,?)", table: "student_5", - expectedGenerateSql: "INSERT INTO student_5(id,name,gender,age) VALUES (?,?,?,?)", + expectedGenerateSql: "INSERT INTO school.student_5(id,name,gender,age) VALUES (?,?,?,?)", }, } @@ -59,7 +59,7 @@ func TestInsertPlan(t *testing.T) { Executor: nil, } var sb strings.Builder - err = plan.generate(&sb) + err = plan.generate(&sb, "school") assert.Nil(t, err) assert.Equal(t, c.expectedGenerateSql, sb.String()) }) diff --git a/pkg/plan/query.go b/pkg/plan/query.go index 53ce73b3..d3942ddf 100644 --- a/pkg/plan/query.go +++ b/pkg/plan/query.go @@ -101,21 +101,22 @@ func (p *QueryOnSingleDBPlan) Execute(ctx context.Context, hints ...*ast.TableOp } func (p *QueryOnSingleDBPlan) generate(ctx context.Context, sb *strings.Builder, args *[]interface{}) (err error) { + schema := proto.Schema(ctx) stmtVal := deepcopy.Copy(p.Stmt) stmt := stmtVal.(*ast.SelectStmt) switch len(p.Tables) { case 0: - err = p.generateSelect("", stmt, sb, p.Limit) + err = p.generateSelect(schema, "", stmt, sb, p.Limit) p.appendArgs(args) case 1: // single shard table - err = p.generateSelect(p.Tables[0], stmt, sb, p.Limit) + err = p.generateSelect(schema, p.Tables[0], stmt, sb, p.Limit) p.appendArgs(args) default: sb.WriteString("SELECT * FROM (") sb.WriteByte('(') - if err = p.generateSelect(p.Tables[0], stmt, sb, p.Limit); err != nil { + if err = p.generateSelect(schema, p.Tables[0], stmt, sb, p.Limit); err != nil { return } sb.WriteByte(')') @@ -127,7 +128,7 @@ func (p *QueryOnSingleDBPlan) generate(ctx context.Context, sb *strings.Builder, sb.WriteString(" UNION ALL ") sb.WriteByte('(') - if err = p.generateSelect(p.Tables[i], stmt, sb, p.Limit); err != nil { + if err = p.generateSelect(schema, p.Tables[i], stmt, sb, p.Limit); err != nil { return } sb.WriteByte(')') @@ -242,11 +243,12 @@ func (p *QueryOnMultiDBPlan) Execute(ctx context.Context, _ ...*ast.TableOptimiz return result, warn, nil } -func (p *QueryOnSingleDBPlan) generateSelect(table string, stmt *ast.SelectStmt, sb *strings.Builder, limit *Limit) error { +func (p *QueryOnSingleDBPlan) generateSelect(schema, table string, stmt *ast.SelectStmt, sb *strings.Builder, limit *Limit) error { vi := &JoinVisitor{ fieldList: stmt.Fields, where: stmt.Where, orderBy: stmt.OrderBy, + schema: schema, table: table, algorithms: p.Algorithms, globalTables: p.GlobalTables, @@ -339,6 +341,7 @@ type JoinVisitor struct { where ast.ExprNode orderBy *ast.OrderByClause + schema string table string algorithms map[string]cond.ShardingAlgorithm globalTables map[string]bool @@ -410,11 +413,13 @@ func (s *JoinVisitor) Leave(n ast.Node) (node ast.Node, ok bool) { s.orderBy.Accept(visitor2) } } + secondTable.Schema = model.NewCIStr(s.schema) secondTable.Name = model.NewCIStr(joinTable) } } } } + firstTable.Schema = model.NewCIStr(s.schema) firstTable.Name = model.NewCIStr(s.table) return n, true } diff --git a/pkg/plan/query_test.go b/pkg/plan/query_test.go index 940534fd..4fb8125a 100644 --- a/pkg/plan/query_test.go +++ b/pkg/plan/query_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/cectc/dbpack/pkg/cond" + "github.com/cectc/dbpack/pkg/proto" "github.com/cectc/dbpack/pkg/topo" "github.com/cectc/dbpack/pkg/visitor" "github.com/cectc/dbpack/third_party/parser" @@ -43,24 +44,24 @@ func TestQueryOnSingleDBPlan(t *testing.T) { tables: []string{"student_1", "student_5"}, pk: "id", args: []interface{}{1, 5}, - expectedGenerateSql: "SELECT * FROM ((SELECT * FROM `student_1` WHERE `id` IN (?,?)) UNION ALL (SELECT * " + - "FROM `student_5` WHERE `id` IN (?,?))) t ORDER BY `t`.`id` ASC", + expectedGenerateSql: "SELECT * FROM ((SELECT * FROM `school`.`student_1` WHERE `id` IN (?,?)) UNION ALL " + + "(SELECT * FROM `school`.`student_5` WHERE `id` IN (?,?))) t ORDER BY `t`.`id` ASC", }, { selectSql: "select * from student where id in (?,?) order by id desc", tables: []string{"student_1", "student_5"}, pk: "id", args: []interface{}{1, 5}, - expectedGenerateSql: "SELECT * FROM ((SELECT * FROM `student_1` WHERE `id` IN (?,?) ORDER BY `id` DESC) " + - "UNION ALL (SELECT * FROM `student_5` WHERE `id` IN (?,?) ORDER BY `id` DESC)) t ORDER BY `t`.`id` DESC", + expectedGenerateSql: "SELECT * FROM ((SELECT * FROM `school`.`student_1` WHERE `id` IN (?,?) ORDER BY `id` DESC) " + + "UNION ALL (SELECT * FROM `school`.`student_5` WHERE `id` IN (?,?) ORDER BY `id` DESC)) t ORDER BY `t`.`id` DESC", }, { selectSql: "select * from student where id in (?,?) order by id desc limit ?, ?", tables: []string{"student_1", "student_5"}, pk: "id", args: []interface{}{1, 5, 1000, 20}, - expectedGenerateSql: "SELECT * FROM ((SELECT * FROM `student_1` WHERE `id` IN (?,?) ORDER BY `id` DESC " + - "LIMIT 1020) UNION ALL (SELECT * FROM `student_5` WHERE `id` IN (?,?) ORDER BY `id` DESC LIMIT 1020)) t ORDER BY `t`.`id` DESC", + expectedGenerateSql: "SELECT * FROM ((SELECT * FROM `school`.`student_1` WHERE `id` IN (?,?) ORDER BY `id` DESC " + + "LIMIT 1020) UNION ALL (SELECT * FROM `school`.`student_5` WHERE `id` IN (?,?) ORDER BY `id` DESC LIMIT 1020)) t ORDER BY `t`.`id` DESC", }, { selectSql: "select student.id, student.name, city.province from student left join city on city.name = " + @@ -68,10 +69,10 @@ func TestQueryOnSingleDBPlan(t *testing.T) { tables: []string{"student_1", "student_5"}, pk: "id", args: []interface{}{1, 5, 1000, 20}, - expectedGenerateSql: "SELECT * FROM ((SELECT `student_1`.`id`,`student_1`.`name`,`city`.`province` FROM `student_1` " + + expectedGenerateSql: "SELECT * FROM ((SELECT `student_1`.`id`,`student_1`.`name`,`city`.`province` FROM `school`.`student_1` " + "LEFT JOIN `city` ON `city`.`name`=`student_1`.`native_place` WHERE `student_1`.`id` IN (?,?) " + "ORDER BY `student_1`.`id` DESC LIMIT 1020) UNION ALL (SELECT `student_5`.`id`,`student_5`.`name`,`city`.`province` " + - "FROM `student_5` LEFT JOIN `city` ON `city`.`name`=`student_5`.`native_place` WHERE `student_5`.`id` IN (?,?) " + + "FROM `school`.`student_5` LEFT JOIN `city` ON `city`.`name`=`student_5`.`native_place` WHERE `student_5`.`id` IN (?,?) " + "ORDER BY `student_5`.`id` DESC LIMIT 1020)) t ORDER BY `t`.`id` DESC", }, { @@ -80,9 +81,9 @@ func TestQueryOnSingleDBPlan(t *testing.T) { tables: []string{"student_1", "student_5"}, pk: "id", args: []interface{}{1, 5, 1000, 20}, - expectedGenerateSql: "SELECT * FROM ((SELECT `s`.`id`,`s`.`name`,`city`.`province` FROM `student_1` AS `s` " + + expectedGenerateSql: "SELECT * FROM ((SELECT `s`.`id`,`s`.`name`,`city`.`province` FROM `school`.`student_1` AS `s` " + "LEFT JOIN `city` ON `city`.`name`=`s`.`native_place` WHERE `s`.`id` IN (?,?) ORDER BY `s`.`id` DESC LIMIT 1020) " + - "UNION ALL (SELECT `s`.`id`,`s`.`name`,`city`.`province` FROM `student_5` AS `s` LEFT JOIN `city` " + + "UNION ALL (SELECT `s`.`id`,`s`.`name`,`city`.`province` FROM `school`.`student_5` AS `s` LEFT JOIN `city` " + "ON `city`.`name`=`s`.`native_place` WHERE `s`.`id` IN (?,?) ORDER BY `s`.`id` DESC LIMIT 1020)) t ORDER BY `t`.`id` DESC", }, { @@ -91,10 +92,10 @@ func TestQueryOnSingleDBPlan(t *testing.T) { tables: []string{"student_1", "student_5"}, pk: "id", args: []interface{}{1, 5, 1000, 20}, - expectedGenerateSql: "SELECT * FROM ((SELECT `student_1`.`id`,`student_1`.`name`,`exam_1`.`grade` FROM `student_1` " + - "LEFT JOIN `exam_1` ON `exam_1`.`student_id`=`student_1`.`id` WHERE `student_1`.`id` IN (?,?) " + + expectedGenerateSql: "SELECT * FROM ((SELECT `student_1`.`id`,`student_1`.`name`,`exam_1`.`grade` FROM `school`.`student_1` " + + "LEFT JOIN `school`.`exam_1` ON `exam_1`.`student_id`=`student_1`.`id` WHERE `student_1`.`id` IN (?,?) " + "ORDER BY `student_1`.`id` DESC LIMIT 1020) UNION ALL (SELECT `student_5`.`id`,`student_5`.`name`,`exam_5`.`grade` " + - "FROM `student_5` LEFT JOIN `exam_5` ON `exam_5`.`student_id`=`student_5`.`id` WHERE `student_5`.`id` IN (?,?) " + + "FROM `school`.`student_5` LEFT JOIN `school`.`exam_5` ON `exam_5`.`student_id`=`student_5`.`id` WHERE `student_5`.`id` IN (?,?) " + "ORDER BY `student_5`.`id` DESC LIMIT 1020)) t ORDER BY `t`.`id` DESC", }, { @@ -103,9 +104,9 @@ func TestQueryOnSingleDBPlan(t *testing.T) { tables: []string{"student_1", "student_5"}, pk: "id", args: []interface{}{1, 5, 1000, 20}, - expectedGenerateSql: "SELECT * FROM ((SELECT `s`.`id`,`s`.`name`,`e`.`grade` FROM `student_1` AS `s` " + - "LEFT JOIN `exam_1` AS `e` ON `e`.`student_id`=`s`.`id` WHERE `s`.`id` IN (?,?) ORDER BY `s`.`id` DESC LIMIT 1020) " + - "UNION ALL (SELECT `s`.`id`,`s`.`name`,`e`.`grade` FROM `student_5` AS `s` LEFT JOIN `exam_5` AS `e` " + + expectedGenerateSql: "SELECT * FROM ((SELECT `s`.`id`,`s`.`name`,`e`.`grade` FROM `school`.`student_1` AS `s` " + + "LEFT JOIN `school`.`exam_1` AS `e` ON `e`.`student_id`=`s`.`id` WHERE `s`.`id` IN (?,?) ORDER BY `s`.`id` DESC LIMIT 1020) " + + "UNION ALL (SELECT `s`.`id`,`s`.`name`,`e`.`grade` FROM `school`.`student_5` AS `s` LEFT JOIN `school`.`exam_5` AS `e` " + "ON `e`.`student_id`=`s`.`id` WHERE `s`.`id` IN (?,?) ORDER BY `s`.`id` DESC LIMIT 1020)) t ORDER BY `t`.`id` DESC", }, } @@ -137,7 +138,8 @@ func TestQueryOnSingleDBPlan(t *testing.T) { args []interface{} ) plan.castLimit() - err = plan.generate(context.Background(), &sb, &args) + ctx := proto.WithSchema(context.Background(), "school") + err = plan.generate(ctx, &sb, &args) assert.Nil(t, err) assert.Equal(t, c.expectedGenerateSql, sb.String()) }) diff --git a/pkg/plan/update.go b/pkg/plan/update.go index ef564ae6..49ff1044 100644 --- a/pkg/plan/update.go +++ b/pkg/plan/update.go @@ -66,9 +66,10 @@ func (p *UpdatePlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi return nil, 0, errors.WithStack(err) } } + schema := proto.Schema(ctx) for _, table := range p.Tables { sb.Reset() - if err = p.generate(&sb, table, hints...); err != nil { + if err = p.generate(&sb, schema, table, hints...); err != nil { return nil, 0, errors.Wrap(err, "failed to generate sql") } sql := sb.String() @@ -114,7 +115,7 @@ func (p *UpdatePlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi return mysqlResult, warnings, nil } -func (p *UpdatePlan) generate(sb *strings.Builder, table string, hints ...*ast.TableOptimizerHint) error { +func (p *UpdatePlan) generate(sb *strings.Builder, schema, table string, hints ...*ast.TableOptimizerHint) error { ctx := format.NewRestoreCtx(constant.DBPackRestoreFormat, sb) ctx.WriteKeyWord("UPDATE ") @@ -132,7 +133,7 @@ func (p *UpdatePlan) generate(sb *strings.Builder, table string, hints ...*ast.T ctx.WritePlain("*/ ") } - ctx.WritePlain(table) + ctx.WritePlainf("%s.%s", schema, table) ctx.WriteKeyWord(" SET ") for i, assignment := range p.Stmt.List { if i != 0 { diff --git a/pkg/plan/update_test.go b/pkg/plan/update_test.go index 6146cd5b..38acb087 100644 --- a/pkg/plan/update_test.go +++ b/pkg/plan/update_test.go @@ -37,15 +37,15 @@ func TestUpdateOnSingleDBPlan(t *testing.T) { deleteSql: "update student set name = ?, age = ? where id in (?,?)", tables: []string{"student_1", "student_5"}, expectedGenerateSqls: []string{ - "UPDATE student_1 SET `name`=?, `age`=? WHERE `id` IN (?,?)", - "UPDATE student_5 SET `name`=?, `age`=? WHERE `id` IN (?,?)", + "UPDATE school.student_1 SET `name`=?, `age`=? WHERE `id` IN (?,?)", + "UPDATE school.student_5 SET `name`=?, `age`=? WHERE `id` IN (?,?)", }, }, { deleteSql: "update student set name = ?, age = ? where id = 9", tables: []string{"student_9"}, expectedGenerateSqls: []string{ - "UPDATE student_9 SET `name`=?, `age`=? WHERE `id`=9", + "UPDATE school.student_9 SET `name`=?, `age`=? WHERE `id`=9", }, }, } @@ -68,7 +68,7 @@ func TestUpdateOnSingleDBPlan(t *testing.T) { } for i, table := range plan.Tables { var sb strings.Builder - err := plan.generate(&sb, table) + err := plan.generate(&sb, "school", table) assert.Nil(t, err) assert.Equal(t, c.expectedGenerateSqls[i], sb.String()) } diff --git a/pkg/sql/db.go b/pkg/sql/db.go index 20bd9a19..c3e8161a 100644 --- a/pkg/sql/db.go +++ b/pkg/sql/db.go @@ -19,6 +19,7 @@ package sql import ( "context" "fmt" + "strings" "time" "github.com/pkg/errors" @@ -31,6 +32,7 @@ import ( "github.com/cectc/dbpack/pkg/misc" "github.com/cectc/dbpack/pkg/proto" "github.com/cectc/dbpack/pkg/tracing" + "github.com/cectc/dbpack/third_party/parser/format" "github.com/cectc/dbpack/third_party/pools" ) @@ -363,7 +365,11 @@ func (db *DB) QueryDirectly(query string) (proto.Result, uint16, error) { } func (db *DB) ExecuteStmt(ctx context.Context, stmt *proto.Stmt) (proto.Result, uint16, error) { - query := stmt.StmtNode.Text() + var sb strings.Builder + if err := stmt.StmtNode.Restore(format.NewRestoreCtx(constant.DBPackRestoreFormat, &sb)); err != nil { + return nil, 0, err + } + query := sb.String() spanCtx, span := tracing.GetTraceSpan(ctx, tracing.DBExecStmt) span.SetAttributes(attribute.KeyValue{Key: "db", Value: attribute.StringValue(db.name)}, attribute.KeyValue{Key: "sql", Value: attribute.StringValue(query)}) diff --git a/pkg/sql/tx.go b/pkg/sql/tx.go index 427c654c..2902d130 100644 --- a/pkg/sql/tx.go +++ b/pkg/sql/tx.go @@ -19,6 +19,7 @@ package sql import ( "context" "fmt" + "strings" "github.com/uber-go/atomic" "go.opentelemetry.io/otel/attribute" @@ -29,6 +30,7 @@ import ( "github.com/cectc/dbpack/pkg/proto" "github.com/cectc/dbpack/pkg/tracing" "github.com/cectc/dbpack/third_party/parser/ast" + "github.com/cectc/dbpack/third_party/parser/format" ) type Tx struct { @@ -69,7 +71,11 @@ func (tx *Tx) QueryDirectly(query string) (proto.Result, uint16, error) { } func (tx *Tx) ExecuteStmt(ctx context.Context, stmt *proto.Stmt) (proto.Result, uint16, error) { - query := stmt.StmtNode.Text() + var sb strings.Builder + if err := stmt.StmtNode.Restore(format.NewRestoreCtx(constant.DBPackRestoreFormat, &sb)); err != nil { + return nil, 0, err + } + query := sb.String() spanCtx, span := tracing.GetTraceSpan(ctx, tracing.TxExecStmt) span.SetAttributes(attribute.KeyValue{Key: "db", Value: attribute.StringValue(tx.db.name)}, attribute.KeyValue{Key: "sql", Value: attribute.StringValue(query)}) diff --git a/test/rws/read_write_splitting_test.go b/test/rws/read_write_splitting_test.go index 9fe3f6ce..99ae0511 100644 --- a/test/rws/read_write_splitting_test.go +++ b/test/rws/read_write_splitting_test.go @@ -110,6 +110,8 @@ func (suite *_ReadWriteSplittingSuite) TestDelete() { } suite.Equal(0, exists) } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ReadWriteSplittingSuite) TestInsert() { @@ -134,6 +136,8 @@ func (suite *_ReadWriteSplittingSuite) TestInsert() { } suite.Equal("master", firstName) } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ReadWriteSplittingSuite) TestSelect1() { @@ -151,6 +155,8 @@ func (suite *_ReadWriteSplittingSuite) TestSelect1() { } suite.Equal("slave", firstName) } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ReadWriteSplittingSuite) TestInsertEncryption() { @@ -174,6 +180,8 @@ func (suite *_ReadWriteSplittingSuite) TestInsertEncryption() { suite.T().Logf("id: %d, dept name: %s", id, deptName) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ReadWriteSplittingSuite) TestSelect2() { @@ -191,6 +199,8 @@ func (suite *_ReadWriteSplittingSuite) TestSelect2() { } suite.Equal("master", firstName) } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ReadWriteSplittingSuite) TestUpdate() { @@ -215,6 +225,8 @@ func (suite *_ReadWriteSplittingSuite) TestUpdate() { } suite.Equal("louis", lastName) } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ReadWriteSplittingSuite) TestUpdateEncryption() { @@ -238,6 +250,8 @@ func (suite *_ReadWriteSplittingSuite) TestUpdateEncryption() { suite.T().Logf("id: %d, dept name: %s", id, deptName) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ReadWriteSplittingSuite) TestXATransaction() { @@ -259,7 +273,10 @@ func (suite *_ReadWriteSplittingSuite) TestXATransaction() { assert.Nil(suite.T(), err) _, err = conn.ExecContext(ctx, "XA COMMIT 'abc'") assert.Nil(suite.T(), err) + err = conn.Close() + assert.Nil(suite.T(), err) } func (suite *_ReadWriteSplittingSuite) TearDownSuite() { + suite.db.Close() } diff --git a/test/sdb/crud_test.go b/test/sdb/crud_test.go index 7ec1b0bd..bc9b2ed7 100644 --- a/test/sdb/crud_test.go +++ b/test/sdb/crud_test.go @@ -114,6 +114,8 @@ func (suite *_CRUDSuite) TestInsertEncryption() { suite.T().Logf("id: %d, dept name: %s", id, deptName) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_CRUDSuite) TestSelect() { @@ -131,6 +133,8 @@ func (suite *_CRUDSuite) TestSelect() { } suite.Equal("scott", firstName) } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_CRUDSuite) TestUpdate() { @@ -164,6 +168,8 @@ func (suite *_CRUDSuite) TestUpdateEncryption() { suite.T().Logf("id: %d, dept name: %s", id, deptName) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_CRUDSuite) TestXATransaction() { @@ -185,6 +191,8 @@ func (suite *_CRUDSuite) TestXATransaction() { assert.Nil(suite.T(), err) _, err = conn.ExecContext(ctx, "XA COMMIT 'abc'") assert.Nil(suite.T(), err) + err = conn.Close() + assert.Nil(suite.T(), err) } func (suite *_CRUDSuite) TearDownSuite() { @@ -195,4 +203,5 @@ func (suite *_CRUDSuite) TearDownSuite() { suite.Equal(int64(1), affected) } } + suite.db.Close() } diff --git a/test/sdb/distributed_transaction_test.go b/test/sdb/distributed_transaction_test.go index 6c3c6e79..67622e2d 100644 --- a/test/sdb/distributed_transaction_test.go +++ b/test/sdb/distributed_transaction_test.go @@ -243,6 +243,8 @@ func (suite *_DistributedTransactionSuite) TearDownSuite() { suite.db.Exec(deleteDeptEmpForDT, 1) suite.db.Exec(deleteSalariesForDT, 1) suite.db.Exec(deleteDeptManagerForDT, 1) - suite.db2.Exec(deleteSalariesForDT, 2) + + suite.db.Close() + suite.db2.Close() } diff --git a/test/shd_mod/sharding_test.go b/test/shd_mod/sharding_test.go index 2d991722..d1f68f74 100644 --- a/test/shd_mod/sharding_test.go +++ b/test/shd_mod/sharding_test.go @@ -86,6 +86,8 @@ func (suite *_ShardingSuite) TestSelect() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectLimit() { @@ -105,6 +107,8 @@ func (suite *_ShardingSuite) TestSelectLimit() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectOrderBy() { @@ -124,6 +128,8 @@ func (suite *_ShardingSuite) TestSelectOrderBy() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectOrderBy2() { @@ -143,6 +149,8 @@ func (suite *_ShardingSuite) TestSelectOrderBy2() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectOrderByAndLimit() { @@ -162,6 +170,8 @@ func (suite *_ShardingSuite) TestSelectOrderByAndLimit() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectOrderByAndLimit2() { @@ -181,6 +191,8 @@ func (suite *_ShardingSuite) TestSelectOrderByAndLimit2() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectJoin1() { @@ -201,6 +213,8 @@ func (suite *_ShardingSuite) TestSelectJoin1() { id, name, countryCode, district, population, countryName) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectJoin2() { @@ -221,6 +235,8 @@ func (suite *_ShardingSuite) TestSelectJoin2() { id, name, countryCode, district, population, countryName) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectCount() { @@ -235,6 +251,8 @@ func (suite *_ShardingSuite) TestSelectCount() { suite.T().Logf("count: %d", count) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowDatabases() { @@ -249,6 +267,8 @@ func (suite *_ShardingSuite) TestShowDatabases() { suite.T().Logf("database: %s", database) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowEngines() { @@ -263,6 +283,8 @@ func (suite *_ShardingSuite) TestShowEngines() { suite.T().Logf("%s %s %s %s %s %s", engine, support, comment, transactions, xa, savepoints) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowCreateDatabase() { @@ -289,6 +311,8 @@ func (suite *_ShardingSuite) TestShowCreateDatabase() { suite.T().Logf("%s %s", database, createDatabase) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowTableStatus() { @@ -326,6 +350,8 @@ func (suite *_ShardingSuite) TestShowTableStatus() { autoIncrement, createTime, updateTime, checkTime, collation, checkSum, createOption, comment) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowTables() { @@ -340,6 +366,8 @@ func (suite *_ShardingSuite) TestShowTables() { suite.T().Logf("%s", table) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowTableMeta() { @@ -371,6 +399,8 @@ func (suite *_ShardingSuite) TestShowTableMeta() { comment, index_comment, visible, expression) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestCreateIndexAndDropIndex() { @@ -478,6 +508,8 @@ func (suite *_ShardingSuite) TestExprShadow() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestHintShadow() { @@ -501,7 +533,10 @@ func (suite *_ShardingSuite) TestHintShadow() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TearDownSuite() { + suite.db.Close() } diff --git a/test/shd_range/sharding_test.go b/test/shd_range/sharding_test.go index 50dd926b..575d616f 100644 --- a/test/shd_range/sharding_test.go +++ b/test/shd_range/sharding_test.go @@ -86,6 +86,8 @@ func (suite *_ShardingSuite) TestSelect() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectLimit() { @@ -105,6 +107,8 @@ func (suite *_ShardingSuite) TestSelectLimit() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectOrderBy() { @@ -124,6 +128,8 @@ func (suite *_ShardingSuite) TestSelectOrderBy() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectOrderBy2() { @@ -143,6 +149,8 @@ func (suite *_ShardingSuite) TestSelectOrderBy2() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectOrderByAndLimit() { @@ -162,6 +170,8 @@ func (suite *_ShardingSuite) TestSelectOrderByAndLimit() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectOrderByAndLimit2() { @@ -181,6 +191,8 @@ func (suite *_ShardingSuite) TestSelectOrderByAndLimit2() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectJoin1() { @@ -201,6 +213,8 @@ func (suite *_ShardingSuite) TestSelectJoin1() { id, name, countryCode, district, population, countryName) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectJoin2() { @@ -221,6 +235,8 @@ func (suite *_ShardingSuite) TestSelectJoin2() { id, name, countryCode, district, population, countryName) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestSelectCount() { @@ -235,6 +251,8 @@ func (suite *_ShardingSuite) TestSelectCount() { suite.T().Logf("count: %d", count) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowDatabases() { @@ -249,6 +267,8 @@ func (suite *_ShardingSuite) TestShowDatabases() { suite.T().Logf("database: %s", database) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowEngines() { @@ -263,6 +283,8 @@ func (suite *_ShardingSuite) TestShowEngines() { suite.T().Logf("%s %s %s %s %s %s", engine, support, comment, transactions, xa, savepoints) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowCreateDatabase() { @@ -289,6 +311,8 @@ func (suite *_ShardingSuite) TestShowCreateDatabase() { suite.T().Logf("%s %s", database, createDatabase) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowTableStatus() { @@ -326,6 +350,8 @@ func (suite *_ShardingSuite) TestShowTableStatus() { autoIncrement, createTime, updateTime, checkTime, collation, checkSum, createOption, comment) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowTables() { @@ -340,6 +366,8 @@ func (suite *_ShardingSuite) TestShowTables() { suite.T().Logf("%s", table) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestShowTableMeta() { @@ -371,6 +399,8 @@ func (suite *_ShardingSuite) TestShowTableMeta() { comment, index_comment, visible, expression) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestCreateIndexAndDropIndex() { @@ -468,6 +498,8 @@ func (suite *_ShardingSuite) TestExprShadow() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TestHintShadow() { @@ -491,7 +523,10 @@ func (suite *_ShardingSuite) TestHintShadow() { id, name, countryCode, district, population) } } + err = rows.Close() + assert.Nil(suite.T(), err) } func (suite *_ShardingSuite) TearDownSuite() { + suite.db.Close() }