diff --git a/cmd/apply.go b/cmd/apply.go index 77ea4a6..cfaab62 100644 --- a/cmd/apply.go +++ b/cmd/apply.go @@ -12,6 +12,7 @@ import ( "strings" "github.com/golang-migrate/migrate/v4" + "github.com/printeers/trek/internal/configuration" internalpostgres "github.com/printeers/trek/internal/postgres" // needed driver. @@ -49,7 +50,7 @@ func NewApplyCommand() *cobra.Command { return fmt.Errorf("failed to get working directory: %w", err) } - config, err := internal.ReadConfig(wd) + config, err := configuration.ReadConfig(wd) if err != nil { return fmt.Errorf("failed to read config: %w", err) } @@ -95,17 +96,15 @@ func NewApplyCommand() *cobra.Command { } } - for _, u := range config.DatabaseUsers { - var userExists bool - userExists, err = internalpostgres.CheckUserExists(ctx, conn, u) + for _, role := range config.Roles { + var roleExists bool + roleExists, err = internalpostgres.CheckRoleExists(ctx, conn, role.Name) if err != nil { - return fmt.Errorf("failed to check if user exists: %w", err) + return fmt.Errorf("failed to check if role exists: %w", err) } - if !userExists { - _, err = conn.Exec(ctx, fmt.Sprintf("CREATE ROLE %q WITH LOGIN", u)) - if err != nil { - return fmt.Errorf("failed to create user: %w", err) - } + if !roleExists { + //nolint:err113 + return fmt.Errorf("expected role %q to exists but it does not exist", role.Name) } } @@ -194,10 +193,10 @@ func NewApplyCommand() *cobra.Command { return fmt.Errorf("failed to connect to database: %w", err) } - for _, u := range config.DatabaseUsers { - _, err = conn.Exec(ctx, fmt.Sprintf("GRANT SELECT ON public.schema_migrations TO %q", u)) + for _, role := range config.Roles { + _, err = conn.Exec(ctx, fmt.Sprintf("GRANT SELECT ON public.schema_migrations TO %q", role.Name)) if err != nil { - return fmt.Errorf("failed to grant select permission on schema_migrations to %q: %w", u, err) + return fmt.Errorf("failed to grant select permission on schema_migrations to role %q: %w", role.Name, err) } } diff --git a/cmd/check.go b/cmd/check.go index 40d8e06..787a099 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -15,7 +15,8 @@ import ( "github.com/golang-migrate/migrate/v4" "github.com/jackc/pgx/v5" - internalpostgres "github.com/printeers/trek/internal/postgres" + "github.com/printeers/trek/internal/configuration" + "github.com/printeers/trek/internal/postgres" // needed driver. _ "github.com/golang-migrate/migrate/v4/database/postgres" @@ -41,7 +42,7 @@ func NewCheckCommand() *cobra.Command { return fmt.Errorf("failed to get working directory: %w", err) } - config, err := internal.ReadConfig(wd) + config, err := configuration.ReadConfig(wd) if err != nil { return fmt.Errorf("failed to read config: %w", err) } @@ -61,35 +62,28 @@ func NewCheckCommand() *cobra.Command { //nolint:cyclop func checkAll( ctx context.Context, - config *internal.Config, + config *configuration.Config, wd, migrationsDir string, ) error { - postgres, err := setupDatabase(5434) + tmpPostgres, err := setupPostgresInstance(5434) if err != nil { - return fmt.Errorf("failed to setup database: %w", err) + return fmt.Errorf("failed to setup tmp database: %w", err) } - defer postgres.Stop() //nolint:errcheck + defer tmpPostgres.Stop() //nolint:errcheck - dsn := postgres.DSN("postgres") + tmpPostgresDSN := tmpPostgres.DSN("postgres") - conn, err := pgx.Connect(ctx, dsn) + conn, err := pgx.Connect(ctx, tmpPostgresDSN) if err != nil { - return fmt.Errorf("failed to connect to database: %w", err) + return fmt.Errorf("failed to connect to tmp database: %w", err) } defer conn.Close(ctx) - for _, u := range config.DatabaseUsers { - var userExists bool - userExists, err = internalpostgres.CheckUserExists(ctx, conn, u) + for _, role := range config.Roles { + _, err = conn.Exec(ctx, fmt.Sprintf("CREATE ROLE %q WITH LOGIN PASSWORD 'postgres'", role.Name)) if err != nil { - return fmt.Errorf("failed to check if user exists: %w", err) - } - if !userExists { - _, err = conn.Exec(ctx, fmt.Sprintf("CREATE ROLE %q WITH LOGIN PASSWORD 'postgres'", u)) - if err != nil { - return fmt.Errorf("failed to create user: %w", err) - } + return fmt.Errorf("failed to create role %q: %w", role.Name, err) } } @@ -137,18 +131,11 @@ func checkAll( log.Println("Checking migrations and testdata") - err = checkMigrationsAndTestdata(ctx, wd, migrationsDir, dsn, migrationFiles) + err = checkMigrationsAndTestdata(ctx, wd, migrationsDir, tmpPostgresDSN, migrationFiles) if err != nil { return fmt.Errorf("failed to check migrations and testdata: %w", err) } - for _, u := range config.DatabaseUsers { - _, err = conn.Exec(ctx, fmt.Sprintf("GRANT SELECT ON public.schema_migrations TO %q", u)) - if err != nil { - return fmt.Errorf("failed to grant select permission on schema_migrations to %q: %w", u, err) - } - } - err = internal.RunHook(ctx, wd, "check-post", hookOptions) if err != nil { return fmt.Errorf("failed to run hook: %w", err) @@ -157,7 +144,7 @@ func checkAll( return nil } -func checkDBM(config *internal.Config, wd string) error { +func checkDBM(config *configuration.Config, wd string) error { model := dbm.DBModel{} m, err := os.ReadFile(filepath.Join(wd, fmt.Sprintf("%s.dbm", config.ModelName))) @@ -170,31 +157,34 @@ func checkDBM(config *internal.Config, wd string) error { return fmt.Errorf("failed to parse model: %w", err) } - modelRoles := map[string]struct{}{} + modelRoles := map[string]dbm.Role{} for _, role := range model.Roles { - if !role.SQLDisabled { - //nolint:err113 - return fmt.Errorf("role %q has sql enabled", role.Name) - } - modelRoles[role.Name] = struct{}{} + modelRoles[role.Name] = role } - configRoles := map[string]struct{}{} - for _, role := range config.DatabaseUsers { - configRoles[role] = struct{}{} + configRoles := map[string]configuration.Role{} + for _, role := range config.Roles { + configRoles[role.Name] = role } for role := range modelRoles { if _, ok := configRoles[role]; !ok { //nolint:err113 - return fmt.Errorf("role %q defined in the model not defined in the config", role) + return fmt.Errorf("role %q is defined in the model but is not defined in the config", role) } } for role := range configRoles { if _, ok := modelRoles[role]; !ok { //nolint:err113 - return fmt.Errorf("role %q defined in the config not defined in the model", role) + return fmt.Errorf("role %q is defined in the config but is not defined in the model", role) + } + } + + for _, role := range model.Roles { + if !role.SQLDisabled { + //nolint:err113 + return fmt.Errorf("role %q is missing 'sql disabled' in the model (sql must not be generated for a role)", role.Name) } } @@ -250,7 +240,7 @@ func checkMigrationFileNames(migrationFiles []string) error { return nil } -func checkTemplates(config *internal.Config, migrationsCount uint) error { +func checkTemplates(config *configuration.Config, migrationsCount uint) error { for _, ts := range config.Templates { if _, err := os.Stat(ts.Path); errors.Is(err, os.ErrNotExist) { //nolint:err113 @@ -297,7 +287,7 @@ func checkMigrationsAndTestdata(ctx context.Context, wd, migrationsDir, dsn stri if strings.HasPrefix(path.Base(p), fmt.Sprintf("%03d", index+1)) { // We have to use psql, because users might use commands like "\copy" // which don't work by directly connecting to the database - err := internalpostgres.PsqlFile(ctx, dsn, p) + err := postgres.PsqlFile(ctx, dsn, p) if err != nil { //nolint:err113 return fmt.Errorf("failed to apply testdata: %w", err) diff --git a/cmd/generate.go b/cmd/generate.go index 642616b..7a850d8 100644 --- a/cmd/generate.go +++ b/cmd/generate.go @@ -16,7 +16,8 @@ import ( "github.com/spf13/cobra" "github.com/printeers/trek/internal" - internalpostgres "github.com/printeers/trek/internal/postgres" + "github.com/printeers/trek/internal/configuration" + "github.com/printeers/trek/internal/postgres" ) //nolint:gocognit,cyclop @@ -66,7 +67,7 @@ func NewGenerateCommand() *cobra.Command { return fmt.Errorf("failed to get working directory: %w", err) } - config, err := internal.ReadConfig(wd) + config, err := configuration.ReadConfig(wd) if err != nil { return fmt.Errorf("failed to read config: %w", err) } @@ -203,20 +204,20 @@ func NewGenerateCommand() *cobra.Command { return generateCmd } -func setupDatabase(port uint32) (internalpostgres.Database, error) { - postgres := internalpostgres.NewPostgresDatabase() - err := postgres.Start(port) +func setupPostgresInstance(port uint32) (postgres.Instance, error) { + pgInstance := postgres.NewPostgresInstance() + err := pgInstance.Start(port) if err != nil { return nil, fmt.Errorf("failed to start database: %w", err) } - return postgres, nil + return pgInstance, nil } //nolint:gocognit,cyclop func runWithStdout( ctx context.Context, - config *internal.Config, + config *configuration.Config, wd, tmpDir, migrationsDir string, @@ -227,39 +228,39 @@ func runWithStdout( return fmt.Errorf("failed to check if model has been updated: %w", err) } if updated { - postgres, err := setupDatabase(5432) + targetInstance, err := setupPostgresInstance(5432) if err != nil { - return fmt.Errorf("failed to setup database: %w", err) + return fmt.Errorf("failed to setup instance: %w", err) } - defer postgres.Stop() //nolint:errcheck + defer targetInstance.Stop() //nolint:errcheck - postgresConn, err := pgx.Connect(ctx, postgres.DSN("postgres")) + postgresConn, err := pgx.Connect(ctx, targetInstance.DSN("postgres")) if err != nil { - return fmt.Errorf("failed to connect to database: %w", err) + return fmt.Errorf("failed to connect to postgres database: %w", err) } defer postgresConn.Close(ctx) _, err = postgresConn.Exec(ctx, "CREATE DATABASE target;") if err != nil { - return fmt.Errorf("failed to create database: %w", err) + return fmt.Errorf("failed to create target database: %w", err) } - targetConn, err := pgx.Connect(ctx, postgres.DSN("target")) + targetConn, err := pgx.Connect(ctx, targetInstance.DSN("target")) if err != nil { - return fmt.Errorf("failed to connect to database: %w", err) + return fmt.Errorf("failed to connect to target database: %w", err) } defer targetConn.Close(ctx) _, err = postgresConn.Exec(ctx, "CREATE DATABASE migrate;") if err != nil { - return fmt.Errorf("failed to create database: %w", err) + return fmt.Errorf("failed to create migrate database: %w", err) } - migrateConn, err := pgx.Connect(ctx, postgres.DSN("migrate")) + migrateConn, err := pgx.Connect(ctx, targetInstance.DSN("migrate")) if err != nil { - return fmt.Errorf("failed to connect to database: %w", err) + return fmt.Errorf("failed to connect to migrate database: %w", err) } - defer targetConn.Close(ctx) + defer migrateConn.Close(ctx) statements, err := generateMigrationStatements( ctx, @@ -320,7 +321,7 @@ func runWithStdout( //nolint:gocognit,cyclop func runWithFile( ctx context.Context, - config *internal.Config, + config *configuration.Config, wd, tmpDir, migrationsDir, @@ -339,39 +340,39 @@ func runWithFile( } } - postgres, err := setupDatabase(5432) + postgresInstance, err := setupPostgresInstance(5432) if err != nil { - return false, fmt.Errorf("failed to setup database: %w", err) + return false, fmt.Errorf("failed to setup instance: %w", err) } - defer postgres.Stop() //nolint:errcheck + defer postgresInstance.Stop() //nolint:errcheck - postgresConn, err := pgx.Connect(ctx, postgres.DSN("postgres")) + postgresConn, err := pgx.Connect(ctx, postgresInstance.DSN("postgres")) if err != nil { - return false, fmt.Errorf("failed to connect to database: %w", err) + return false, fmt.Errorf("failed to connect to postgres database: %w", err) } defer postgresConn.Close(ctx) _, err = postgresConn.Exec(ctx, "CREATE DATABASE target;") if err != nil { - return false, fmt.Errorf("failed to create database: %w", err) + return false, fmt.Errorf("failed to create target database: %w", err) } - targetConn, err := pgx.Connect(ctx, postgres.DSN("target")) + targetConn, err := pgx.Connect(ctx, postgresInstance.DSN("target")) if err != nil { - return false, fmt.Errorf("failed to connect to database: %w", err) + return false, fmt.Errorf("failed to connect to target database: %w", err) } defer targetConn.Close(ctx) _, err = postgresConn.Exec(ctx, "CREATE DATABASE migrate;") if err != nil { - return false, fmt.Errorf("failed to create database: %w", err) + return false, fmt.Errorf("failed to create migrate database: %w", err) } - migrateConn, err := pgx.Connect(ctx, postgres.DSN("migrate")) + migrateConn, err := pgx.Connect(ctx, postgresInstance.DSN("migrate")) if err != nil { - return false, fmt.Errorf("failed to connect to database: %w", err) + return false, fmt.Errorf("failed to connect to migrate database: %w", err) } - defer targetConn.Close(ctx) + defer migrateConn.Close(ctx) statements, err := generateMigrationStatements( ctx, @@ -417,7 +418,7 @@ func runWithFile( return false, nil } -func checkIfUpdated(config *internal.Config, wd string) (bool, error) { +func checkIfUpdated(config *configuration.Config, wd string) (bool, error) { m, err := os.ReadFile(filepath.Join(wd, fmt.Sprintf("%s.dbm", config.ModelName))) if err != nil { return false, fmt.Errorf("failed to read model file: %w", err) @@ -433,7 +434,7 @@ func checkIfUpdated(config *internal.Config, wd string) (bool, error) { return true, nil } -func writeTemplateFiles(config *internal.Config, newVersion uint) error { +func writeTemplateFiles(config *configuration.Config, newVersion uint) error { for _, ts := range config.Templates { dir := filepath.Dir(ts.Path) err := os.MkdirAll(dir, 0o755) @@ -463,7 +464,7 @@ var ( //nolint:cyclop func generateMigrationStatements( ctx context.Context, - config *internal.Config, + config *configuration.Config, wd, tmpDir, migrationsDir string, @@ -494,9 +495,11 @@ func generateMigrationStatements( } }() - err = internalpostgres.CreateUsers(ctx, postgresConn, config.DatabaseUsers) - if err != nil { - return "", fmt.Errorf("failed to create users: %w", err) + for _, role := range config.Roles { + _, err = targetConn.Exec(ctx, fmt.Sprintf("CREATE ROLE %q WITH LOGIN;", role.Name)) + if err != nil { + return "", fmt.Errorf("failed to create role %q: %w", role.Name, err) + } } err = executeTargetSQL(ctx, config, wd, targetConn) @@ -554,7 +557,7 @@ func generateMigrationStatements( } func executeMigrateSQL(migrationsDir string, migrateConn *pgx.Conn) error { - m, err := migrate.New(fmt.Sprintf("file://%s", migrationsDir), internalpostgres.DSN(migrateConn, "disable")) + m, err := migrate.New(fmt.Sprintf("file://%s", migrationsDir), postgres.DSN(migrateConn, "disable")) if err != nil { return fmt.Errorf("failed to create migrate: %w", err) } @@ -566,7 +569,7 @@ func executeMigrateSQL(migrationsDir string, migrateConn *pgx.Conn) error { return nil } -func executeTargetSQL(ctx context.Context, config *internal.Config, wd string, targetConn *pgx.Conn) error { +func executeTargetSQL(ctx context.Context, config *configuration.Config, wd string, targetConn *pgx.Conn) error { targetSQL, err := os.ReadFile(filepath.Join(wd, fmt.Sprintf("%s.sql", config.ModelName))) if err != nil { return fmt.Errorf("failed to read target sql: %w", err) @@ -603,13 +606,13 @@ func generateMissingPermissionStatements( "--exclude-table=public.schema_migrations", } - targetDump, err := internalpostgres.PgDump(ctx, internalpostgres.DSN(targetConn, "disable"), pgDumpOptions) + targetDump, err := postgres.PgDump(ctx, postgres.DSN(targetConn, "disable"), pgDumpOptions) if err != nil { //nolint:wrapcheck return "", err } - migrateDump, err := internalpostgres.PgDump(ctx, internalpostgres.DSN(migrateConn, "disable"), pgDumpOptions) + migrateDump, err := postgres.PgDump(ctx, postgres.DSN(migrateConn, "disable"), pgDumpOptions) if err != nil { //nolint:wrapcheck return "", err @@ -648,7 +651,7 @@ func generateMissingPermissionStatements( } var lines []string - for _, line := range strings.Split(string(output), "\n") { + for line := range strings.SplitSeq(string(output), "\n") { if strings.HasPrefix(line, "ALTER ") { lines = append(lines, line) } diff --git a/cmd/init.go b/cmd/init.go index 4a1289b..e6f5e2d 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -14,20 +14,21 @@ import ( "github.com/spf13/cobra" "github.com/printeers/trek/internal" + "github.com/printeers/trek/internal/configuration" "github.com/printeers/trek/internal/templates" ) var errInvalidModelName = errors.New("invalid model name") var errInvalidDatabaseName = errors.New("invalid database name") -var errInvalidDatabaseUsersList = errors.New("invalid database users list") +var errInvalidRolesList = errors.New("invalid roles list") //nolint:gocognit,cyclop func NewInitCommand() *cobra.Command { var ( - version string - modelName string - databaseName string - databaseUsers string + version string + modelName string + databaseName string + roleNames string ) initCmd := &cobra.Command{ Use: "init", @@ -54,7 +55,7 @@ func NewInitCommand() *cobra.Command { } } - if modelName == "" || databaseName == "" || databaseUsers == "" { + if modelName == "" || databaseName == "" || roleNames == "" { fmt.Printf("The following answers can only contain a-z and _\n") } @@ -88,26 +89,26 @@ func NewInitCommand() *cobra.Command { } } - if databaseUsers != "" { - if err = validateDatabaseUsers(databaseUsers); err != nil { - return fmt.Errorf("invalid database users %q: %w", databaseUsers, err) + if roleNames != "" { + if err = validateRoles(roleNames); err != nil { + return fmt.Errorf("invalid roles %q: %w", roleNames, err) } } else { dbUsersPrompt := promptui.Prompt{ - Label: "Database users (comma separated)", - Validate: validateDatabaseUsers, + Label: "Roles (comma separated)", + Validate: validateRoles, } - databaseUsers, err = dbUsersPrompt.Run() + roleNames, err = dbUsersPrompt.Run() if err != nil { - return fmt.Errorf("failed to prompt database users: %w", err) + return fmt.Errorf("failed to prompt roles: %w", err) } } - templateData := map[string]interface{}{ + templateData := map[string]any{ "trek_version": version, "model_name": modelName, "db_name": databaseName, - "db_users": strings.Split(databaseUsers, ","), + "roleNames": strings.Split(roleNames, ","), } for file, tmpl := range map[string]string{ @@ -147,7 +148,7 @@ func NewInitCommand() *cobra.Command { log.Println("New project created!") - config, err := internal.ReadConfig(wd) + config, err := configuration.ReadConfig(wd) if err != nil { return fmt.Errorf("failed to read config: %w", err) } @@ -175,13 +176,13 @@ func NewInitCommand() *cobra.Command { initCmd.Flags().StringVar(&version, "version", "", "Trek version to use (in the Dockerfile)") initCmd.Flags().StringVar(&modelName, "model-name", "", "Model (file) name") initCmd.Flags().StringVar(&databaseName, "database-name", "", "Database name") - initCmd.Flags().StringVar(&databaseUsers, "database-users", "", "Database users") + initCmd.Flags().StringVar(&roleNames, "roles", "", "Roles") return initCmd } func validateModelName(s string) error { - if !internal.ValidateIdentifier(s) { + if !configuration.ValidateIdentifier(s) { return errInvalidModelName } @@ -189,22 +190,22 @@ func validateModelName(s string) error { } func validateDatabaseName(s string) error { - if !internal.ValidateIdentifier(s) { + if !configuration.ValidateIdentifier(s) { return errInvalidDatabaseName } return nil } -func validateDatabaseUsers(s string) error { - if !internal.ValidateIdentifierList(strings.Split(s, ",")) { - return errInvalidDatabaseUsersList +func validateRoles(s string) error { + if !configuration.ValidateIdentifierList(strings.Split(s, ",")) { + return errInvalidRolesList } return nil } -func writeTemplateFile(ts, filename string, templateData map[string]interface{}) error { +func writeTemplateFile(ts, filename string, templateData map[string]any) error { t, err := template.New(filename).Parse(ts) if err != nil { return fmt.Errorf("failed to parse template: %w", err) diff --git a/internal/config.go b/internal/configuration/config.go similarity index 90% rename from internal/config.go rename to internal/configuration/config.go index 4d2b43c..394b60b 100644 --- a/internal/config.go +++ b/internal/configuration/config.go @@ -1,4 +1,4 @@ -package internal +package configuration import ( "errors" @@ -22,8 +22,11 @@ type Config struct { //nolint:tagliatelle DatabaseName string `yaml:"db_name"` //nolint:tagliatelle - DatabaseUsers []string `yaml:"db_users"` - Templates []Template `yaml:"templates"` + Roles []Role `yaml:"roles"` + Templates []Template `yaml:"templates"` +} +type Role struct { + Name string `yaml:"name"` } type Template struct { @@ -69,10 +72,10 @@ func (c *Config) validate() (problems []string) { ) problems = append(problems, p) } - for _, user := range c.DatabaseUsers { - if !ValidateIdentifier(user) { + for _, role := range c.Roles { + if !ValidateIdentifier(role.Name) { p := fmt.Sprintf("Database user %q contains invalid characters. Must match %q.", - user, + role, regexpStringValidIdentifier, ) problems = append(problems, p) diff --git a/internal/dbm/dbm.go b/internal/dbm/dbm.go index bf73a3a..170ae88 100644 --- a/internal/dbm/dbm.go +++ b/internal/dbm/dbm.go @@ -4,13 +4,17 @@ package dbm import "encoding/xml" type DBModel struct { - XMLName xml.Name `xml:"dbmodel"` - Roles []struct { - Name string `xml:"name,attr"` - SQLDisabled bool `xml:"sql-disabled,attr"` - } `xml:"role"` - Databases []struct { - Name string `xml:"name,attr"` - SQLDisabled bool `xml:"sql-disabled,attr"` - } `xml:"database"` + XMLName xml.Name `xml:"dbmodel"` + Roles []Role `xml:"role"` + Databases []Database `xml:"database"` +} + +type Role struct { + Name string `xml:"name,attr"` + SQLDisabled bool `xml:"sql-disabled,attr"` +} + +type Database struct { + Name string `xml:"name,attr"` + SQLDisabled bool `xml:"sql-disabled,attr"` } diff --git a/internal/postgres/common.go b/internal/postgres/common.go index c3b9578..67bd509 100644 --- a/internal/postgres/common.go +++ b/internal/postgres/common.go @@ -9,14 +9,14 @@ import ( "github.com/jackc/pgx/v5" ) -type Database interface { +type Instance interface { Start(port uint32) error Stop() error DSN(database string) string } -func NewPostgresDatabase() Database { - db := &postgresDatabaseEmbedded{} +func NewPostgresInstance() Instance { + db := &postgresInstanceEmbedded{} return db } @@ -63,60 +63,49 @@ func PsqlFile(ctx context.Context, dsn, file string) error { return nil } -func CreateUsers(ctx context.Context, conn *pgx.Conn, users []string) error { - for _, u := range users { - _, err := conn.Exec(ctx, fmt.Sprintf("CREATE ROLE %q WITH LOGIN;", u)) - if err != nil { - return fmt.Errorf("failed to create user: %w", err) - } - } - - return nil -} - -func CheckDatabaseExists(ctx context.Context, conn *pgx.Conn, user string) (bool, error) { - a := conn.QueryRow( +func CheckDatabaseExists(ctx context.Context, conn *pgx.Conn, database string) (bool, error) { + row := conn.QueryRow( ctx, - fmt.Sprintf("SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname='%s');", user), + fmt.Sprintf("SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname='%s');", database), ) - var b bool - err := a.Scan(&b) + var exists bool + err := row.Scan(&exists) if err != nil { return false, fmt.Errorf("failed to decode row: %w", err) } - return b, nil + return exists, nil } -func CheckUserExists(ctx context.Context, conn *pgx.Conn, user string) (bool, error) { - a := conn.QueryRow( +func CheckRoleExists(ctx context.Context, conn *pgx.Conn, role string) (bool, error) { + row := conn.QueryRow( ctx, - fmt.Sprintf("SELECT EXISTS(SELECT 1 FROM pg_roles WHERE rolname='%s');", user), + fmt.Sprintf("SELECT EXISTS(SELECT 1 FROM pg_roles WHERE rolname='%s');", role), ) - var b bool - err := a.Scan(&b) + var exists bool + err := row.Scan(&exists) if err != nil { return false, fmt.Errorf("failed to decode row: %w", err) } - return b, nil + return exists, nil } func CheckTableExists(ctx context.Context, conn *pgx.Conn, schema, name string) (bool, error) { - a := conn.QueryRow( + row := conn.QueryRow( ctx, fmt.Sprintf("SELECT EXISTS(SELECT FROM pg_tables WHERE schemaname = '%s' AND tablename = '%s');", schema, name), ) - var b bool - err := a.Scan(&b) + var exists bool + err := row.Scan(&exists) if err != nil { return false, fmt.Errorf("failed to decode row: %w", err) } - return b, nil + return exists, nil } func DSN(conn *pgx.Conn, sslmode string) string { diff --git a/internal/postgres/embedded.go b/internal/postgres/embedded.go index 329c190..41ad0d9 100644 --- a/internal/postgres/embedded.go +++ b/internal/postgres/embedded.go @@ -9,15 +9,15 @@ import ( "github.com/printeers/trek/internal" ) -var _ Database = &postgresDatabaseEmbedded{} +var _ Instance = &postgresInstanceEmbedded{} -type postgresDatabaseEmbedded struct { +type postgresInstanceEmbedded struct { port uint32 db *embeddedpostgres.EmbeddedPostgres tmpDir string } -func (p *postgresDatabaseEmbedded) Start(port uint32) error { +func (p *postgresInstanceEmbedded) Start(port uint32) error { p.port = port var buf bytes.Buffer @@ -49,7 +49,7 @@ func (p *postgresDatabaseEmbedded) Start(port uint32) error { return nil } -func (p *postgresDatabaseEmbedded) Stop() error { +func (p *postgresInstanceEmbedded) Stop() error { err := p.db.Stop() if err != nil { return fmt.Errorf("failed to stop database: %w", err) @@ -68,6 +68,6 @@ func (p *postgresDatabaseEmbedded) Stop() error { return nil } -func (p *postgresDatabaseEmbedded) DSN(database string) string { +func (p *postgresInstanceEmbedded) DSN(database string) string { return fmt.Sprintf("postgres://postgres:postgres@127.0.0.1:%d/%s?sslmode=disable", p.port, database) } diff --git a/internal/templates.go b/internal/templates.go index 0fb7cee..70c80b6 100644 --- a/internal/templates.go +++ b/internal/templates.go @@ -4,16 +4,18 @@ import ( "bytes" "fmt" "text/template" + + "github.com/printeers/trek/internal/configuration" ) -func ExecuteConfigTemplate(ts Template, version uint) (*string, error) { +func ExecuteConfigTemplate(ts configuration.Template, version uint) (*string, error) { t, err := template.New(ts.Path).Parse(ts.Content) if err != nil { return nil, fmt.Errorf("failed to parse template: %w", err) } var data bytes.Buffer - err = t.Execute(&data, map[string]interface{}{"NewVersion": version}) + err = t.Execute(&data, map[string]any{"NewVersion": version}) if err != nil { return nil, fmt.Errorf("failed to execute template: %w", err) } diff --git a/internal/templates/dbm.tmpl b/internal/templates/dbm.tmpl index 4ecd3dc..9b0d4e7 100644 --- a/internal/templates/dbm.tmpl +++ b/internal/templates/dbm.tmpl @@ -5,7 +5,7 @@ CAUTION: Do not modify this file unless you know what you are doing. --> -{{range .db_users}} +{{range .roleNames}} {{end}} diff --git a/internal/templates/trek.yaml.tmpl b/internal/templates/trek.yaml.tmpl index 44dd92c..8e59a24 100755 --- a/internal/templates/trek.yaml.tmpl +++ b/internal/templates/trek.yaml.tmpl @@ -1,4 +1,4 @@ model_name: {{.model_name}} db_name: {{.db_name}} -db_users:{{range .db_users}} - - {{.}}{{end}} +roles:{{range .roleNames}} + - name: {{.}}{{end}} diff --git a/tests/output/trek.yaml b/tests/output/trek.yaml index 59d9672..2d3f4b9 100644 --- a/tests/output/trek.yaml +++ b/tests/output/trek.yaml @@ -1,5 +1,5 @@ model_name: santas_warehouse db_name: north_pole -db_users: - - santa - - worker +roles: + - name: santa + - name: worker diff --git a/tests/run.sh b/tests/run.sh index f538a45..a83be6c 100755 --- a/tests/run.sh +++ b/tests/run.sh @@ -15,15 +15,17 @@ mkdir -p output TREK_VERSION=latest \ TREK_MODEL_NAME=santas_warehouse \ TREK_DATABASE_NAME=north_pole \ - TREK_DATABASE_USERS=santa,worker \ + TREK_ROLES=santa,worker \ trek init trek check - for file in ../stages/*; do + for file in ../stages/*.dbm; do cp "$file" santas_warehouse.dbm - trek generate "$(basename "$file" | cut -d "-" -f 2 | cut -d "." -f 1)" + stage_name=$(basename "$file" | cut -d "-" -f 2 | cut -d "." -f 1) + + trek generate "$stage_name" trek check done