diff --git a/database/dialect.go b/database/dialect.go index 9f138f560..bf0575f3f 100644 --- a/database/dialect.go +++ b/database/dialect.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "strings" "github.com/pressly/goose/v3/internal/dialect/dialectquery" ) @@ -26,6 +27,10 @@ const ( DialectStarrocks Dialect = "starrocks" ) +var ( + ClickhouseStore = &dialectquery.Clickhouse{} +) + // NewStore returns a new [Store] implementation for the given dialect. func NewStore(dialect Dialect, tablename string) (Store, error) { if tablename == "" { @@ -35,7 +40,7 @@ func NewStore(dialect Dialect, tablename string) (Store, error) { return nil, errors.New("dialect must not be empty") } lookup := map[Dialect]dialectquery.Querier{ - DialectClickHouse: &dialectquery.Clickhouse{}, + DialectClickHouse: ClickhouseStore, DialectMSSQL: &dialectquery.Sqlserver{}, DialectMySQL: &dialectquery.Mysql{}, DialectPostgres: &dialectquery.Postgres{}, @@ -69,9 +74,15 @@ func (s *store) Tablename() string { } func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error { - q := s.querier.CreateTable(s.tablename) - if _, err := db.ExecContext(ctx, q); err != nil { - return fmt.Errorf("failed to create version table %q: %w", s.tablename, err) + queries := strings.Split(s.querier.CreateTable(s.tablename), ";") + for _, q := range queries { + q = strings.TrimSpace(q) + if q == "" { + continue + } + if _, err := db.ExecContext(ctx, q); err != nil { + return err + } } return nil } diff --git a/internal/dialect/dialectquery/clickhouse.go b/internal/dialect/dialectquery/clickhouse.go index 723efd4cc..42a941517 100644 --- a/internal/dialect/dialectquery/clickhouse.go +++ b/internal/dialect/dialectquery/clickhouse.go @@ -1,21 +1,53 @@ package dialectquery -import "fmt" +import ( + "fmt" + "strings" +) -type Clickhouse struct{} +type Clickhouse struct { + ClusterName string +} var _ Querier = (*Clickhouse)(nil) func (c *Clickhouse) CreateTable(tableName string) string { - q := `CREATE TABLE IF NOT EXISTS %s ( + if c.ClusterName != "" { + var dbName string + split := strings.SplitN(tableName, ".", 2) + if len(split) != 2 { + dbName = "default" + } else { + dbName = split[0] + tableName = split[1] + } + + fullTableName := fmt.Sprintf("%s.%s", dbName, tableName) + const localPostfix = "_local_v1" + + return `CREATE TABLE IF NOT EXISTS ` + fullTableName + localPostfix + ` ON CLUSTER '` + c.ClusterName + `' ( version_id Int64, is_applied UInt8, date Date default now(), tstamp DateTime default now() - ) + ) + ENGINE = ReplicatedMergeTree(' + /clickhouse/{installation}/{cluster}/tables/{shard}/` + dbName + `/` + tableName + localPostfix + `', '{replica}') + ORDER BY (date); + + CREATE TABLE IF NOT EXISTS ` + fullTableName + ` ON CLUSTER '` + c.ClusterName + `' AS ` + fullTableName + localPostfix + ` + ENGINE = Distributed('` + c.ClusterName + `', ` + dbName + `, '` + tableName + localPostfix + `', rand()); + ORDER BY (date); ` + } + + return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( + version_id Int64, + is_applied UInt8, + date Date default now(), + tstamp DateTime default now() + ) ENGINE = MergeTree() - ORDER BY (date)` - return fmt.Sprintf(q, tableName) + ORDER BY (date)`, tableName) } func (c *Clickhouse) InsertVersion(tableName string) string { diff --git a/internal/dialect/store.go b/internal/dialect/store.go index e9b768f91..afb3a5047 100644 --- a/internal/dialect/store.go +++ b/internal/dialect/store.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "strings" "time" "github.com/pressly/goose/v3/internal/dialect/dialectquery" @@ -94,9 +95,17 @@ type store struct { var _ Store = (*store)(nil) func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error { - q := s.querier.CreateTable(tableName) - _, err := tx.ExecContext(ctx, q) - return err + queries := strings.Split(s.querier.CreateTable(tableName), ";") + for _, q := range queries { + q = strings.TrimSpace(q) + if q == "" { + continue + } + if _, err := tx.ExecContext(ctx, q); err != nil { + return err + } + } + return nil } func (s *store) InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error {