Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
hooks:
- id: golangci-lint-full
language_version: 1.25.0 # Should match runner/go.mod
entry: bash -c 'cd runner && golangci-lint run --fix'
entry: bash -c 'cd runner && golangci-lint run'
stages: [manual]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
Expand Down
101 changes: 64 additions & 37 deletions runner/cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,157 +5,169 @@ import (
"errors"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"path"
"path/filepath"
"syscall"
"time"

"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v3"

"github.com/dstackai/dstack/runner/consts"
"github.com/dstackai/dstack/runner/internal/common"
"github.com/dstackai/dstack/runner/internal/log"
"github.com/dstackai/dstack/runner/internal/shim"
"github.com/dstackai/dstack/runner/internal/shim/api"
"github.com/dstackai/dstack/runner/internal/shim/components"
"github.com/dstackai/dstack/runner/internal/shim/dcgm"
)

// Version is a build-time variable. The value is overridden by ldflags.
var Version string

func main() {
os.Exit(mainInner())
}

func mainInner() int {
var args shim.CLIArgs
var serviceMode bool

const defaultLogLevel = int(logrus.InfoLevel)

ctx := context.Background()

log.DefaultEntry.Logger.SetLevel(logrus.Level(defaultLogLevel))
log.DefaultEntry.Logger.SetOutput(os.Stderr)

app := &cli.App{
cmd := &cli.Command{
Name: "dstack-shim",
Usage: "Starts dstack-runner or docker container.",
Version: Version,
Flags: []cli.Flag{
/* Shim Parameters */
&cli.PathFlag{
&cli.StringFlag{
Name: "shim-home",
Usage: "Set shim's home directory",
Destination: &args.Shim.HomeDir,
TakesFile: true,
DefaultText: path.Join("~", consts.DstackDirPath),
EnvVars: []string{"DSTACK_SHIM_HOME"},
Sources: cli.EnvVars("DSTACK_SHIM_HOME"),
},
&cli.IntFlag{
Name: "shim-http-port",
Usage: "Set shim's http port",
Value: 10998,
Destination: &args.Shim.HTTPPort,
EnvVars: []string{"DSTACK_SHIM_HTTP_PORT"},
Sources: cli.EnvVars("DSTACK_SHIM_HTTP_PORT"),
},
&cli.IntFlag{
Name: "shim-log-level",
Usage: "Set shim's log level",
Value: defaultLogLevel,
Destination: &args.Shim.LogLevel,
EnvVars: []string{"DSTACK_SHIM_LOG_LEVEL"},
Sources: cli.EnvVars("DSTACK_SHIM_LOG_LEVEL"),
},
/* Runner Parameters */
&cli.StringFlag{
Name: "runner-download-url",
Usage: "Set runner's download URL",
Destination: &args.Runner.DownloadURL,
EnvVars: []string{"DSTACK_RUNNER_DOWNLOAD_URL"},
Sources: cli.EnvVars("DSTACK_RUNNER_DOWNLOAD_URL"),
},
&cli.PathFlag{
&cli.StringFlag{
Name: "runner-binary-path",
Usage: "Path to runner's binary",
Value: consts.RunnerBinaryPath,
Destination: &args.Runner.BinaryPath,
EnvVars: []string{"DSTACK_RUNNER_BINARY_PATH"},
TakesFile: true,
Sources: cli.EnvVars("DSTACK_RUNNER_BINARY_PATH"),
},
&cli.IntFlag{
Name: "runner-http-port",
Usage: "Set runner's http port",
Value: consts.RunnerHTTPPort,
Destination: &args.Runner.HTTPPort,
EnvVars: []string{"DSTACK_RUNNER_HTTP_PORT"},
Sources: cli.EnvVars("DSTACK_RUNNER_HTTP_PORT"),
},
&cli.IntFlag{
Name: "runner-ssh-port",
Usage: "Set runner's ssh port",
Value: consts.RunnerSSHPort,
Destination: &args.Runner.SSHPort,
EnvVars: []string{"DSTACK_RUNNER_SSH_PORT"},
Sources: cli.EnvVars("DSTACK_RUNNER_SSH_PORT"),
},
&cli.IntFlag{
Name: "runner-log-level",
Usage: "Set runner's log level",
Value: defaultLogLevel,
Destination: &args.Runner.LogLevel,
EnvVars: []string{"DSTACK_RUNNER_LOG_LEVEL"},
Sources: cli.EnvVars("DSTACK_RUNNER_LOG_LEVEL"),
},
/* DCGM Exporter Parameters */
&cli.IntFlag{
Name: "dcgm-exporter-http-port",
Usage: "DCGM Exporter http port",
Value: 10997,
Destination: &args.DCGMExporter.HTTPPort,
EnvVars: []string{"DSTACK_DCGM_EXPORTER_HTTP_PORT"},
Sources: cli.EnvVars("DSTACK_DCGM_EXPORTER_HTTP_PORT"),
},
&cli.IntFlag{
Name: "dcgm-exporter-interval",
Usage: "DCGM Exporter collect interval, milliseconds",
Value: 5000,
Destination: &args.DCGMExporter.Interval,
EnvVars: []string{"DSTACK_DCGM_EXPORTER_INTERVAL"},
Sources: cli.EnvVars("DSTACK_DCGM_EXPORTER_INTERVAL"),
},
/* DCGM Parameters */
&cli.StringFlag{
Name: "dcgm-address",
Usage: "nv-hostengine `hostname`, e.g., `localhost`",
DefaultText: "start libdcgm in embedded mode",
Destination: &args.DCGM.Address,
EnvVars: []string{"DSTACK_DCGM_ADDRESS"},
Sources: cli.EnvVars("DSTACK_DCGM_ADDRESS"),
},
/* Docker Parameters */
&cli.BoolFlag{
Name: "privileged",
Usage: "Give extended privileges to the container",
Destination: &args.Docker.Privileged,
EnvVars: []string{"DSTACK_DOCKER_PRIVILEGED"},
Sources: cli.EnvVars("DSTACK_DOCKER_PRIVILEGED"),
},
&cli.StringFlag{
Name: "ssh-key",
Usage: "Public SSH key",
Destination: &args.Docker.ConcatinatedPublicSSHKeys,
EnvVars: []string{"DSTACK_PUBLIC_SSH_KEY"},
Sources: cli.EnvVars("DSTACK_PUBLIC_SSH_KEY"),
},
&cli.StringFlag{
Name: "pjrt-device",
Usage: "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)",
Destination: &args.Docker.PJRTDevice,
EnvVars: []string{"PJRT_DEVICE"},
Sources: cli.EnvVars("PJRT_DEVICE"),
},
/* Misc Parameters */
&cli.BoolFlag{
Name: "service",
Usage: "Start as a service",
Destination: &serviceMode,
EnvVars: []string{"DSTACK_SERVICE_MODE"},
Sources: cli.EnvVars("DSTACK_SERVICE_MODE"),
},
},
Action: func(c *cli.Context) error {
Action: func(ctx context.Context, cmd *cli.Command) error {
return start(ctx, args, serviceMode)
},
}

if err := app.Run(os.Args); err != nil {
log.Fatal(ctx, err.Error())
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()

if err := cmd.Run(ctx, os.Args); err != nil {
log.Error(ctx, err.Error())
return 1
}

return 0
}

func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) {
Expand Down Expand Up @@ -191,8 +203,13 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
}
}()

if err := args.DownloadRunner(ctx); err != nil {
return err
runnerManager, runnerErr := components.NewRunnerManager(ctx, args.Runner.BinaryPath)
if args.Runner.DownloadURL != "" {
if err := runnerManager.Install(ctx, args.Runner.DownloadURL, false); err != nil {
return err
}
} else if runnerErr != nil {
return runnerErr
}

log.Debug(ctx, "Shim", "args", args.Shim)
Expand Down Expand Up @@ -242,13 +259,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
}

address := fmt.Sprintf("localhost:%d", args.Shim.HTTPPort)
shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper)

defer func() {
shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second)
defer cancelShutdown()
_ = shimServer.HttpServer.Shutdown(shutdownCtx)
}()
shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper, runnerManager)

if serviceMode {
if err := shim.WriteHostInfo(shimHomeDir, dockerRunner.Resources(ctx)); err != nil {
Expand All @@ -260,9 +271,25 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
}
}

if err := shimServer.HttpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
var serveErr error
serveErrCh := make(chan error)

go func() {
if err := shimServer.Serve(); err != nil {
serveErrCh <- err
}
}()

select {
case serveErr = <-serveErrCh:
case <-ctx.Done():
}

return nil
shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second)
defer cancelShutdown()
shutdownErr := shimServer.Shutdown(shutdownCtx)
if serveErr != nil {
return serveErr
}
return shutdownErr
}
Loading