From 9c8a91abb148ebd8a6e438a23e6c7078804d4a2f Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Mon, 29 Dec 2025 12:23:46 +0000 Subject: [PATCH] [runner] Streamline authorized_keys management * Add keys by the runner, not the shell script * Don't touch user's ~/.ssh/authorized_keys, use our own file * Share that file between all users Part-of: https://github.com/dstackai/dstack/issues/3419 --- runner/cmd/runner/main.go | 42 ++++- runner/cmd/shim/main.go | 6 - runner/internal/executor/executor.go | 145 ++---------------- runner/internal/executor/executor_test.go | 20 ++- runner/internal/runner/api/server.go | 5 +- runner/internal/shim/docker.go | 26 ++-- runner/internal/shim/docker_test.go | 4 +- runner/internal/shim/models.go | 14 +- runner/internal/ssh/sshd.go | 78 +++++++++- .../_internal/core/backends/base/compute.py | 39 +++-- .../core/backends/cloudrift/compute.py | 2 +- .../_internal/core/backends/cudo/compute.py | 2 +- .../_internal/core/backends/gcp/compute.py | 8 +- .../core/backends/hotaisle/compute.py | 5 +- .../core/backends/lambdalabs/compute.py | 5 +- .../_internal/core/backends/verda/compute.py | 2 +- .../background/tasks/process_instances.py | 2 +- 17 files changed, 191 insertions(+), 214 deletions(-) diff --git a/runner/cmd/runner/main.go b/runner/cmd/runner/main.go index 3202a4e827..7b3bb84680 100644 --- a/runner/cmd/runner/main.go +++ b/runner/cmd/runner/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "fmt" "io" "os" @@ -30,6 +31,7 @@ func mainInner() int { var homeDir string var httpPort int var sshPort int + var sshAuthorizedKeys []string var logLevel int cmd := &cli.Command{ @@ -76,9 +78,14 @@ func mainInner() int { Value: consts.RunnerSSHPort, Destination: &sshPort, }, + &cli.StringSliceFlag{ + Name: "ssh-authorized-key", + Usage: "dstack server or user authorized key. May be specified multiple times", + Destination: &sshAuthorizedKeys, + }, }, - Action: func(cxt context.Context, cmd *cli.Command) error { - return start(cxt, tempDir, homeDir, httpPort, sshPort, logLevel, Version) + Action: func(ctx context.Context, cmd *cli.Command) error { + return start(ctx, tempDir, homeDir, httpPort, sshPort, sshAuthorizedKeys, logLevel, Version) }, }, }, @@ -95,7 +102,7 @@ func mainInner() int { return 0 } -func start(ctx context.Context, tempDir string, homeDir string, httpPort int, sshPort int, logLevel int, version string) error { +func start(ctx context.Context, tempDir string, homeDir string, httpPort int, sshPort int, sshAuthorizedKeys []string, logLevel int, version string) error { if err := os.MkdirAll(tempDir, 0o755); err != nil { return fmt.Errorf("create temp directory: %w", err) } @@ -114,15 +121,32 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile)) log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel)) - server, err := api.NewServer(ctx, tempDir, homeDir, fmt.Sprintf(":%d", httpPort), sshPort, version) - if err != nil { - return fmt.Errorf("create server: %w", err) + // To ensure that all components of the authorized_keys path are owned by root and no directories + // are group or world writable, as required by sshd with "StrictModes yes" (the default value), + // we fix `/dstack` ownership and permissions and remove `/dstack/ssh` (it will be (re)created + // in Sshd.Prepare()) + // See: https://github.com/openssh/openssh-portable/blob/d01efaa1c9ed84fd9011201dbc3c7cb0a82bcee3/misc.c#L2257-L2272 + if err := os.Mkdir("/dstack", 0o755); errors.Is(err, os.ErrExist) { + if err := os.Chown("/dstack", 0, 0); err != nil { + return fmt.Errorf("chown dstack dir: %w", err) + } + if err := os.Chmod("/dstack", 0o755); err != nil { + return fmt.Errorf("chmod dstack dir: %w", err) + } + } else if err != nil { + return fmt.Errorf("create dstack dir: %w", err) + } + if err := os.RemoveAll("/dstack/ssh"); err != nil { + return fmt.Errorf("remove dstack ssh dir: %w", err) } sshd := ssh.NewSshd("/usr/sbin/sshd") - if err := sshd.Prepare(ctx, "/dstack/ssh/conf", "/dstack/ssh/log", sshPort); err != nil { + if err := sshd.Prepare(ctx, "/dstack/ssh", sshPort, "INFO"); err != nil { return fmt.Errorf("prepare sshd: %w", err) } + if err := sshd.AddAuthorizedKeys(ctx, sshAuthorizedKeys...); err != nil { + return fmt.Errorf("add authorized keys: %w", err) + } if err := sshd.Start(ctx); err != nil { return fmt.Errorf("start sshd: %w", err) } @@ -132,6 +156,10 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss } }() + server, err := api.NewServer(ctx, tempDir, homeDir, fmt.Sprintf(":%d", httpPort), sshd, version) + if err != nil { + return fmt.Errorf("create server: %w", err) + } log.Trace(ctx, "Starting API server", "port", httpPort) if err := server.Run(ctx); err != nil { return fmt.Errorf("server failed: %w", err) diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 79aefbda6a..4c3c951df2 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -147,12 +147,6 @@ func mainInner() int { Destination: &args.Docker.Privileged, Sources: cli.EnvVars("DSTACK_DOCKER_PRIVILEGED"), }, - &cli.StringFlag{ - Name: "ssh-key", - Usage: "Public SSH key", - Destination: &args.Docker.ConcatinatedPublicSSHKeys, - Sources: cli.EnvVars("DSTACK_PUBLIC_SSH_KEY"), - }, &cli.StringFlag{ Name: "pjrt-device", Usage: "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)", diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index 6f302e81cb..3e486f9704 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -29,6 +29,7 @@ import ( "github.com/dstackai/dstack/runner/internal/connections" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/schemas" + "github.com/dstackai/dstack/runner/internal/ssh" "github.com/dstackai/dstack/runner/internal/types" ) @@ -54,7 +55,8 @@ type RunExecutor struct { tempDir string homeDir string archiveDir string - sshPort int + sshd ssh.SshdManager + currentUid uint32 run schemas.Run @@ -89,7 +91,7 @@ func (s *stubConnectionTracker) GetNoConnectionsSecs() int64 { return 0 } func (s *stubConnectionTracker) Track(ticker <-chan time.Time) {} func (s *stubConnectionTracker) Stop() {} -func NewRunExecutor(tempDir string, homeDir string, sshPort int) (*RunExecutor, error) { +func NewRunExecutor(tempDir string, homeDir string, sshd ssh.SshdManager) (*RunExecutor, error) { mu := &sync.RWMutex{} timestamp := NewMonotonicTimestamp() user, err := osuser.Current() @@ -110,7 +112,7 @@ func NewRunExecutor(tempDir string, homeDir string, sshPort int) (*RunExecutor, return nil, fmt.Errorf("initialize procfs: %w", err) } connectionTracker = connections.NewConnectionTracker(connections.ConnectionTrackerConfig{ - Port: uint64(sshPort), + Port: uint64(sshd.Port()), MinConnDuration: 10 * time.Second, // shorter connections are likely from dstack-server Procfs: proc, }) @@ -123,7 +125,7 @@ func NewRunExecutor(tempDir string, homeDir string, sshPort int) (*RunExecutor, tempDir: tempDir, homeDir: homeDir, archiveDir: filepath.Join(tempDir, "file_archives"), - sshPort: sshPort, + sshd: sshd, currentUid: uid, jobUid: -1, jobGid: -1, @@ -466,8 +468,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error } // As of 2024-11-29, ex.homeDir is always set to /root - rootSSHDir, err := prepareSSHDir(-1, -1, ex.homeDir) - if err != nil { + if _, err := prepareSSHDir(-1, -1, ex.homeDir); err != nil { log.Warning(ctx, "failed to prepare ssh dir", "home", ex.homeDir, "err", err) } userSSHDir := "" @@ -484,14 +485,6 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error userSSHDir, err = prepareSSHDir(uid, gid, homeDir) if err != nil { log.Warning(ctx, "failed to prepare ssh dir", "home", homeDir, "err", err) - } else { - rootSSHKeysPath := filepath.Join(rootSSHDir, "authorized_keys") - userSSHKeysPath := filepath.Join(userSSHDir, "authorized_keys") - restoreUserSSHKeys := backupFile(ctx, userSSHKeysPath) - defer restoreUserSSHKeys(ctx) - if err := copyAuthorizedKeys(rootSSHKeysPath, uid, gid, userSSHKeysPath); err != nil { - log.Warning(ctx, "failed to copy authorized keys", "path", homeDir, "err", err) - } } } else { log.Trace(ctx, "homeDir is not accessible, skipping provisioning", "path", homeDir) @@ -504,9 +497,12 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error if ex.jobSpec.SSHKey != nil && userSSHDir != "" { err := configureSSH( - ex.jobSpec.SSHKey.Private, ex.jobSpec.SSHKey.Public, ex.clusterInfo.JobIPs, ex.sshPort, + ex.jobSpec.SSHKey.Private, ex.clusterInfo.JobIPs, ex.sshd.Port(), uid, gid, userSSHDir, ) + if err == nil { + err = ex.sshd.AddAuthorizedKeys(ctx, ex.jobSpec.SSHKey.Public) + } if err != nil { log.Warning(ctx, "failed to configure SSH", "err", err) } @@ -914,7 +910,7 @@ func includeDstackProfile(profilePath string, dstackProfilePath string) error { return nil } -func configureSSH(private string, public string, ips []string, port int, uid int, gid int, sshDir string) error { +func configureSSH(private string, ips []string, port int, uid int, gid int, sshDir string) error { privatePath := filepath.Join(sshDir, "dstack_job") privateFile, err := os.OpenFile(privatePath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0o600) if err != nil { @@ -928,19 +924,9 @@ func configureSSH(private string, public string, ips []string, port int, uid int return fmt.Errorf("write private key: %w", err) } - akPath := filepath.Join(sshDir, "authorized_keys") - akFile, err := os.OpenFile(akPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600) - if err != nil { - return fmt.Errorf("open authorized_keys: %w", err) - } - defer akFile.Close() - if err := os.Chown(akPath, uid, gid); err != nil { - return fmt.Errorf("chown authorized_keys: %w", err) - } - if _, err := akFile.WriteString(public); err != nil { - return fmt.Errorf("write public key: %w", err) - } - + // TODO: move job hosts config to ~/.dstack/ssh/config.d/current_job + // and add "Include ~/.dstack/ssh/config.d/*" directive to ~/.ssh/config if not present + // instead of appending job hosts config directly (don't bloat user's ssh_config) configPath := filepath.Join(sshDir, "config") configFile, err := os.OpenFile(configPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600) if err != nil { @@ -963,104 +949,3 @@ func configureSSH(private string, public string, ips []string, port int, uid int } return nil } - -// A makeshift solution to deliver authorized_keys to a non-root user -// without modifying the existing API/bootstrap process -// TODO: implement key delivery properly, i.e. sumbit keys to and write by the runner, -// not the outer sh script that launches sshd and runner -func copyAuthorizedKeys(srcPath string, uid int, gid int, dstPath string) error { - srcFile, err := os.Open(srcPath) - if err != nil { - return fmt.Errorf("open source authorized_keys: %w", err) - } - defer srcFile.Close() - - dstExists := false - info, err := os.Stat(dstPath) - if err == nil { - dstExists = true - if info.IsDir() { - return fmt.Errorf("is a directory: %s", dstPath) - } - if err = os.Chmod(dstPath, 0o600); err != nil { - return fmt.Errorf("chmod destination authorized_keys: %w", err) - } - } else if !errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("stat destination authorized_keys: %w", err) - } - - dstFile, err := os.OpenFile(dstPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600) - if err != nil { - return fmt.Errorf("open destination authorized_keys: %w", err) - } - defer dstFile.Close() - - if dstExists { - // visually separate our keys from existing ones - if _, err := dstFile.WriteString("\n\n"); err != nil { - return fmt.Errorf("write separator to authorized_keys: %w", err) - } - } - if _, err := io.Copy(dstFile, srcFile); err != nil { - return fmt.Errorf("copy authorized_keys: %w", err) - } - if err = os.Chown(dstPath, uid, gid); err != nil { - return fmt.Errorf("chown destination authorized_keys: %w", err) - } - - return nil -} - -// backupFile renames `/path/to/file` to `/path/to/file.dstack.bak`, -// creates a new file with the same content, and returns restore function that -// renames the backup back to the original name. -// If the original file does not exist, restore function removes the file if it is created. -// NB: A newly created file has default uid:gid and permissions, probably not -// the same as the original file. -func backupFile(ctx context.Context, path string) func(context.Context) { - var existed bool - backupPath := path + ".dstack.bak" - - restoreFunc := func(ctx context.Context) { - if !existed { - err := os.Remove(path) - if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Error(ctx, "failed to remove", "path", path, "err", err) - } - return - } - err := os.Rename(backupPath, path) - if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Error(ctx, "failed to restore", "path", path, "err", err) - } - } - - err := os.Rename(path, backupPath) - if errors.Is(err, os.ErrNotExist) { - existed = false - return restoreFunc - } - existed = true - if err != nil { - log.Error(ctx, "failed to back up", "path", path, "err", err) - return restoreFunc - } - - src, err := os.Open(backupPath) - if err != nil { - log.Error(ctx, "failed to open backup src", "path", backupPath, "err", err) - return restoreFunc - } - defer src.Close() - dst, err := os.Create(path) - if err != nil { - log.Error(ctx, "failed to open backup dest", "path", path, "err", err) - return restoreFunc - } - defer dst.Close() - _, err = io.Copy(dst, src) - if err != nil { - log.Error(ctx, "failed to copy backup", "path", backupPath, "err", err) - } - return restoreFunc -} diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go index 0e0b14d84e..cc5cae7b38 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/executor/executor_test.go @@ -208,7 +208,7 @@ func makeTestExecutor(t *testing.T) *RunExecutor { _ = os.Mkdir(temp, 0o700) home := filepath.Join(baseDir, "home") _ = os.Mkdir(home, 0o700) - ex, _ := NewRunExecutor(temp, home, 10022) + ex, _ := NewRunExecutor(temp, home, new(sshdMock)) ex.SetJob(body) ex.SetCodePath(filepath.Join(baseDir, "code")) // note: create file before run ex.setJobWorkingDir(context.Background()) @@ -341,6 +341,24 @@ func TestExecutor_LogsAnsiCodeHandling(t *testing.T) { } } +type sshdMock struct{} + +func (d *sshdMock) Port() int { + return 0 +} + +func (d *sshdMock) Start(context.Context) error { + return nil +} + +func (d *sshdMock) Stop(context.Context) error { + return nil +} + +func (d *sshdMock) AddAuthorizedKeys(context.Context, ...string) error { + return nil +} + func combineLogMessages(logHistory []schemas.LogEvent) string { var logOutput bytes.Buffer for _, logEvent := range logHistory { diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index 1a5459c96b..2e8a526273 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -11,6 +11,7 @@ import ( "github.com/dstackai/dstack/runner/internal/executor" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/metrics" + "github.com/dstackai/dstack/runner/internal/ssh" ) type Server struct { @@ -33,9 +34,9 @@ type Server struct { version string } -func NewServer(ctx context.Context, tempDir string, homeDir string, address string, sshPort int, version string) (*Server, error) { +func NewServer(ctx context.Context, tempDir string, homeDir string, address string, sshd ssh.SshdManager, version string) (*Server, error) { r := api.NewRouter() - ex, err := executor.NewRunExecutor(tempDir, homeDir, sshPort) + ex, err := executor.NewRunExecutor(tempDir, homeDir, sshd) if err != nil { return nil, err } diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 631a10c46b..7e29e92dd7 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -920,7 +920,7 @@ func encodeRegistryAuth(username string, password string) (string, error) { return base64.URLEncoding.EncodeToString(encodedConfig), nil } -func getSSHShellCommands(publicSSHKey string) []string { +func getSSHShellCommands() []string { return []string{ `( :`, // See https://github.com/dstackai/dstack/issues/1769 @@ -936,11 +936,6 @@ func getSSHShellCommands(publicSSHKey string) []string { `if exists apk; then install_pkg() { apk add -U "$1"; }; fi`, // check in sshd is here, install if not `if ! exists sshd; then install_pkg openssh-server; fi`, - // create ssh dirs and add public key - "mkdir -p ~/.ssh", - "chmod 700 ~/.ssh", - fmt.Sprintf("echo '%s' > ~/.ssh/authorized_keys", publicSSHKey), - "chmod 600 ~/.ssh/authorized_keys", `: )`, } } @@ -1190,21 +1185,20 @@ func (c *CLIArgs) DockerPJRTDevice() string { } func (c *CLIArgs) DockerShellCommands(publicKeys []string) []string { - concatinatedPublicKeys := c.Docker.ConcatinatedPublicSSHKeys - if len(publicKeys) > 0 { - concatinatedPublicKeys = strings.Join(publicKeys, "\n") - } - commands := getSSHShellCommands(concatinatedPublicKeys) - runnerArgs := []string{ + commands := getSSHShellCommands() + runnerCommand := []string{ + consts.RunnerBinaryPath, "--log-level", strconv.Itoa(c.Runner.LogLevel), "start", + "--home-dir", consts.RunnerHomeDir, + "--temp-dir", consts.RunnerTempDir, "--http-port", strconv.Itoa(c.Runner.HTTPPort), "--ssh-port", strconv.Itoa(c.Runner.SSHPort), - "--temp-dir", consts.RunnerTempDir, - "--home-dir", consts.RunnerHomeDir, } - commands = append(commands, fmt.Sprintf("%s %s", consts.RunnerBinaryPath, strings.Join(runnerArgs, " "))) - return commands + for _, key := range publicKeys { + runnerCommand = append(runnerCommand, "--ssh-authorized-key", fmt.Sprintf("'%s'", key)) + } + return append(commands, strings.Join(runnerCommand, " ")) } func (c *CLIArgs) DockerMounts(hostRunnerDir string) ([]mount.Mount, error) { diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index 48fa27d5ec..faa31bbc06 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -6,7 +6,6 @@ import ( "math/rand" "os" "runtime" - "strings" "sync" "testing" "time" @@ -112,10 +111,9 @@ func (c *dockerParametersMock) DockerPJRTDevice() string { } func (c *dockerParametersMock) DockerShellCommands(publicKeys []string) []string { - userPublicKey := strings.Join(publicKeys, "\n") commands := make([]string, 0) if c.sshShellCommands { - commands = append(commands, getSSHShellCommands(userPublicKey)...) + commands = append(commands, getSSHShellCommands()...) } commands = append(commands, c.commands...) return commands diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 0a0c697eec..5952286507 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -39,9 +39,8 @@ type CLIArgs struct { } Docker struct { - ConcatinatedPublicSSHKeys string - Privileged bool - PJRTDevice string + Privileged bool + PJRTDevice string } } @@ -98,11 +97,10 @@ type TaskConfig struct { InstanceMounts []InstanceMountPoint `json:"instance_mounts"` // GPUDevices allows the server to set gpu devices instead of relying on the runner default logic. // E.g. passing nvidia devices directly instead of using nvidia-container-toolkit. - GPUDevices []GPUDevice `json:"gpu_devices"` - HostSshUser string `json:"host_ssh_user"` - HostSshKeys []string `json:"host_ssh_keys"` - // TODO: submit keys to runner, not to shim - ContainerSshKeys []string `json:"container_ssh_keys"` + GPUDevices []GPUDevice `json:"gpu_devices"` + HostSshUser string `json:"host_ssh_user"` + HostSshKeys []string `json:"host_ssh_keys"` + ContainerSshKeys []string `json:"container_ssh_keys"` } type TaskListItem struct { diff --git a/runner/internal/ssh/sshd.go b/runner/internal/ssh/sshd.go index 49de50ad67..d46be7e24f 100644 --- a/runner/internal/ssh/sshd.go +++ b/runner/internal/ssh/sshd.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path" + "sync" "syscall" "time" @@ -14,16 +15,29 @@ import ( "github.com/dstackai/dstack/runner/internal/log" ) +type SshdManager interface { + Port() int + + Start(context.Context) error + Stop(context.Context) error + AddAuthorizedKeys(context.Context, ...string) error +} + var hostKeys = [...]string{ "ssh_host_rsa_key", "ssh_host_ecdsa_key", "ssh_host_ed25519_key", } +// Implements SshdManager type Sshd struct { binPath string confPath string logPath string + akPath string + port int + + akMu sync.Mutex cmd *exec.Cmd } @@ -34,19 +48,34 @@ func NewSshd(binPath string) *Sshd { } } -func (d *Sshd) Prepare(ctx context.Context, confDir string, logDir string, port int) error { +func (d *Sshd) Port() int { + return d.port +} + +func (d *Sshd) Prepare(ctx context.Context, baseDir string, port int, logLevel string) error { + confDir := path.Join(baseDir, "conf") if err := os.MkdirAll(confDir, 0o755); err != nil { return fmt.Errorf("create conf dir: %w", err) } + if err := generateHostKeys(ctx, confDir); err != nil { return fmt.Errorf("generate host keys: %w", err) } - confPath, err := createSshdConfig(ctx, confDir, port) + + akPath, err := prepareAuthorizedKeysFile(confDir) + if err != nil { + return fmt.Errorf("prepare authorized_keys: %w", err) + } + d.akPath = akPath + + confPath, err := createSshdConfig(ctx, confDir, port, logLevel, akPath) if err != nil { return fmt.Errorf("create sshd config: %w", err) } d.confPath = confPath + d.port = port + logDir := path.Join(baseDir, "log") logPath, err := prepareLogPath(logDir) if err != nil { return fmt.Errorf("prepare log path: %w", err) @@ -67,6 +96,29 @@ func (d *Sshd) Prepare(ctx context.Context, confDir string, logDir string, port return nil } +func (d *Sshd) AddAuthorizedKeys(ctx context.Context, authorizedKeys ...string) error { + d.akMu.Lock() + defer d.akMu.Unlock() + + file, err := os.OpenFile(d.akPath, os.O_WRONLY|os.O_APPEND, 0o700) + if err != nil { + return fmt.Errorf("open authorized_keys: %w", err) + } + defer func() { + if err := file.Close(); err != nil { + log.Error(ctx, "Close authorized_keys", "err", err) + } + }() + + for _, key := range authorizedKeys { + if _, err := fmt.Fprintln(file, key); err != nil { + return fmt.Errorf("write authorized_keys: %w", err) + } + } + + return nil +} + func (d *Sshd) Start(ctx context.Context) error { if d.confPath == "" { return errors.New("not configured") @@ -148,7 +200,23 @@ func copyHostKey(srcDir string, destDir string, key string) error { return nil } -func createSshdConfig(ctx context.Context, confDir string, port int) (string, error) { +func prepareAuthorizedKeysFile(confDir string) (string, error) { + // Ensures that the file exists, has correct ownership and permissions, and is empty + akPath := path.Join(confDir, "authorized_keys") + if _, err := common.RemoveIfExists(akPath); err != nil { + return "", err + } + file, err := os.OpenFile(akPath, os.O_CREATE|os.O_EXCL|os.O_RDONLY, 0o644) + if err != nil { + return "", err + } + if err := file.Close(); err != nil { + return "", err + } + return akPath, nil +} + +func createSshdConfig(ctx context.Context, confDir string, port int, logLevel string, akPath string) (string, error) { confPath := path.Join(confDir, "sshd_config") file, err := os.OpenFile(confPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) if err != nil { @@ -161,6 +229,7 @@ func createSshdConfig(ctx context.Context, confDir string, port int) (string, er }() lines := []string{ + fmt.Sprintf("LogLevel %s", logLevel), fmt.Sprintf("Port %d", port), "PidFile none", "Subsystem sftp internal-sftp", @@ -176,7 +245,8 @@ func createSshdConfig(ctx context.Context, confDir string, port int) (string, er // See: useradd(8) // TODO: Change to `no` if a custom OpenSSH build without LOCKED_PASSWD_PREFIX is used "UsePAM yes", - "AuthorizedKeysFile .ssh/authorized_keys", + // Keep ~/.ssh/authorized_keys as a fallback in case our sshd server is also used by the user for their purposes + fmt.Sprintf("AuthorizedKeysFile %s .ssh/authorized_keys", akPath), "AcceptEnv LANG LC_* COLORTERM NO_COLOR", "ClientAliveInterval 30", "ClientAliveCountMax 4", diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 9599c9f97b..13cba1eb53 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -1,6 +1,7 @@ import os import random import re +import shlex import string import threading from abc import ABC, abstractmethod @@ -683,7 +684,6 @@ def get_user_data( firewall_allow_from_subnets: Iterable[str] = DEFAULT_PRIVATE_SUBNETS, ) -> str: shim_commands = get_shim_commands( - authorized_keys=authorized_keys, base_path=base_path, bin_path=bin_path, backend_shim_env=backend_shim_env, @@ -698,7 +698,6 @@ def get_user_data( def get_shim_env( - authorized_keys: List[str], base_path: Optional[PathLike] = None, bin_path: Optional[PathLike] = None, backend_shim_env: Optional[Dict[str, str]] = None, @@ -714,7 +713,6 @@ def get_shim_env( "DSTACK_RUNNER_HTTP_PORT": str(DSTACK_RUNNER_HTTP_PORT), "DSTACK_RUNNER_SSH_PORT": str(DSTACK_RUNNER_SSH_PORT), "DSTACK_RUNNER_LOG_LEVEL": log_level, - "DSTACK_PUBLIC_SSH_KEY": "\n".join(authorized_keys), } if backend_shim_env is not None: envs |= backend_shim_env @@ -722,7 +720,6 @@ def get_shim_env( def get_shim_commands( - authorized_keys: List[str], *, is_privileged: bool = False, pjrt_device: Optional[str] = None, @@ -743,7 +740,6 @@ def get_shim_commands( arch=arch, ) shim_env = get_shim_env( - authorized_keys=authorized_keys, base_path=base_path, bin_path=bin_path, backend_shim_env=backend_shim_env, @@ -942,7 +938,6 @@ def get_docker_commands( bin_path: Optional[PathLike] = None, ) -> list[str]: dstack_runner_binary_path = get_dstack_runner_binary_path(bin_path) - authorized_keys_content = "\n".join(authorized_keys).strip() commands = [ "( :", # See https://github.com/dstackai/dstack/issues/1769 @@ -960,27 +955,31 @@ def get_docker_commands( "if ! exists sshd; then install_pkg openssh-server; fi", # install curl if necessary "if ! exists curl; then install_pkg curl; fi", - # create ssh dirs and add public key - "mkdir -p ~/.ssh", - "chmod 700 ~/.ssh", - f"echo '{authorized_keys_content}' > ~/.ssh/authorized_keys", - "chmod 600 ~/.ssh/authorized_keys", ": )", ] + runner_command = [ + dstack_runner_binary_path, + "--log-level", + "6", + "start", + "--home-dir", + "/root", + "--temp-dir", + "/tmp/runner", + "--http-port", + str(DSTACK_RUNNER_HTTP_PORT), + "--ssh-port", + str(DSTACK_RUNNER_SSH_PORT), + ] + for authorized_key in authorized_keys: + runner_command += ["--ssh-authorized-key", authorized_key] + url = get_dstack_runner_download_url() commands += [ f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {dstack_runner_binary_path} {url}", f"chmod +x {dstack_runner_binary_path}", - ( - f"{dstack_runner_binary_path}" - " --log-level 6" - " start" - f" --http-port {DSTACK_RUNNER_HTTP_PORT}" - f" --ssh-port {DSTACK_RUNNER_SSH_PORT}" - " --temp-dir /tmp/runner" - " --home-dir /root" - ), + shlex.join(runner_command), ] return commands diff --git a/src/dstack/_internal/core/backends/cloudrift/compute.py b/src/dstack/_internal/core/backends/cloudrift/compute.py index 02e8b09469..d2bd4fc755 100644 --- a/src/dstack/_internal/core/backends/cloudrift/compute.py +++ b/src/dstack/_internal/core/backends/cloudrift/compute.py @@ -73,7 +73,7 @@ def create_instance( instance_config: InstanceConfiguration, placement_group: Optional[PlacementGroup], ) -> JobProvisioningData: - commands = get_shim_commands(authorized_keys=instance_config.get_public_keys()) + commands = get_shim_commands() startup_script = " ".join([" && ".join(commands)]) logger.debug( f"Creating instance for offer {instance_offer.instance.name} in region {instance_offer.region} with commands: {startup_script}" diff --git a/src/dstack/_internal/core/backends/cudo/compute.py b/src/dstack/_internal/core/backends/cudo/compute.py index aca9e53c7d..5edae4ce46 100644 --- a/src/dstack/_internal/core/backends/cudo/compute.py +++ b/src/dstack/_internal/core/backends/cudo/compute.py @@ -75,7 +75,7 @@ def create_instance( commands = install_jq_commands() else: commands = install_docker_commands() - commands += get_shim_commands(authorized_keys=public_keys) + commands += get_shim_commands() try: resp_data = self.api_client.create_virtual_machine( diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 76a394bc7a..c2c18e3d9f 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -304,7 +304,7 @@ def create_instance( ) if is_tpu: instance_id = instance_name - startup_script = _get_tpu_startup_script(authorized_keys) + startup_script = _get_tpu_startup_script() # GCP does not allow attaching disks while TPUs is creating, # so we need to attach the disks on creation. data_disks = _get_tpu_data_disks(self.config.project_id, instance_config.volumes) @@ -1178,10 +1178,8 @@ def _get_volume_price(size: int) -> float: return size * 0.12 -def _get_tpu_startup_script(authorized_keys: List[str]) -> str: - commands = get_shim_commands( - authorized_keys=authorized_keys, is_privileged=True, pjrt_device="TPU" - ) +def _get_tpu_startup_script() -> str: + commands = get_shim_commands(is_privileged=True, pjrt_device="TPU") startup_script = " ".join([" && ".join(commands)]) startup_script = "#! /bin/bash\n" + startup_script return startup_script diff --git a/src/dstack/_internal/core/backends/hotaisle/compute.py b/src/dstack/_internal/core/backends/hotaisle/compute.py index 10013b22a2..4ebfdab288 100644 --- a/src/dstack/_internal/core/backends/hotaisle/compute.py +++ b/src/dstack/_internal/core/backends/hotaisle/compute.py @@ -103,10 +103,7 @@ def update_provisioning_data( if provisioning_data.hostname is None and provisioning_data.backend_data: backend_data = HotAisleInstanceBackendData.load(provisioning_data.backend_data) provisioning_data.hostname = backend_data.ip_address - commands = get_shim_commands( - authorized_keys=[project_ssh_public_key], - arch=provisioning_data.instance_type.resources.cpu_arch, - ) + commands = get_shim_commands(arch=provisioning_data.instance_type.resources.cpu_arch) launch_command = "sudo sh -c " + shlex.quote(" && ".join(commands)) thread = Thread( target=_start_runner, diff --git a/src/dstack/_internal/core/backends/lambdalabs/compute.py b/src/dstack/_internal/core/backends/lambdalabs/compute.py index ae27488347..445a8b3948 100644 --- a/src/dstack/_internal/core/backends/lambdalabs/compute.py +++ b/src/dstack/_internal/core/backends/lambdalabs/compute.py @@ -95,10 +95,7 @@ def update_provisioning_data( instance_info = _get_instance_info(self.api_client, provisioning_data.instance_id) if instance_info is not None and instance_info["status"] != "booting": provisioning_data.hostname = instance_info["ip"] - commands = get_shim_commands( - authorized_keys=[project_ssh_public_key], - arch=provisioning_data.instance_type.resources.cpu_arch, - ) + commands = get_shim_commands(arch=provisioning_data.instance_type.resources.cpu_arch) # shim is assumed to be run under root launch_command = "sudo sh -c " + shlex.quote(" && ".join(commands)) thread = Thread( diff --git a/src/dstack/_internal/core/backends/verda/compute.py b/src/dstack/_internal/core/backends/verda/compute.py index d6dbdd6ae0..4ad995d9ea 100644 --- a/src/dstack/_internal/core/backends/verda/compute.py +++ b/src/dstack/_internal/core/backends/verda/compute.py @@ -112,7 +112,7 @@ def create_instance( ) ) - commands = get_shim_commands(authorized_keys=public_keys) + commands = get_shim_commands() startup_script = " ".join([" && ".join(commands)]) script_name = f"dstack-{instance_config.instance_name}.sh" startup_script_ids = _get_or_create_startup_scrpit( diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 4b45e68b13..9d75c58756 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -503,7 +503,7 @@ def _deploy_instance( logger.debug("The script for installing dstack has been executed") # Upload envs - shim_envs = get_shim_env(authorized_keys, arch=arch) + shim_envs = get_shim_env(arch=arch) try: fleet_configuration_envs = remote_details.env.as_dict() except ValueError as e: