Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions runner/cmd/runner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -30,6 +31,7 @@ func mainInner() int {
var homeDir string
var httpPort int
var sshPort int
var sshAuthorizedKeys []string
var logLevel int

cmd := &cli.Command{
Expand Down Expand Up @@ -76,9 +78,14 @@ func mainInner() int {
Value: consts.RunnerSSHPort,
Destination: &sshPort,
},
&cli.StringSliceFlag{
Name: "ssh-authorized-key",
Usage: "dstack server or user authorized key. May be specified multiple times",
Destination: &sshAuthorizedKeys,
},
},
Action: func(cxt context.Context, cmd *cli.Command) error {
return start(cxt, tempDir, homeDir, httpPort, sshPort, logLevel, Version)
Action: func(ctx context.Context, cmd *cli.Command) error {
return start(ctx, tempDir, homeDir, httpPort, sshPort, sshAuthorizedKeys, logLevel, Version)
},
},
},
Expand All @@ -95,7 +102,7 @@ func mainInner() int {
return 0
}

func start(ctx context.Context, tempDir string, homeDir string, httpPort int, sshPort int, logLevel int, version string) error {
func start(ctx context.Context, tempDir string, homeDir string, httpPort int, sshPort int, sshAuthorizedKeys []string, logLevel int, version string) error {
if err := os.MkdirAll(tempDir, 0o755); err != nil {
return fmt.Errorf("create temp directory: %w", err)
}
Expand All @@ -114,15 +121,32 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile))
log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel))

server, err := api.NewServer(ctx, tempDir, homeDir, fmt.Sprintf(":%d", httpPort), sshPort, version)
if err != nil {
return fmt.Errorf("create server: %w", err)
// To ensure that all components of the authorized_keys path are owned by root and no directories
// are group or world writable, as required by sshd with "StrictModes yes" (the default value),
// we fix `/dstack` ownership and permissions and remove `/dstack/ssh` (it will be (re)created
// in Sshd.Prepare())
// See: https://github.com/openssh/openssh-portable/blob/d01efaa1c9ed84fd9011201dbc3c7cb0a82bcee3/misc.c#L2257-L2272
if err := os.Mkdir("/dstack", 0o755); errors.Is(err, os.ErrExist) {
if err := os.Chown("/dstack", 0, 0); err != nil {
return fmt.Errorf("chown dstack dir: %w", err)
}
if err := os.Chmod("/dstack", 0o755); err != nil {
return fmt.Errorf("chmod dstack dir: %w", err)
}
} else if err != nil {
return fmt.Errorf("create dstack dir: %w", err)
}
if err := os.RemoveAll("/dstack/ssh"); err != nil {
return fmt.Errorf("remove dstack ssh dir: %w", err)
}

sshd := ssh.NewSshd("/usr/sbin/sshd")
if err := sshd.Prepare(ctx, "/dstack/ssh/conf", "/dstack/ssh/log", sshPort); err != nil {
if err := sshd.Prepare(ctx, "/dstack/ssh", sshPort, "INFO"); err != nil {
return fmt.Errorf("prepare sshd: %w", err)
}
if err := sshd.AddAuthorizedKeys(ctx, sshAuthorizedKeys...); err != nil {
return fmt.Errorf("add authorized keys: %w", err)
}
if err := sshd.Start(ctx); err != nil {
return fmt.Errorf("start sshd: %w", err)
}
Expand All @@ -132,6 +156,10 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
}
}()

server, err := api.NewServer(ctx, tempDir, homeDir, fmt.Sprintf(":%d", httpPort), sshd, version)
if err != nil {
return fmt.Errorf("create server: %w", err)
}
log.Trace(ctx, "Starting API server", "port", httpPort)
if err := server.Run(ctx); err != nil {
return fmt.Errorf("server failed: %w", err)
Expand Down
6 changes: 0 additions & 6 deletions runner/cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,6 @@ func mainInner() int {
Destination: &args.Docker.Privileged,
Sources: cli.EnvVars("DSTACK_DOCKER_PRIVILEGED"),
},
&cli.StringFlag{
Name: "ssh-key",
Usage: "Public SSH key",
Destination: &args.Docker.ConcatinatedPublicSSHKeys,
Sources: cli.EnvVars("DSTACK_PUBLIC_SSH_KEY"),
},
&cli.StringFlag{
Name: "pjrt-device",
Usage: "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)",
Expand Down
145 changes: 15 additions & 130 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/dstackai/dstack/runner/internal/connections"
"github.com/dstackai/dstack/runner/internal/log"
"github.com/dstackai/dstack/runner/internal/schemas"
"github.com/dstackai/dstack/runner/internal/ssh"
"github.com/dstackai/dstack/runner/internal/types"
)

Expand All @@ -54,7 +55,8 @@ type RunExecutor struct {
tempDir string
homeDir string
archiveDir string
sshPort int
sshd ssh.SshdManager

currentUid uint32

run schemas.Run
Expand Down Expand Up @@ -89,7 +91,7 @@ func (s *stubConnectionTracker) GetNoConnectionsSecs() int64 { return 0 }
func (s *stubConnectionTracker) Track(ticker <-chan time.Time) {}
func (s *stubConnectionTracker) Stop() {}

func NewRunExecutor(tempDir string, homeDir string, sshPort int) (*RunExecutor, error) {
func NewRunExecutor(tempDir string, homeDir string, sshd ssh.SshdManager) (*RunExecutor, error) {
mu := &sync.RWMutex{}
timestamp := NewMonotonicTimestamp()
user, err := osuser.Current()
Expand All @@ -110,7 +112,7 @@ func NewRunExecutor(tempDir string, homeDir string, sshPort int) (*RunExecutor,
return nil, fmt.Errorf("initialize procfs: %w", err)
}
connectionTracker = connections.NewConnectionTracker(connections.ConnectionTrackerConfig{
Port: uint64(sshPort),
Port: uint64(sshd.Port()),
MinConnDuration: 10 * time.Second, // shorter connections are likely from dstack-server
Procfs: proc,
})
Expand All @@ -123,7 +125,7 @@ func NewRunExecutor(tempDir string, homeDir string, sshPort int) (*RunExecutor,
tempDir: tempDir,
homeDir: homeDir,
archiveDir: filepath.Join(tempDir, "file_archives"),
sshPort: sshPort,
sshd: sshd,
currentUid: uid,
jobUid: -1,
jobGid: -1,
Expand Down Expand Up @@ -466,8 +468,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
}

// As of 2024-11-29, ex.homeDir is always set to /root
rootSSHDir, err := prepareSSHDir(-1, -1, ex.homeDir)
if err != nil {
if _, err := prepareSSHDir(-1, -1, ex.homeDir); err != nil {
log.Warning(ctx, "failed to prepare ssh dir", "home", ex.homeDir, "err", err)
}
userSSHDir := ""
Expand All @@ -484,14 +485,6 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
userSSHDir, err = prepareSSHDir(uid, gid, homeDir)
if err != nil {
log.Warning(ctx, "failed to prepare ssh dir", "home", homeDir, "err", err)
} else {
rootSSHKeysPath := filepath.Join(rootSSHDir, "authorized_keys")
userSSHKeysPath := filepath.Join(userSSHDir, "authorized_keys")
restoreUserSSHKeys := backupFile(ctx, userSSHKeysPath)
defer restoreUserSSHKeys(ctx)
if err := copyAuthorizedKeys(rootSSHKeysPath, uid, gid, userSSHKeysPath); err != nil {
log.Warning(ctx, "failed to copy authorized keys", "path", homeDir, "err", err)
}
}
} else {
log.Trace(ctx, "homeDir is not accessible, skipping provisioning", "path", homeDir)
Expand All @@ -504,9 +497,12 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error

if ex.jobSpec.SSHKey != nil && userSSHDir != "" {
err := configureSSH(
ex.jobSpec.SSHKey.Private, ex.jobSpec.SSHKey.Public, ex.clusterInfo.JobIPs, ex.sshPort,
ex.jobSpec.SSHKey.Private, ex.clusterInfo.JobIPs, ex.sshd.Port(),
uid, gid, userSSHDir,
)
if err == nil {
err = ex.sshd.AddAuthorizedKeys(ctx, ex.jobSpec.SSHKey.Public)
}
if err != nil {
log.Warning(ctx, "failed to configure SSH", "err", err)
}
Expand Down Expand Up @@ -914,7 +910,7 @@ func includeDstackProfile(profilePath string, dstackProfilePath string) error {
return nil
}

func configureSSH(private string, public string, ips []string, port int, uid int, gid int, sshDir string) error {
func configureSSH(private string, ips []string, port int, uid int, gid int, sshDir string) error {
privatePath := filepath.Join(sshDir, "dstack_job")
privateFile, err := os.OpenFile(privatePath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0o600)
if err != nil {
Expand All @@ -928,19 +924,9 @@ func configureSSH(private string, public string, ips []string, port int, uid int
return fmt.Errorf("write private key: %w", err)
}

akPath := filepath.Join(sshDir, "authorized_keys")
akFile, err := os.OpenFile(akPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600)
if err != nil {
return fmt.Errorf("open authorized_keys: %w", err)
}
defer akFile.Close()
if err := os.Chown(akPath, uid, gid); err != nil {
return fmt.Errorf("chown authorized_keys: %w", err)
}
if _, err := akFile.WriteString(public); err != nil {
return fmt.Errorf("write public key: %w", err)
}

// TODO: move job hosts config to ~/.dstack/ssh/config.d/current_job
// and add "Include ~/.dstack/ssh/config.d/*" directive to ~/.ssh/config if not present
// instead of appending job hosts config directly (don't bloat user's ssh_config)
configPath := filepath.Join(sshDir, "config")
configFile, err := os.OpenFile(configPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600)
if err != nil {
Expand All @@ -963,104 +949,3 @@ func configureSSH(private string, public string, ips []string, port int, uid int
}
return nil
}

// A makeshift solution to deliver authorized_keys to a non-root user
// without modifying the existing API/bootstrap process
// TODO: implement key delivery properly, i.e. sumbit keys to and write by the runner,
// not the outer sh script that launches sshd and runner
func copyAuthorizedKeys(srcPath string, uid int, gid int, dstPath string) error {
srcFile, err := os.Open(srcPath)
if err != nil {
return fmt.Errorf("open source authorized_keys: %w", err)
}
defer srcFile.Close()

dstExists := false
info, err := os.Stat(dstPath)
if err == nil {
dstExists = true
if info.IsDir() {
return fmt.Errorf("is a directory: %s", dstPath)
}
if err = os.Chmod(dstPath, 0o600); err != nil {
return fmt.Errorf("chmod destination authorized_keys: %w", err)
}
} else if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("stat destination authorized_keys: %w", err)
}

dstFile, err := os.OpenFile(dstPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600)
if err != nil {
return fmt.Errorf("open destination authorized_keys: %w", err)
}
defer dstFile.Close()

if dstExists {
// visually separate our keys from existing ones
if _, err := dstFile.WriteString("\n\n"); err != nil {
return fmt.Errorf("write separator to authorized_keys: %w", err)
}
}
if _, err := io.Copy(dstFile, srcFile); err != nil {
return fmt.Errorf("copy authorized_keys: %w", err)
}
if err = os.Chown(dstPath, uid, gid); err != nil {
return fmt.Errorf("chown destination authorized_keys: %w", err)
}

return nil
}

// backupFile renames `/path/to/file` to `/path/to/file.dstack.bak`,
// creates a new file with the same content, and returns restore function that
// renames the backup back to the original name.
// If the original file does not exist, restore function removes the file if it is created.
// NB: A newly created file has default uid:gid and permissions, probably not
// the same as the original file.
func backupFile(ctx context.Context, path string) func(context.Context) {
var existed bool
backupPath := path + ".dstack.bak"

restoreFunc := func(ctx context.Context) {
if !existed {
err := os.Remove(path)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error(ctx, "failed to remove", "path", path, "err", err)
}
return
}
err := os.Rename(backupPath, path)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error(ctx, "failed to restore", "path", path, "err", err)
}
}

err := os.Rename(path, backupPath)
if errors.Is(err, os.ErrNotExist) {
existed = false
return restoreFunc
}
existed = true
if err != nil {
log.Error(ctx, "failed to back up", "path", path, "err", err)
return restoreFunc
}

src, err := os.Open(backupPath)
if err != nil {
log.Error(ctx, "failed to open backup src", "path", backupPath, "err", err)
return restoreFunc
}
defer src.Close()
dst, err := os.Create(path)
if err != nil {
log.Error(ctx, "failed to open backup dest", "path", path, "err", err)
return restoreFunc
}
defer dst.Close()
_, err = io.Copy(dst, src)
if err != nil {
log.Error(ctx, "failed to copy backup", "path", backupPath, "err", err)
}
return restoreFunc
}
20 changes: 19 additions & 1 deletion runner/internal/executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func makeTestExecutor(t *testing.T) *RunExecutor {
_ = os.Mkdir(temp, 0o700)
home := filepath.Join(baseDir, "home")
_ = os.Mkdir(home, 0o700)
ex, _ := NewRunExecutor(temp, home, 10022)
ex, _ := NewRunExecutor(temp, home, new(sshdMock))
ex.SetJob(body)
ex.SetCodePath(filepath.Join(baseDir, "code")) // note: create file before run
ex.setJobWorkingDir(context.Background())
Expand Down Expand Up @@ -341,6 +341,24 @@ func TestExecutor_LogsAnsiCodeHandling(t *testing.T) {
}
}

type sshdMock struct{}

func (d *sshdMock) Port() int {
return 0
}

func (d *sshdMock) Start(context.Context) error {
return nil
}

func (d *sshdMock) Stop(context.Context) error {
return nil
}

func (d *sshdMock) AddAuthorizedKeys(context.Context, ...string) error {
return nil
}

func combineLogMessages(logHistory []schemas.LogEvent) string {
var logOutput bytes.Buffer
for _, logEvent := range logHistory {
Expand Down
5 changes: 3 additions & 2 deletions runner/internal/runner/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/dstackai/dstack/runner/internal/executor"
"github.com/dstackai/dstack/runner/internal/log"
"github.com/dstackai/dstack/runner/internal/metrics"
"github.com/dstackai/dstack/runner/internal/ssh"
)

type Server struct {
Expand All @@ -33,9 +34,9 @@ type Server struct {
version string
}

func NewServer(ctx context.Context, tempDir string, homeDir string, address string, sshPort int, version string) (*Server, error) {
func NewServer(ctx context.Context, tempDir string, homeDir string, address string, sshd ssh.SshdManager, version string) (*Server, error) {
r := api.NewRouter()
ex, err := executor.NewRunExecutor(tempDir, homeDir, sshPort)
ex, err := executor.NewRunExecutor(tempDir, homeDir, sshd)
if err != nil {
return nil, err
}
Expand Down
Loading