Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions runner/cmd/runner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"os"
"os/signal"
"path"
"path/filepath"
"syscall"

Expand Down Expand Up @@ -121,27 +122,32 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile))
log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel))

// NB: The Mkdir/Chown/Chmod code below relies on the fact that RunnerDstackDir path is _not_ nested (/dstack).
// Adjust it if the path is changed to, e.g., /opt/dstack
const dstackDir = consts.RunnerDstackDir
dstackSshDir := path.Join(dstackDir, "ssh")

// To ensure that all components of the authorized_keys path are owned by root and no directories
// are group or world writable, as required by sshd with "StrictModes yes" (the default value),
// we fix `/dstack` ownership and permissions and remove `/dstack/ssh` (it will be (re)created
// in Sshd.Prepare())
// See: https://github.com/openssh/openssh-portable/blob/d01efaa1c9ed84fd9011201dbc3c7cb0a82bcee3/misc.c#L2257-L2272
if err := os.Mkdir("/dstack", 0o755); errors.Is(err, os.ErrExist) {
if err := os.Chown("/dstack", 0, 0); err != nil {
if err := os.Mkdir(dstackDir, 0o755); errors.Is(err, os.ErrExist) {
if err := os.Chown(dstackDir, 0, 0); err != nil {
return fmt.Errorf("chown dstack dir: %w", err)
}
if err := os.Chmod("/dstack", 0o755); err != nil {
if err := os.Chmod(dstackDir, 0o755); err != nil {
return fmt.Errorf("chmod dstack dir: %w", err)
}
} else if err != nil {
return fmt.Errorf("create dstack dir: %w", err)
}
if err := os.RemoveAll("/dstack/ssh"); err != nil {
if err := os.RemoveAll(dstackSshDir); err != nil {
return fmt.Errorf("remove dstack ssh dir: %w", err)
}

sshd := ssh.NewSshd("/usr/sbin/sshd")
if err := sshd.Prepare(ctx, "/dstack/ssh", sshPort, "INFO"); err != nil {
if err := sshd.Prepare(ctx, dstackSshDir, sshPort, "INFO"); err != nil {
return fmt.Errorf("prepare sshd: %w", err)
}
if err := sshd.AddAuthorizedKeys(ctx, sshAuthorizedKeys...); err != nil {
Expand All @@ -156,7 +162,7 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
}
}()

server, err := api.NewServer(ctx, tempDir, homeDir, fmt.Sprintf(":%d", httpPort), sshd, version)
server, err := api.NewServer(ctx, tempDir, homeDir, dstackDir, sshd, fmt.Sprintf(":%d", httpPort), version)
if err != nil {
return fmt.Errorf("create server: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions runner/cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func mainInner() int {
Usage: "Set shim's home directory",
Destination: &args.Shim.HomeDir,
TakesFile: true,
DefaultText: path.Join("~", consts.DstackDirPath),
DefaultText: path.Join("~", consts.DstackUserDir),
Sources: cli.EnvVars("DSTACK_SHIM_HOME"),
},
&cli.StringFlag{
Expand Down Expand Up @@ -187,7 +187,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
if err != nil {
return err
}
shimHomeDir = filepath.Join(home, consts.DstackDirPath)
shimHomeDir = filepath.Join(home, consts.DstackUserDir)
args.Shim.HomeDir = shimHomeDir
}

Expand Down
10 changes: 9 additions & 1 deletion runner/consts/consts.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package consts

const DstackDirPath string = ".dstack"
// A directory inside user's home used for dstack-related files
const DstackUserDir string = ".dstack"

// Runner's log filenames
const (
Expand Down Expand Up @@ -29,6 +30,13 @@ const (
// The current user's homedir (as of 2024-12-28, it's always root) should be used
// instead of the hardcoded value
RunnerHomeDir = "/root"
// A directory for:
// 1. Files used by the runner and related components (e.g., sshd stores its config and log inside /dstack/ssh)
// 2. Files shared between users (e.g., sshd authorized_keys, MPI hostfile)
// The inner structure should be considered private and subject to change, the users should not make assumptions
// about its structure.
// The only way to access its content/paths should be via public environment variables such as DSTACK_MPI_HOSTFILE.
RunnerDstackDir = "/dstack"
)

const (
Expand Down
39 changes: 24 additions & 15 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type ConnectionTracker interface {
type RunExecutor struct {
tempDir string
homeDir string
dstackDir string
archiveDir string
sshd ssh.SshdManager

Expand Down Expand Up @@ -91,7 +92,7 @@ func (s *stubConnectionTracker) GetNoConnectionsSecs() int64 { return 0 }
func (s *stubConnectionTracker) Track(ticker <-chan time.Time) {}
func (s *stubConnectionTracker) Stop() {}

func NewRunExecutor(tempDir string, homeDir string, sshd ssh.SshdManager) (*RunExecutor, error) {
func NewRunExecutor(tempDir string, homeDir string, dstackDir string, sshd ssh.SshdManager) (*RunExecutor, error) {
mu := &sync.RWMutex{}
timestamp := NewMonotonicTimestamp()
user, err := osuser.Current()
Expand Down Expand Up @@ -124,6 +125,7 @@ func NewRunExecutor(tempDir string, homeDir string, sshd ssh.SshdManager) (*RunE
return &RunExecutor{
tempDir: tempDir,
homeDir: homeDir,
dstackDir: dstackDir,
archiveDir: filepath.Join(tempDir, "file_archives"),
sshd: sshd,
currentUid: uid,
Expand Down Expand Up @@ -384,12 +386,12 @@ func (ex *RunExecutor) getRepoData() schemas.RepoData {
}

func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error {
node_rank := ex.jobSpec.JobNum
nodes_num := ex.jobSpec.JobsPerReplica
gpus_per_node_num := ex.clusterInfo.GPUSPerJob
gpus_num := nodes_num * gpus_per_node_num
nodeRank := ex.jobSpec.JobNum
nodesNum := ex.jobSpec.JobsPerReplica
gpusPerNodeNum := ex.clusterInfo.GPUSPerJob
gpusNum := nodesNum * gpusPerNodeNum

mpiHostfilePath := filepath.Join(ex.homeDir, ".dstack/mpi/hostfile")
mpiHostfilePath := filepath.Join(ex.dstackDir, "mpi/hostfile")

jobEnvs := map[string]string{
"DSTACK_RUN_ID": ex.run.Id,
Expand All @@ -400,10 +402,10 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
"DSTACK_WORKING_DIR": ex.jobWorkingDir,
"DSTACK_NODES_IPS": strings.Join(ex.clusterInfo.JobIPs, "\n"),
"DSTACK_MASTER_NODE_IP": ex.clusterInfo.MasterJobIP,
"DSTACK_NODE_RANK": strconv.Itoa(node_rank),
"DSTACK_NODES_NUM": strconv.Itoa(nodes_num),
"DSTACK_GPUS_PER_NODE": strconv.Itoa(gpus_per_node_num),
"DSTACK_GPUS_NUM": strconv.Itoa(gpus_num),
"DSTACK_NODE_RANK": strconv.Itoa(nodeRank),
"DSTACK_NODES_NUM": strconv.Itoa(nodesNum),
"DSTACK_GPUS_PER_NODE": strconv.Itoa(gpusPerNodeNum),
"DSTACK_GPUS_NUM": strconv.Itoa(gpusNum),
"DSTACK_MPI_HOSTFILE": mpiHostfilePath,
}

Expand Down Expand Up @@ -460,7 +462,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
envMap.Update(ex.jobSpec.Env, false)

const profilePath = "/etc/profile"
const dstackProfilePath = "/dstack/profile"
dstackProfilePath := path.Join(ex.dstackDir, "profile")
if err := writeDstackProfile(envMap, dstackProfilePath); err != nil {
log.Warning(ctx, "failed to write dstack_profile", "path", dstackProfilePath, "err", err)
} else if err := includeDstackProfile(profilePath, dstackProfilePath); err != nil {
Expand Down Expand Up @@ -508,7 +510,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
}
}

err = writeMpiHostfile(ctx, ex.clusterInfo.JobIPs, gpus_per_node_num, mpiHostfilePath)
err = writeMpiHostfile(ctx, ex.clusterInfo.JobIPs, gpusPerNodeNum, mpiHostfilePath)
if err != nil {
return fmt.Errorf("write MPI hostfile: %w", err)
}
Expand Down Expand Up @@ -839,7 +841,7 @@ func prepareSSHDir(uid int, gid int, homeDir string) (string, error) {
return sshDir, nil
}

func writeMpiHostfile(ctx context.Context, ips []string, gpus_per_node int, path string) error {
func writeMpiHostfile(ctx context.Context, ips []string, gpusPerNode int, path string) error {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return fmt.Errorf("create MPI hostfile directory: %w", err)
}
Expand All @@ -855,9 +857,16 @@ func writeMpiHostfile(ctx context.Context, ips []string, gpus_per_node int, path
}
}
if len(nonEmptyIps) == len(ips) {
var template string
if gpusPerNode == 0 {
// CPU node: the number of slots defaults to the number of processor cores on that host
// See: https://docs.open-mpi.org/en/main/launching-apps/scheduling.html#calculating-the-number-of-slots
template = "%s\n"
} else {
template = fmt.Sprintf("%%s slots=%d\n", gpusPerNode)
}
for _, ip := range nonEmptyIps {
line := fmt.Sprintf("%s slots=%d\n", ip, gpus_per_node)
if _, err = file.WriteString(line); err != nil {
if _, err = fmt.Fprintf(file, template, ip); err != nil {
return fmt.Errorf("write MPI hostfile line: %w", err)
}
}
Expand Down
48 changes: 25 additions & 23 deletions runner/internal/executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ func TestExecutor_WorkingDir_Set(t *testing.T) {

ex.jobSpec.WorkingDir = &workingDir
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "pwd")
err = ex.setJobWorkingDir(context.TODO())
err = ex.setJobWorkingDir(t.Context())
require.NoError(t, err)
require.Equal(t, workingDir, ex.jobWorkingDir)
err = os.MkdirAll(workingDir, 0o755)
require.NoError(t, err)

err = ex.execJob(context.TODO(), io.Writer(&b))
err = ex.execJob(t.Context(), io.Writer(&b))
assert.NoError(t, err)
// Normalize line endings for cross-platform compatibility.
assert.Equal(t, workingDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
Expand All @@ -47,11 +47,11 @@ func TestExecutor_WorkingDir_NotSet(t *testing.T) {
require.NoError(t, err)
ex.jobSpec.WorkingDir = nil
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "pwd")
err = ex.setJobWorkingDir(context.TODO())
err = ex.setJobWorkingDir(t.Context())
require.NoError(t, err)
require.Equal(t, cwd, ex.jobWorkingDir)

err = ex.execJob(context.TODO(), io.Writer(&b))
err = ex.execJob(t.Context(), io.Writer(&b))
assert.NoError(t, err)
assert.Equal(t, cwd+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
}
Expand All @@ -61,7 +61,7 @@ func TestExecutor_HomeDir(t *testing.T) {
ex := makeTestExecutor(t)
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo ~")

err := ex.execJob(context.TODO(), io.Writer(&b))
err := ex.execJob(t.Context(), io.Writer(&b))
assert.NoError(t, err)
assert.Equal(t, ex.homeDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
}
Expand All @@ -71,7 +71,7 @@ func TestExecutor_NonZeroExit(t *testing.T) {
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "exit 100")
makeCodeTar(t, ex.codePath)

err := ex.Run(context.TODO())
err := ex.Run(t.Context())
assert.Error(t, err)
assert.NotEmpty(t, ex.jobStateHistory)
exitStatus := ex.jobStateHistory[len(ex.jobStateHistory)-1].ExitStatus
Expand All @@ -90,11 +90,11 @@ func TestExecutor_SSHCredentials(t *testing.T) {
PrivateKey: &key,
}

clean, err := ex.setupCredentials(context.TODO())
clean, err := ex.setupCredentials(t.Context())
defer clean()
require.NoError(t, err)

err = ex.execJob(context.TODO(), io.Writer(&b))
err = ex.execJob(t.Context(), io.Writer(&b))
assert.NoError(t, err)
assert.Equal(t, key, b.String())
}
Expand All @@ -106,10 +106,10 @@ func TestExecutor_LocalRepo(t *testing.T) {
ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd)
makeCodeTar(t, ex.codePath)

err := ex.setupRepo(context.TODO())
err := ex.setupRepo(t.Context())
require.NoError(t, err)

err = ex.execJob(context.TODO(), io.Writer(&b))
err = ex.execJob(t.Context(), io.Writer(&b))
assert.NoError(t, err)
assert.Equal(t, "bar\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
}
Expand All @@ -119,7 +119,7 @@ func TestExecutor_Recover(t *testing.T) {
ex.jobSpec.Commands = nil // cause a panic
makeCodeTar(t, ex.codePath)

err := ex.Run(context.TODO())
err := ex.Run(t.Context())
assert.ErrorContains(t, err, "recovered: ")
}

Expand All @@ -136,7 +136,7 @@ func TestExecutor_MaxDuration(t *testing.T) {
ex.jobSpec.MaxDuration = 1 // seconds
makeCodeTar(t, ex.codePath)

err := ex.Run(context.TODO())
err := ex.Run(t.Context())
assert.ErrorContains(t, err, "killed")
}

Expand All @@ -158,12 +158,12 @@ func TestExecutor_RemoteRepo(t *testing.T) {
err := os.WriteFile(ex.codePath, []byte{}, 0o600) // empty diff
require.NoError(t, err)

err = ex.setJobWorkingDir(context.TODO())
err = ex.setJobWorkingDir(t.Context())
require.NoError(t, err)
err = ex.setupRepo(context.TODO())
err = ex.setupRepo(t.Context())
require.NoError(t, err)

err = ex.execJob(context.TODO(), io.Writer(&b))
err = ex.execJob(t.Context(), io.Writer(&b))
assert.NoError(t, err)
expected := fmt.Sprintf("%s\n%s\n%s\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail)
assert.Equal(t, expected, strings.ReplaceAll(b.String(), "\r\n", "\n"))
Expand Down Expand Up @@ -204,11 +204,13 @@ func makeTestExecutor(t *testing.T) *RunExecutor {
},
}

temp := filepath.Join(baseDir, "temp")
_ = os.Mkdir(temp, 0o700)
home := filepath.Join(baseDir, "home")
_ = os.Mkdir(home, 0o700)
ex, _ := NewRunExecutor(temp, home, new(sshdMock))
tempDir := filepath.Join(baseDir, "temp")
require.NoError(t, os.Mkdir(tempDir, 0o700))
homeDir := filepath.Join(baseDir, "home")
require.NoError(t, os.Mkdir(homeDir, 0o700))
dstackDir := filepath.Join(baseDir, "dstack")
require.NoError(t, os.Mkdir(dstackDir, 0o755))
ex, _ := NewRunExecutor(tempDir, homeDir, dstackDir, new(sshdMock))
ex.SetJob(body)
ex.SetCodePath(filepath.Join(baseDir, "code")) // note: create file before run
ex.setJobWorkingDir(context.Background())
Expand Down Expand Up @@ -261,7 +263,7 @@ func TestExecutor_Logs(t *testing.T) {
// \033[31m = red text, \033[1;32m = bold green text, \033[0m = reset
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "printf '\\033[31mRed Hello World\\033[0m\\n' && printf '\\033[1;32mBold Green Line 2\\033[0m\\n' && printf 'Line 3\\n'")

err := ex.execJob(context.TODO(), io.Writer(&b))
err := ex.execJob(t.Context(), io.Writer(&b))
assert.NoError(t, err)

logHistory := ex.GetHistory(0).JobLogs
Expand All @@ -285,7 +287,7 @@ func TestExecutor_LogsWithErrors(t *testing.T) {
ex := makeTestExecutor(t)
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo 'Success message' && echo 'Error message' >&2 && exit 1")

err := ex.execJob(context.TODO(), io.Writer(&b))
err := ex.execJob(t.Context(), io.Writer(&b))
assert.Error(t, err)

logHistory := ex.GetHistory(0).JobLogs
Expand All @@ -309,7 +311,7 @@ func TestExecutor_LogsAnsiCodeHandling(t *testing.T) {

ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd)

err := ex.execJob(context.TODO(), io.Writer(&b))
err := ex.execJob(t.Context(), io.Writer(&b))
assert.NoError(t, err)

// 1. Check WebSocket logs, which should preserve ANSI codes.
Expand Down
7 changes: 5 additions & 2 deletions runner/internal/runner/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ type Server struct {
version string
}

func NewServer(ctx context.Context, tempDir string, homeDir string, address string, sshd ssh.SshdManager, version string) (*Server, error) {
func NewServer(
ctx context.Context, tempDir string, homeDir string, dstackDir string, sshd ssh.SshdManager,
address string, version string,
) (*Server, error) {
r := api.NewRouter()
ex, err := executor.NewRunExecutor(tempDir, homeDir, sshd)
ex, err := executor.NewRunExecutor(tempDir, homeDir, dstackDir, sshd)
if err != nil {
return nil, err
}
Expand Down