From b34036a6ca25cd378406b113f22af39f6fd432bf Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Tue, 30 Dec 2025 11:38:52 +0000 Subject: [PATCH] [runner] Fix MPI hostfile * Don't set slots on CPU nodes * Move the file to /dstack/mpi and make it world-readable Fixes: https://github.com/dstackai/dstack/issues/3434 Fixes: https://github.com/dstackai/dstack/issues/3436 --- runner/cmd/runner/main.go | 18 ++++++--- runner/cmd/shim/main.go | 4 +- runner/consts/consts.go | 10 ++++- runner/internal/executor/executor.go | 39 +++++++++++------- runner/internal/executor/executor_test.go | 48 ++++++++++++----------- runner/internal/runner/api/server.go | 7 +++- 6 files changed, 77 insertions(+), 49 deletions(-) diff --git a/runner/cmd/runner/main.go b/runner/cmd/runner/main.go index 7b3bb84680..a080246d41 100644 --- a/runner/cmd/runner/main.go +++ b/runner/cmd/runner/main.go @@ -7,6 +7,7 @@ import ( "io" "os" "os/signal" + "path" "path/filepath" "syscall" @@ -121,27 +122,32 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile)) log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel)) + // NB: The Mkdir/Chown/Chmod code below relies on the fact that RunnerDstackDir path is _not_ nested (/dstack). + // Adjust it if the path is changed to, e.g., /opt/dstack + const dstackDir = consts.RunnerDstackDir + dstackSshDir := path.Join(dstackDir, "ssh") + // To ensure that all components of the authorized_keys path are owned by root and no directories // are group or world writable, as required by sshd with "StrictModes yes" (the default value), // we fix `/dstack` ownership and permissions and remove `/dstack/ssh` (it will be (re)created // in Sshd.Prepare()) // See: https://github.com/openssh/openssh-portable/blob/d01efaa1c9ed84fd9011201dbc3c7cb0a82bcee3/misc.c#L2257-L2272 - if err := os.Mkdir("/dstack", 0o755); errors.Is(err, os.ErrExist) { - if err := os.Chown("/dstack", 0, 0); err != nil { + if err := os.Mkdir(dstackDir, 0o755); errors.Is(err, os.ErrExist) { + if err := os.Chown(dstackDir, 0, 0); err != nil { return fmt.Errorf("chown dstack dir: %w", err) } - if err := os.Chmod("/dstack", 0o755); err != nil { + if err := os.Chmod(dstackDir, 0o755); err != nil { return fmt.Errorf("chmod dstack dir: %w", err) } } else if err != nil { return fmt.Errorf("create dstack dir: %w", err) } - if err := os.RemoveAll("/dstack/ssh"); err != nil { + if err := os.RemoveAll(dstackSshDir); err != nil { return fmt.Errorf("remove dstack ssh dir: %w", err) } sshd := ssh.NewSshd("/usr/sbin/sshd") - if err := sshd.Prepare(ctx, "/dstack/ssh", sshPort, "INFO"); err != nil { + if err := sshd.Prepare(ctx, dstackSshDir, sshPort, "INFO"); err != nil { return fmt.Errorf("prepare sshd: %w", err) } if err := sshd.AddAuthorizedKeys(ctx, sshAuthorizedKeys...); err != nil { @@ -156,7 +162,7 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss } }() - server, err := api.NewServer(ctx, tempDir, homeDir, fmt.Sprintf(":%d", httpPort), sshd, version) + server, err := api.NewServer(ctx, tempDir, homeDir, dstackDir, sshd, fmt.Sprintf(":%d", httpPort), version) if err != nil { return fmt.Errorf("create server: %w", err) } diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 4c3c951df2..644d7e80e8 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -56,7 +56,7 @@ func mainInner() int { Usage: "Set shim's home directory", Destination: &args.Shim.HomeDir, TakesFile: true, - DefaultText: path.Join("~", consts.DstackDirPath), + DefaultText: path.Join("~", consts.DstackUserDir), Sources: cli.EnvVars("DSTACK_SHIM_HOME"), }, &cli.StringFlag{ @@ -187,7 +187,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) if err != nil { return err } - shimHomeDir = filepath.Join(home, consts.DstackDirPath) + shimHomeDir = filepath.Join(home, consts.DstackUserDir) args.Shim.HomeDir = shimHomeDir } diff --git a/runner/consts/consts.go b/runner/consts/consts.go index 2c392b5ee4..4da4a139f7 100644 --- a/runner/consts/consts.go +++ b/runner/consts/consts.go @@ -1,6 +1,7 @@ package consts -const DstackDirPath string = ".dstack" +// A directory inside user's home used for dstack-related files +const DstackUserDir string = ".dstack" // Runner's log filenames const ( @@ -29,6 +30,13 @@ const ( // The current user's homedir (as of 2024-12-28, it's always root) should be used // instead of the hardcoded value RunnerHomeDir = "/root" + // A directory for: + // 1. Files used by the runner and related components (e.g., sshd stores its config and log inside /dstack/ssh) + // 2. Files shared between users (e.g., sshd authorized_keys, MPI hostfile) + // The inner structure should be considered private and subject to change, the users should not make assumptions + // about its structure. + // The only way to access its content/paths should be via public environment variables such as DSTACK_MPI_HOSTFILE. + RunnerDstackDir = "/dstack" ) const ( diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index 3e486f9704..56a5d1cd9f 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -54,6 +54,7 @@ type ConnectionTracker interface { type RunExecutor struct { tempDir string homeDir string + dstackDir string archiveDir string sshd ssh.SshdManager @@ -91,7 +92,7 @@ func (s *stubConnectionTracker) GetNoConnectionsSecs() int64 { return 0 } func (s *stubConnectionTracker) Track(ticker <-chan time.Time) {} func (s *stubConnectionTracker) Stop() {} -func NewRunExecutor(tempDir string, homeDir string, sshd ssh.SshdManager) (*RunExecutor, error) { +func NewRunExecutor(tempDir string, homeDir string, dstackDir string, sshd ssh.SshdManager) (*RunExecutor, error) { mu := &sync.RWMutex{} timestamp := NewMonotonicTimestamp() user, err := osuser.Current() @@ -124,6 +125,7 @@ func NewRunExecutor(tempDir string, homeDir string, sshd ssh.SshdManager) (*RunE return &RunExecutor{ tempDir: tempDir, homeDir: homeDir, + dstackDir: dstackDir, archiveDir: filepath.Join(tempDir, "file_archives"), sshd: sshd, currentUid: uid, @@ -384,12 +386,12 @@ func (ex *RunExecutor) getRepoData() schemas.RepoData { } func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error { - node_rank := ex.jobSpec.JobNum - nodes_num := ex.jobSpec.JobsPerReplica - gpus_per_node_num := ex.clusterInfo.GPUSPerJob - gpus_num := nodes_num * gpus_per_node_num + nodeRank := ex.jobSpec.JobNum + nodesNum := ex.jobSpec.JobsPerReplica + gpusPerNodeNum := ex.clusterInfo.GPUSPerJob + gpusNum := nodesNum * gpusPerNodeNum - mpiHostfilePath := filepath.Join(ex.homeDir, ".dstack/mpi/hostfile") + mpiHostfilePath := filepath.Join(ex.dstackDir, "mpi/hostfile") jobEnvs := map[string]string{ "DSTACK_RUN_ID": ex.run.Id, @@ -400,10 +402,10 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error "DSTACK_WORKING_DIR": ex.jobWorkingDir, "DSTACK_NODES_IPS": strings.Join(ex.clusterInfo.JobIPs, "\n"), "DSTACK_MASTER_NODE_IP": ex.clusterInfo.MasterJobIP, - "DSTACK_NODE_RANK": strconv.Itoa(node_rank), - "DSTACK_NODES_NUM": strconv.Itoa(nodes_num), - "DSTACK_GPUS_PER_NODE": strconv.Itoa(gpus_per_node_num), - "DSTACK_GPUS_NUM": strconv.Itoa(gpus_num), + "DSTACK_NODE_RANK": strconv.Itoa(nodeRank), + "DSTACK_NODES_NUM": strconv.Itoa(nodesNum), + "DSTACK_GPUS_PER_NODE": strconv.Itoa(gpusPerNodeNum), + "DSTACK_GPUS_NUM": strconv.Itoa(gpusNum), "DSTACK_MPI_HOSTFILE": mpiHostfilePath, } @@ -460,7 +462,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error envMap.Update(ex.jobSpec.Env, false) const profilePath = "/etc/profile" - const dstackProfilePath = "/dstack/profile" + dstackProfilePath := path.Join(ex.dstackDir, "profile") if err := writeDstackProfile(envMap, dstackProfilePath); err != nil { log.Warning(ctx, "failed to write dstack_profile", "path", dstackProfilePath, "err", err) } else if err := includeDstackProfile(profilePath, dstackProfilePath); err != nil { @@ -508,7 +510,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error } } - err = writeMpiHostfile(ctx, ex.clusterInfo.JobIPs, gpus_per_node_num, mpiHostfilePath) + err = writeMpiHostfile(ctx, ex.clusterInfo.JobIPs, gpusPerNodeNum, mpiHostfilePath) if err != nil { return fmt.Errorf("write MPI hostfile: %w", err) } @@ -839,7 +841,7 @@ func prepareSSHDir(uid int, gid int, homeDir string) (string, error) { return sshDir, nil } -func writeMpiHostfile(ctx context.Context, ips []string, gpus_per_node int, path string) error { +func writeMpiHostfile(ctx context.Context, ips []string, gpusPerNode int, path string) error { if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return fmt.Errorf("create MPI hostfile directory: %w", err) } @@ -855,9 +857,16 @@ func writeMpiHostfile(ctx context.Context, ips []string, gpus_per_node int, path } } if len(nonEmptyIps) == len(ips) { + var template string + if gpusPerNode == 0 { + // CPU node: the number of slots defaults to the number of processor cores on that host + // See: https://docs.open-mpi.org/en/main/launching-apps/scheduling.html#calculating-the-number-of-slots + template = "%s\n" + } else { + template = fmt.Sprintf("%%s slots=%d\n", gpusPerNode) + } for _, ip := range nonEmptyIps { - line := fmt.Sprintf("%s slots=%d\n", ip, gpus_per_node) - if _, err = file.WriteString(line); err != nil { + if _, err = fmt.Fprintf(file, template, ip); err != nil { return fmt.Errorf("write MPI hostfile line: %w", err) } } diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go index cc5cae7b38..e3661fac0e 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/executor/executor_test.go @@ -28,13 +28,13 @@ func TestExecutor_WorkingDir_Set(t *testing.T) { ex.jobSpec.WorkingDir = &workingDir ex.jobSpec.Commands = append(ex.jobSpec.Commands, "pwd") - err = ex.setJobWorkingDir(context.TODO()) + err = ex.setJobWorkingDir(t.Context()) require.NoError(t, err) require.Equal(t, workingDir, ex.jobWorkingDir) err = os.MkdirAll(workingDir, 0o755) require.NoError(t, err) - err = ex.execJob(context.TODO(), io.Writer(&b)) + err = ex.execJob(t.Context(), io.Writer(&b)) assert.NoError(t, err) // Normalize line endings for cross-platform compatibility. assert.Equal(t, workingDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) @@ -47,11 +47,11 @@ func TestExecutor_WorkingDir_NotSet(t *testing.T) { require.NoError(t, err) ex.jobSpec.WorkingDir = nil ex.jobSpec.Commands = append(ex.jobSpec.Commands, "pwd") - err = ex.setJobWorkingDir(context.TODO()) + err = ex.setJobWorkingDir(t.Context()) require.NoError(t, err) require.Equal(t, cwd, ex.jobWorkingDir) - err = ex.execJob(context.TODO(), io.Writer(&b)) + err = ex.execJob(t.Context(), io.Writer(&b)) assert.NoError(t, err) assert.Equal(t, cwd+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } @@ -61,7 +61,7 @@ func TestExecutor_HomeDir(t *testing.T) { ex := makeTestExecutor(t) ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo ~") - err := ex.execJob(context.TODO(), io.Writer(&b)) + err := ex.execJob(t.Context(), io.Writer(&b)) assert.NoError(t, err) assert.Equal(t, ex.homeDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } @@ -71,7 +71,7 @@ func TestExecutor_NonZeroExit(t *testing.T) { ex.jobSpec.Commands = append(ex.jobSpec.Commands, "exit 100") makeCodeTar(t, ex.codePath) - err := ex.Run(context.TODO()) + err := ex.Run(t.Context()) assert.Error(t, err) assert.NotEmpty(t, ex.jobStateHistory) exitStatus := ex.jobStateHistory[len(ex.jobStateHistory)-1].ExitStatus @@ -90,11 +90,11 @@ func TestExecutor_SSHCredentials(t *testing.T) { PrivateKey: &key, } - clean, err := ex.setupCredentials(context.TODO()) + clean, err := ex.setupCredentials(t.Context()) defer clean() require.NoError(t, err) - err = ex.execJob(context.TODO(), io.Writer(&b)) + err = ex.execJob(t.Context(), io.Writer(&b)) assert.NoError(t, err) assert.Equal(t, key, b.String()) } @@ -106,10 +106,10 @@ func TestExecutor_LocalRepo(t *testing.T) { ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd) makeCodeTar(t, ex.codePath) - err := ex.setupRepo(context.TODO()) + err := ex.setupRepo(t.Context()) require.NoError(t, err) - err = ex.execJob(context.TODO(), io.Writer(&b)) + err = ex.execJob(t.Context(), io.Writer(&b)) assert.NoError(t, err) assert.Equal(t, "bar\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } @@ -119,7 +119,7 @@ func TestExecutor_Recover(t *testing.T) { ex.jobSpec.Commands = nil // cause a panic makeCodeTar(t, ex.codePath) - err := ex.Run(context.TODO()) + err := ex.Run(t.Context()) assert.ErrorContains(t, err, "recovered: ") } @@ -136,7 +136,7 @@ func TestExecutor_MaxDuration(t *testing.T) { ex.jobSpec.MaxDuration = 1 // seconds makeCodeTar(t, ex.codePath) - err := ex.Run(context.TODO()) + err := ex.Run(t.Context()) assert.ErrorContains(t, err, "killed") } @@ -158,12 +158,12 @@ func TestExecutor_RemoteRepo(t *testing.T) { err := os.WriteFile(ex.codePath, []byte{}, 0o600) // empty diff require.NoError(t, err) - err = ex.setJobWorkingDir(context.TODO()) + err = ex.setJobWorkingDir(t.Context()) require.NoError(t, err) - err = ex.setupRepo(context.TODO()) + err = ex.setupRepo(t.Context()) require.NoError(t, err) - err = ex.execJob(context.TODO(), io.Writer(&b)) + err = ex.execJob(t.Context(), io.Writer(&b)) assert.NoError(t, err) expected := fmt.Sprintf("%s\n%s\n%s\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail) assert.Equal(t, expected, strings.ReplaceAll(b.String(), "\r\n", "\n")) @@ -204,11 +204,13 @@ func makeTestExecutor(t *testing.T) *RunExecutor { }, } - temp := filepath.Join(baseDir, "temp") - _ = os.Mkdir(temp, 0o700) - home := filepath.Join(baseDir, "home") - _ = os.Mkdir(home, 0o700) - ex, _ := NewRunExecutor(temp, home, new(sshdMock)) + tempDir := filepath.Join(baseDir, "temp") + require.NoError(t, os.Mkdir(tempDir, 0o700)) + homeDir := filepath.Join(baseDir, "home") + require.NoError(t, os.Mkdir(homeDir, 0o700)) + dstackDir := filepath.Join(baseDir, "dstack") + require.NoError(t, os.Mkdir(dstackDir, 0o755)) + ex, _ := NewRunExecutor(tempDir, homeDir, dstackDir, new(sshdMock)) ex.SetJob(body) ex.SetCodePath(filepath.Join(baseDir, "code")) // note: create file before run ex.setJobWorkingDir(context.Background()) @@ -261,7 +263,7 @@ func TestExecutor_Logs(t *testing.T) { // \033[31m = red text, \033[1;32m = bold green text, \033[0m = reset ex.jobSpec.Commands = append(ex.jobSpec.Commands, "printf '\\033[31mRed Hello World\\033[0m\\n' && printf '\\033[1;32mBold Green Line 2\\033[0m\\n' && printf 'Line 3\\n'") - err := ex.execJob(context.TODO(), io.Writer(&b)) + err := ex.execJob(t.Context(), io.Writer(&b)) assert.NoError(t, err) logHistory := ex.GetHistory(0).JobLogs @@ -285,7 +287,7 @@ func TestExecutor_LogsWithErrors(t *testing.T) { ex := makeTestExecutor(t) ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo 'Success message' && echo 'Error message' >&2 && exit 1") - err := ex.execJob(context.TODO(), io.Writer(&b)) + err := ex.execJob(t.Context(), io.Writer(&b)) assert.Error(t, err) logHistory := ex.GetHistory(0).JobLogs @@ -309,7 +311,7 @@ func TestExecutor_LogsAnsiCodeHandling(t *testing.T) { ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd) - err := ex.execJob(context.TODO(), io.Writer(&b)) + err := ex.execJob(t.Context(), io.Writer(&b)) assert.NoError(t, err) // 1. Check WebSocket logs, which should preserve ANSI codes. diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index 2e8a526273..0a0b851a9f 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -34,9 +34,12 @@ type Server struct { version string } -func NewServer(ctx context.Context, tempDir string, homeDir string, address string, sshd ssh.SshdManager, version string) (*Server, error) { +func NewServer( + ctx context.Context, tempDir string, homeDir string, dstackDir string, sshd ssh.SshdManager, + address string, version string, +) (*Server, error) { r := api.NewRouter() - ex, err := executor.NewRunExecutor(tempDir, homeDir, sshd) + ex, err := executor.NewRunExecutor(tempDir, homeDir, dstackDir, sshd) if err != nil { return nil, err }