diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b3475efda5..efd29636a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: hooks: - id: golangci-lint-full language_version: 1.25.0 # Should match runner/go.mod - entry: bash -c 'cd runner && golangci-lint run --fix' + entry: bash -c 'cd runner && golangci-lint run' stages: [manual] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index b7f52d26a9..af468a6a93 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -5,20 +5,22 @@ import ( "errors" "fmt" "io" - "net/http" "os" + "os/signal" "path" "path/filepath" + "syscall" "time" "github.com/sirupsen/logrus" - "github.com/urfave/cli/v2" + "github.com/urfave/cli/v3" "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/common" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/shim" "github.com/dstackai/dstack/runner/internal/shim/api" + "github.com/dstackai/dstack/runner/internal/shim/components" "github.com/dstackai/dstack/runner/internal/shim/dcgm" ) @@ -26,77 +28,81 @@ import ( var Version string func main() { + os.Exit(mainInner()) +} + +func mainInner() int { var args shim.CLIArgs var serviceMode bool const defaultLogLevel = int(logrus.InfoLevel) - ctx := context.Background() - log.DefaultEntry.Logger.SetLevel(logrus.Level(defaultLogLevel)) log.DefaultEntry.Logger.SetOutput(os.Stderr) - app := &cli.App{ + cmd := &cli.Command{ Name: "dstack-shim", Usage: "Starts dstack-runner or docker container.", Version: Version, Flags: []cli.Flag{ /* Shim Parameters */ - &cli.PathFlag{ + &cli.StringFlag{ Name: "shim-home", Usage: "Set shim's home directory", Destination: &args.Shim.HomeDir, + TakesFile: true, DefaultText: path.Join("~", consts.DstackDirPath), - EnvVars: []string{"DSTACK_SHIM_HOME"}, + Sources: cli.EnvVars("DSTACK_SHIM_HOME"), }, &cli.IntFlag{ Name: "shim-http-port", Usage: "Set shim's http port", Value: 10998, Destination: &args.Shim.HTTPPort, - EnvVars: []string{"DSTACK_SHIM_HTTP_PORT"}, + Sources: cli.EnvVars("DSTACK_SHIM_HTTP_PORT"), }, &cli.IntFlag{ Name: "shim-log-level", Usage: "Set shim's log level", Value: defaultLogLevel, Destination: &args.Shim.LogLevel, - EnvVars: []string{"DSTACK_SHIM_LOG_LEVEL"}, + Sources: cli.EnvVars("DSTACK_SHIM_LOG_LEVEL"), }, /* Runner Parameters */ &cli.StringFlag{ Name: "runner-download-url", Usage: "Set runner's download URL", Destination: &args.Runner.DownloadURL, - EnvVars: []string{"DSTACK_RUNNER_DOWNLOAD_URL"}, + Sources: cli.EnvVars("DSTACK_RUNNER_DOWNLOAD_URL"), }, - &cli.PathFlag{ + &cli.StringFlag{ Name: "runner-binary-path", Usage: "Path to runner's binary", Value: consts.RunnerBinaryPath, Destination: &args.Runner.BinaryPath, - EnvVars: []string{"DSTACK_RUNNER_BINARY_PATH"}, + TakesFile: true, + Sources: cli.EnvVars("DSTACK_RUNNER_BINARY_PATH"), }, &cli.IntFlag{ Name: "runner-http-port", Usage: "Set runner's http port", Value: consts.RunnerHTTPPort, Destination: &args.Runner.HTTPPort, - EnvVars: []string{"DSTACK_RUNNER_HTTP_PORT"}, + Sources: cli.EnvVars("DSTACK_RUNNER_HTTP_PORT"), }, &cli.IntFlag{ Name: "runner-ssh-port", Usage: "Set runner's ssh port", Value: consts.RunnerSSHPort, Destination: &args.Runner.SSHPort, - EnvVars: []string{"DSTACK_RUNNER_SSH_PORT"}, + Sources: cli.EnvVars("DSTACK_RUNNER_SSH_PORT"), }, &cli.IntFlag{ Name: "runner-log-level", Usage: "Set runner's log level", Value: defaultLogLevel, Destination: &args.Runner.LogLevel, - EnvVars: []string{"DSTACK_RUNNER_LOG_LEVEL"}, + Sources: cli.EnvVars("DSTACK_RUNNER_LOG_LEVEL"), }, /* DCGM Exporter Parameters */ &cli.IntFlag{ @@ -104,14 +110,14 @@ func main() { Usage: "DCGM Exporter http port", Value: 10997, Destination: &args.DCGMExporter.HTTPPort, - EnvVars: []string{"DSTACK_DCGM_EXPORTER_HTTP_PORT"}, + Sources: cli.EnvVars("DSTACK_DCGM_EXPORTER_HTTP_PORT"), }, &cli.IntFlag{ Name: "dcgm-exporter-interval", Usage: "DCGM Exporter collect interval, milliseconds", Value: 5000, Destination: &args.DCGMExporter.Interval, - EnvVars: []string{"DSTACK_DCGM_EXPORTER_INTERVAL"}, + Sources: cli.EnvVars("DSTACK_DCGM_EXPORTER_INTERVAL"), }, /* DCGM Parameters */ &cli.StringFlag{ @@ -119,43 +125,49 @@ func main() { Usage: "nv-hostengine `hostname`, e.g., `localhost`", DefaultText: "start libdcgm in embedded mode", Destination: &args.DCGM.Address, - EnvVars: []string{"DSTACK_DCGM_ADDRESS"}, + Sources: cli.EnvVars("DSTACK_DCGM_ADDRESS"), }, /* Docker Parameters */ &cli.BoolFlag{ Name: "privileged", Usage: "Give extended privileges to the container", Destination: &args.Docker.Privileged, - EnvVars: []string{"DSTACK_DOCKER_PRIVILEGED"}, + Sources: cli.EnvVars("DSTACK_DOCKER_PRIVILEGED"), }, &cli.StringFlag{ Name: "ssh-key", Usage: "Public SSH key", Destination: &args.Docker.ConcatinatedPublicSSHKeys, - EnvVars: []string{"DSTACK_PUBLIC_SSH_KEY"}, + Sources: cli.EnvVars("DSTACK_PUBLIC_SSH_KEY"), }, &cli.StringFlag{ Name: "pjrt-device", Usage: "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)", Destination: &args.Docker.PJRTDevice, - EnvVars: []string{"PJRT_DEVICE"}, + Sources: cli.EnvVars("PJRT_DEVICE"), }, /* Misc Parameters */ &cli.BoolFlag{ Name: "service", Usage: "Start as a service", Destination: &serviceMode, - EnvVars: []string{"DSTACK_SERVICE_MODE"}, + Sources: cli.EnvVars("DSTACK_SERVICE_MODE"), }, }, - Action: func(c *cli.Context) error { + Action: func(ctx context.Context, cmd *cli.Command) error { return start(ctx, args, serviceMode) }, } - if err := app.Run(os.Args); err != nil { - log.Fatal(ctx, err.Error()) + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + if err := cmd.Run(ctx, os.Args); err != nil { + log.Error(ctx, err.Error()) + return 1 } + + return 0 } func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) { @@ -191,8 +203,13 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) } }() - if err := args.DownloadRunner(ctx); err != nil { - return err + runnerManager, runnerErr := components.NewRunnerManager(ctx, args.Runner.BinaryPath) + if args.Runner.DownloadURL != "" { + if err := runnerManager.Install(ctx, args.Runner.DownloadURL, false); err != nil { + return err + } + } else if runnerErr != nil { + return runnerErr } log.Debug(ctx, "Shim", "args", args.Shim) @@ -242,13 +259,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) } address := fmt.Sprintf("localhost:%d", args.Shim.HTTPPort) - shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper) - - defer func() { - shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second) - defer cancelShutdown() - _ = shimServer.HttpServer.Shutdown(shutdownCtx) - }() + shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper, runnerManager) if serviceMode { if err := shim.WriteHostInfo(shimHomeDir, dockerRunner.Resources(ctx)); err != nil { @@ -260,9 +271,25 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) } } - if err := shimServer.HttpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - return err + var serveErr error + serveErrCh := make(chan error) + + go func() { + if err := shimServer.Serve(); err != nil { + serveErrCh <- err + } + }() + + select { + case serveErr = <-serveErrCh: + case <-ctx.Done(): } - return nil + shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second) + defer cancelShutdown() + shutdownErr := shimServer.Shutdown(shutdownCtx) + if serveErr != nil { + return serveErr + } + return shutdownErr } diff --git a/runner/docs/shim.openapi.yaml b/runner/docs/shim.openapi.yaml index d612cc0bad..e6f49fa079 100644 --- a/runner/docs/shim.openapi.yaml +++ b/runner/docs/shim.openapi.yaml @@ -1,8 +1,8 @@ -openapi: 3.1.1 +openapi: 3.1.2 info: title: dstack-shim API - version: v2/0.19.22 + version: v2/0.19.41 x-logo: url: https://avatars.githubusercontent.com/u/54146142?s=260 description: > @@ -53,7 +53,6 @@ paths: /instance/health: get: summary: Get instance health - description: (since [0.19.22](https://github.com/dstackai/dstack/releases/tag/0.19.22)) Returns an object of optional passive system checks tags: [Instance] responses: @@ -64,6 +63,43 @@ paths: schema: $ref: "#/components/schemas/InstanceHealthResponse" + /components: + get: + summary: Get components + description: (since [0.19.41](https://github.com/dstackai/dstack/releases/tag/0.19.41)) Returns a list of software components (e.g., `dstack-runner`) + tags: [Components] + responses: + "200": + description: "" + content: + application/json: + schema: + $ref: "#/components/schemas/ComponentListResponse" + + /components/install: + post: + summary: Install component + description: > + (since [0.19.41](https://github.com/dstackai/dstack/releases/tag/0.19.41)) Request installing/updating the software component. + Components are installed asynchronously + tags: [Components] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ComponentInstallRequest" + responses: + "200": + description: Request accepted + $ref: "#/components/responses/PlainTextOk" + "400": + description: Malformed JSON body or validation error + $ref: "#/components/responses/PlainTextBadRequest" + "409": + description: The component is already being installed + $ref: "#/components/responses/PlainTextConflict" + /tasks: get: summary: Get task list @@ -104,7 +140,7 @@ paths: summary: Get task info tags: [Tasks] parameters: - - $ref: "#/parameters/taskId" + - $ref: "#/components/parameters/taskId" responses: "200": $ref: "#/components/responses/TaskInfo" @@ -151,7 +187,7 @@ paths: resources: a container, logs, etc. tags: [Tasks] parameters: - - $ref: "#/parameters/taskId" + - $ref: "#/components/parameters/taskId" responses: "200": description: Task removed @@ -166,15 +202,15 @@ paths: description: Internal error, e.g., failed to remove a container $ref: "#/components/responses/PlainTextInternalError" -parameters: - taskId: - name: id - in: path - schema: - $ref: "#/components/schemas/TaskID" - required: true - components: + parameters: + taskId: + name: id + in: path + schema: + $ref: "#/components/schemas/TaskID" + required: true + schemas: TaskID: description: Unique task ID assigned by dstack server @@ -369,6 +405,43 @@ components: - entity_id additionalProperties: false + ComponentName: + title: shim.components.ComponentName + type: string + enum: + - dstack-runner + + ComponentStatus: + title: shim.components.ComponentStatus + type: string + enum: + - not-installed + - installed + - installing + - error + + ComponentInfo: + title: shim.components.ComponenInfo + type: object + properties: + name: + $ref: "#/components/schemas/ComponentName" + version: + type: string + description: An empty string if status != installed + examples: + - 0.19.41 + status: + allOf: + - $ref: "#/components/schemas/ComponentStatus" + - examples: + - installed + required: + - name + - version + - status + additionalProperties: false + HealthcheckResponse: title: shim.api.HealthcheckResponse type: object @@ -392,6 +465,32 @@ components: $ref: "#/components/schemas/DCGMHealth" additionalProperties: false + ComponentListResponse: + title: shim.api.ComponentListResponse + type: object + properties: + components: + type: array + items: + $ref: "#/components/schemas/ComponentInfo" + required: + - components + additionalProperties: false + + ComponentInstallRequest: + title: shim.api.ComponentInstallRequest + type: object + properties: + name: + $ref: "#/components/schemas/ComponentName" + url: + type: string + examples: + - https://dstack-runner-downloads.s3.eu-west-1.amazonaws.com/0.19.41/binaries/dstack-runner-linux-amd64 + required: + - name + - url + TaskListResponse: title: shim.api.TaskListResponse type: object @@ -535,8 +634,9 @@ components: examples: - 1073741824 network_mode: - $ref: "#/components/schemas/NetworkMode" - default: host + allOf: + - $ref: "#/components/schemas/NetworkMode" + - default: host volumes: type: array items: diff --git a/runner/go.mod b/runner/go.mod index 8c474cc42d..b317f6c7b0 100644 --- a/runner/go.mod +++ b/runner/go.mod @@ -19,8 +19,9 @@ require ( github.com/prometheus/procfs v0.15.1 github.com/shirou/gopsutil/v4 v4.24.11 github.com/sirupsen/logrus v1.9.3 - github.com/stretchr/testify v1.10.0 - github.com/urfave/cli/v2 v2.27.1 + github.com/stretchr/testify v1.11.1 + github.com/urfave/cli/v2 v2.27.7 + github.com/urfave/cli/v3 v3.6.1 golang.org/x/crypto v0.22.0 golang.org/x/sys v0.26.0 ) @@ -32,7 +33,7 @@ require ( github.com/bits-and-blooms/bitset v1.22.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/containerd/log v0.1.0 // indirect - github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/distribution/reference v0.6.0 // indirect @@ -69,7 +70,7 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/ulikunitz/xz v0.5.12 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect - github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 // indirect + github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.50.0 // indirect go.opentelemetry.io/otel v1.25.0 // indirect diff --git a/runner/go.sum b/runner/go.sum index f9e6c0e912..de734fa39a 100644 --- a/runner/go.sum +++ b/runner/go.sum @@ -34,8 +34,8 @@ github.com/codeclysm/extract/v4 v4.0.0 h1:H87LFsUNaJTu2e/8p/oiuiUsOK/TaPQ5wxsjPn github.com/codeclysm/extract/v4 v4.0.0/go.mod h1:SFju1lj6as7FvUgalpSct7torJE0zttbJUWtryPRG6s= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= -github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4= -github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= +github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53EtKeQYTC3kyg= @@ -175,8 +175,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= @@ -185,12 +185,14 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/ulikunitz/xz v0.5.12 h1:37Nm15o69RwBkXM0J6A5OlE67RZTfzUxTj8fB3dfcsc= github.com/ulikunitz/xz v0.5.12/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= -github.com/urfave/cli/v2 v2.27.1 h1:8xSQ6szndafKVRmfyeUMxkNUJQMjL1F2zmsZ+qHpfho= -github.com/urfave/cli/v2 v2.27.1/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= +github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= +github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= +github.com/urfave/cli/v3 v3.6.1 h1:j8Qq8NyUawj/7rTYdBGrxcH7A/j7/G8Q5LhWEW4G3Mo= +github.com/urfave/cli/v3 v3.6.1/go.mod h1:ysVLtOEmg2tOy6PknnYVhDoouyC/6N42TMeoMzskhso= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= -github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 h1:+qGGcbkzsfDQNPPe9UDgpxAWQrhbbBXOYJFQDq/dtJw= -github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913/go.mod h1:4aEEwZQutDLsQv2Deui4iYQ6DWTxR14g6m8Wv88+Xqk= +github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= +github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/runner/internal/common/utils.go b/runner/internal/common/utils.go index 9352b0201c..01cae30a47 100644 --- a/runner/internal/common/utils.go +++ b/runner/internal/common/utils.go @@ -10,6 +10,17 @@ import ( "github.com/dstackai/dstack/runner/internal/log" ) +func PathExists(pth string) (bool, error) { + _, err := os.Stat(pth) + if err == nil { + return true, nil + } + if errors.Is(err, os.ErrNotExist) { + return false, nil + } + return false, err +} + func ExpandPath(pth string, base string, home string) (string, error) { pth = path.Clean(pth) if pth == "~" { diff --git a/runner/internal/shim/api/handlers.go b/runner/internal/shim/api/handlers.go index 91df9cb55f..7e4f172272 100644 --- a/runner/internal/shim/api/handlers.go +++ b/runner/internal/shim/api/handlers.go @@ -8,6 +8,7 @@ import ( "github.com/dstackai/dstack/runner/internal/api" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/shim" + "github.com/dstackai/dstack/runner/internal/shim/components" "github.com/dstackai/dstack/runner/internal/shim/dcgm" ) @@ -156,3 +157,46 @@ func (s *ShimServer) TaskMetricsHandler(w http.ResponseWriter, r *http.Request) response := dcgm.FilterMetrics(expfmtBody, taskInfo.GpuIDs) _, _ = w.Write(response) } + +func (s *ShimServer) ComponentListHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + runnerStatus := s.runnerManager.GetInfo(r.Context()) + response := &ComponentListResponse{ + Components: []components.ComponentInfo{runnerStatus}, + } + return response, nil +} + +func (s *ShimServer) ComponentInstallHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + var req ComponentInstallRequest + if err := api.DecodeJSONBody(w, r, &req, true); err != nil { + return nil, err + } + + if req.Name == "" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty name"} + } + + switch components.ComponentName(req.Name) { + case components.ComponentNameRunner: + if req.URL == "" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty url"} + } + + // There is still a small chance of time-of-check race condition, but we ignore it. + runnerInfo := s.runnerManager.GetInfo(r.Context()) + if runnerInfo.Status == components.ComponentStatusInstalling { + return nil, &api.Error{Status: http.StatusConflict, Msg: "already installing"} + } + + s.bgJobsGroup.Go(func() { + if err := s.runnerManager.Install(s.bgJobsCtx, req.URL, true); err != nil { + log.Error(s.bgJobsCtx, "runner background install", "err", err) + } + }) + + default: + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "unknown component"} + } + + return nil, nil +} diff --git a/runner/internal/shim/api/handlers_test.go b/runner/internal/shim/api/handlers_test.go index c640fdb731..c04621eb0a 100644 --- a/runner/internal/shim/api/handlers_test.go +++ b/runner/internal/shim/api/handlers_test.go @@ -13,7 +13,7 @@ func TestHealthcheck(t *testing.T) { request := httptest.NewRequest("GET", "/api/healthcheck", nil) responseRecorder := httptest.NewRecorder() - server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil) + server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil) f := common.JSONResponseHandler(server.HealthcheckHandler) f(responseRecorder, request) @@ -30,7 +30,7 @@ func TestHealthcheck(t *testing.T) { } func TestTaskSubmit(t *testing.T) { - server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil) + server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil) requestBody := `{ "id": "dummy-id", "name": "dummy-name", diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index 41d09b8ac6..a7d5fa7d48 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -2,6 +2,7 @@ package api import ( "github.com/dstackai/dstack/runner/internal/shim" + "github.com/dstackai/dstack/runner/internal/shim/components" "github.com/dstackai/dstack/runner/internal/shim/dcgm" ) @@ -37,3 +38,12 @@ type TaskTerminateRequest struct { TerminationMessage string `json:"termination_message"` Timeout uint `json:"timeout"` } + +type ComponentListResponse struct { + Components []components.ComponentInfo `json:"components"` +} + +type ComponentInstallRequest struct { + Name string `json:"name"` + URL string `json:"url"` +} diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index 8fd7026a99..15e0191354 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -2,6 +2,7 @@ package api import ( "context" + "errors" "net" "net/http" "reflect" @@ -9,6 +10,7 @@ import ( "github.com/dstackai/dstack/runner/internal/api" "github.com/dstackai/dstack/runner/internal/shim" + "github.com/dstackai/dstack/runner/internal/shim/components" "github.com/dstackai/dstack/runner/internal/shim/dcgm" ) @@ -24,43 +26,59 @@ type TaskRunner interface { } type ShimServer struct { - HttpServer *http.Server + httpServer *http.Server mu sync.RWMutex + bgJobsCtx context.Context + bgJobsCancel context.CancelFunc + bgJobsGroup *sync.WaitGroup + runner TaskRunner dcgmExporter *dcgm.DCGMExporter dcgmWrapper dcgm.DCGMWrapperInterface // interface with nil value normalized to plain nil + runnerManager *components.RunnerManager + version string } func NewShimServer( ctx context.Context, address string, version string, runner TaskRunner, dcgmExporter *dcgm.DCGMExporter, dcgmWrapper dcgm.DCGMWrapperInterface, + runnerManager *components.RunnerManager, ) *ShimServer { + bgJobsCtx, bgJobsCancel := context.WithCancel(ctx) if dcgmWrapper != nil && reflect.ValueOf(dcgmWrapper).IsNil() { dcgmWrapper = nil } r := api.NewRouter() s := &ShimServer{ - HttpServer: &http.Server{ + httpServer: &http.Server{ Addr: address, Handler: r, BaseContext: func(l net.Listener) context.Context { return ctx }, }, + bgJobsCtx: bgJobsCtx, + bgJobsCancel: bgJobsCancel, + bgJobsGroup: &sync.WaitGroup{}, + runner: runner, dcgmExporter: dcgmExporter, dcgmWrapper: dcgmWrapper, + runnerManager: runnerManager, + version: version, } // The healthcheck endpoint should stay backward compatible, as it is used for negotiation r.AddHandler("GET", "/api/healthcheck", s.HealthcheckHandler) r.AddHandler("GET", "/api/instance/health", s.InstanceHealthHandler) + r.AddHandler("GET", "/api/components", s.ComponentListHandler) + r.AddHandler("POST", "/api/components/install", s.ComponentInstallHandler) r.AddHandler("GET", "/api/tasks", s.TaskListHandler) r.AddHandler("GET", "/api/tasks/{id}", s.TaskInfoHandler) r.AddHandler("POST", "/api/tasks", s.TaskSubmitHandler) @@ -70,3 +88,17 @@ func NewShimServer( return s } + +func (s *ShimServer) Serve() error { + if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil +} + +func (s *ShimServer) Shutdown(ctx context.Context) error { + s.bgJobsCancel() + err := s.httpServer.Shutdown(ctx) + s.bgJobsGroup.Wait() + return err +} diff --git a/runner/internal/shim/components/runner.go b/runner/internal/shim/components/runner.go new file mode 100644 index 0000000000..b18f51d3c3 --- /dev/null +++ b/runner/internal/shim/components/runner.go @@ -0,0 +1,94 @@ +package components + +import ( + "context" + "errors" + "fmt" + "os/exec" + "strings" + "sync" + + "github.com/dstackai/dstack/runner/internal/common" +) + +type RunnerManager struct { + path string + version string + status ComponentStatus + + mu *sync.RWMutex +} + +func NewRunnerManager(ctx context.Context, pth string) (*RunnerManager, error) { + m := RunnerManager{ + path: pth, + mu: &sync.RWMutex{}, + } + err := m.check(ctx) + return &m, err +} + +func (m *RunnerManager) GetInfo(ctx context.Context) ComponentInfo { + m.mu.RLock() + defer m.mu.RUnlock() + return ComponentInfo{ + Name: ComponentNameRunner, + Version: m.version, + Status: m.status, + } +} + +func (m *RunnerManager) Install(ctx context.Context, url string, force bool) error { + m.mu.Lock() + if m.status == ComponentStatusInstalling { + m.mu.Unlock() + return errors.New("install runner: already installing") + } + m.status = ComponentStatusInstalling + m.version = "" + m.mu.Unlock() + + downloadErr := downloadFile(ctx, url, m.path, 0o755, force) + // Recheck the binary even if the download has failed, just in case. + checkErr := m.check(ctx) + if downloadErr != nil { + return downloadErr + } + return checkErr +} + +func (m *RunnerManager) check(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + + exists, err := common.PathExists(m.path) + if err != nil { + m.status = ComponentStatusError + m.version = "" + return fmt.Errorf("check runner: %w", err) + } + if !exists { + m.status = ComponentStatusNotInstalled + m.version = "" + return nil + } + + cmd := exec.CommandContext(ctx, m.path, "--version") + output, err := cmd.Output() + if err != nil { + m.status = ComponentStatusError + m.version = "" + return fmt.Errorf("check runner: %w", err) + } + + rawVersion := string(output) // dstack-runner version 0.19.38 + versionFields := strings.Fields(rawVersion) + if len(versionFields) != 3 { + m.status = ComponentStatusError + m.version = "" + return fmt.Errorf("check runner: unexpected version output: %s", rawVersion) + } + m.status = ComponentStatusInstalled + m.version = versionFields[2] + return nil +} diff --git a/runner/internal/shim/components/types.go b/runner/internal/shim/components/types.go new file mode 100644 index 0000000000..13d1af857e --- /dev/null +++ b/runner/internal/shim/components/types.go @@ -0,0 +1,20 @@ +package components + +type ComponentName string + +const ComponentNameRunner ComponentName = "dstack-runner" + +type ComponentStatus string + +const ( + ComponentStatusNotInstalled ComponentStatus = "not-installed" + ComponentStatusInstalled ComponentStatus = "installed" + ComponentStatusInstalling ComponentStatus = "installing" + ComponentStatusError ComponentStatus = "error" +) + +type ComponentInfo struct { + Name ComponentName `json:"name"` + Version string `json:"version"` + Status ComponentStatus `json:"status"` +} diff --git a/runner/internal/shim/components/utils.go b/runner/internal/shim/components/utils.go new file mode 100644 index 0000000000..9161a64499 --- /dev/null +++ b/runner/internal/shim/components/utils.go @@ -0,0 +1,87 @@ +package components + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/dstackai/dstack/runner/internal/log" +) + +const downloadTimeout = 10 * time.Minute + +func downloadFile(ctx context.Context, url string, path string, mode os.FileMode, force bool) error { + if _, err := os.Stat(path); err == nil { + if force { + log.Debug(ctx, "file exists, forcing download", "path", path) + } else { + log.Debug(ctx, "file exists, skipping download", "path", path) + return nil + } + } else if !os.IsNotExist(err) { + return fmt.Errorf("check file exists: %w", err) + } + dir, name := filepath.Split(path) + tempFile, err := os.CreateTemp(dir, fmt.Sprintf(".*-%s", name)) + if err != nil { + return fmt.Errorf("create temp file for %s: %w", name, err) + } + defer func() { + if err := tempFile.Close(); err != nil { + log.Error(ctx, "close temp file", "err", err) + } + if err := os.Remove(tempFile.Name()); err != nil && !errors.Is(err, os.ErrNotExist) { + log.Error(ctx, "remove temp file", "err", err) + } + }() + + log.Debug(ctx, "downloading", "path", path, "url", url) + ctx, cancel := context.WithTimeout(ctx, downloadTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("create download request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("execute download request: %w", err) + } + + defer func() { + err := resp.Body.Close() + if err != nil { + log.Error(ctx, "downloadFile: close body error", "err", err) + } + }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code %s downloading %s from %s", resp.Status, name, url) + } + + written, err := io.Copy(tempFile, resp.Body) + if err != nil { + log.Error(ctx, "download file", "err", err, "bytes", written, "total", resp.ContentLength) + if err := os.Remove(tempFile.Name()); err != nil { + log.Error(ctx, "remove temp file", "err", err) + } + return fmt.Errorf("copy %s: %w", name, err) + } + log.Debug(ctx, "file has been downloaded", "path", path, "bytes", written) + + if err := tempFile.Chmod(mode); err != nil { + return fmt.Errorf("chmod %s: %w", path, err) + } + + if err := os.Rename(tempFile.Name(), path); err != nil { + return fmt.Errorf("move %s to %s: %w", name, path, err) + } + + return nil +} diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 04519f6dd7..56afe08938 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -1218,7 +1218,15 @@ func (c *CLIArgs) DockerShellCommands(publicKeys []string) []string { concatinatedPublicKeys = strings.Join(publicKeys, "\n") } commands := getSSHShellCommands(c.Runner.SSHPort, concatinatedPublicKeys) - commands = append(commands, fmt.Sprintf("%s %s", consts.RunnerBinaryPath, strings.Join(c.getRunnerArgs(), " "))) + runnerArgs := []string{ + "--log-level", strconv.Itoa(c.Runner.LogLevel), + "start", + "--http-port", strconv.Itoa(c.Runner.HTTPPort), + "--ssh-port", strconv.Itoa(c.Runner.SSHPort), + "--temp-dir", consts.RunnerTempDir, + "--home-dir", consts.RunnerHomeDir, + } + commands = append(commands, fmt.Sprintf("%s %s", consts.RunnerBinaryPath, strings.Join(runnerArgs, " "))) return commands } diff --git a/runner/internal/shim/runner.go b/runner/internal/shim/runner.go deleted file mode 100644 index 4ef8f5db6e..0000000000 --- a/runner/internal/shim/runner.go +++ /dev/null @@ -1,112 +0,0 @@ -package shim - -import ( - "context" - "errors" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strconv" - "time" - - "github.com/dstackai/dstack/runner/consts" - "github.com/dstackai/dstack/runner/internal/log" -) - -func (c *CLIArgs) DownloadRunner(ctx context.Context) error { - if c.Runner.DownloadURL == "" { - return nil - } - err := downloadRunner(ctx, c.Runner.DownloadURL, c.Runner.BinaryPath, false) - if err != nil { - return fmt.Errorf("download runner from %s: %w", c.Runner.DownloadURL, err) - } - return nil -} - -func (c *CLIArgs) getRunnerArgs() []string { - return []string{ - "--log-level", strconv.Itoa(c.Runner.LogLevel), - "start", - "--http-port", strconv.Itoa(c.Runner.HTTPPort), - "--ssh-port", strconv.Itoa(c.Runner.SSHPort), - "--temp-dir", consts.RunnerTempDir, - "--home-dir", consts.RunnerHomeDir, - } -} - -func downloadRunner(ctx context.Context, url string, path string, force bool) error { - if _, err := os.Stat(path); err == nil { - if force { - log.Info(ctx, "dstack-runner binary exists, forcing download", "path", path) - } else { - log.Info(ctx, "dstack-runner binary exists, skipping download", "path", path) - return nil - } - } else if !os.IsNotExist(err) { - return fmt.Errorf("check dstack-runner exists: %w", err) - } - tempFile, err := os.CreateTemp(filepath.Dir(path), "dstack-runner") - if err != nil { - return fmt.Errorf("create temp file for runner: %w", err) - } - defer func() { - err := tempFile.Close() - if err != nil { - log.Error(ctx, "close file error", "err", err) - } - }() - - log.Debug(ctx, "downloading runner", "url", url) - ctx, cancel := context.WithTimeout(ctx, 10*time.Minute) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return fmt.Errorf("create download request: %w", err) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("execute download request: %w", err) - } - - defer func() { - err := resp.Body.Close() - if err != nil { - log.Error(ctx, "downloadRunner: close body error", "err", err) - } - }() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code %s downloading runner from %s", resp.Status, url) - } - - written, err := io.Copy(tempFile, resp.Body) - if err != nil { - return fmt.Errorf("copy runner binary: %w", err) - } - - select { - case <-ctx.Done(): - err := ctx.Err() - if errors.Is(err, context.DeadlineExceeded) { - log.Error(ctx, "downloadRunner error", "err", err, "bytes", written, "total", resp.ContentLength) - return fmt.Errorf("download runner timeout after %d/%d bytes: %w", written, resp.ContentLength, err) - } - default: - log.Info(ctx, "the runner was downloaded successfully", "bytes", written) - } - - if err := tempFile.Chmod(0o755); err != nil { - return fmt.Errorf("chmod runner binary: %w", err) - } - - if err := os.Rename(tempFile.Name(), path); err != nil { - return fmt.Errorf("move runner binary to %s: %w", path, err) - } - - return nil -} diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index c680f4114a..1178068180 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -787,7 +787,9 @@ def normalize_arch(arch: Optional[str] = None) -> GoArchType: raise ValueError(f"Unsupported architecture: {arch}") -def get_dstack_runner_download_url(arch: Optional[str] = None) -> str: +def get_dstack_runner_download_url( + arch: Optional[str] = None, version: Optional[str] = None +) -> str: url_template = os.environ.get("DSTACK_RUNNER_DOWNLOAD_URL") if not url_template: if settings.DSTACK_VERSION is not None: @@ -798,7 +800,8 @@ def get_dstack_runner_download_url(arch: Optional[str] = None) -> str: f"https://{bucket}.s3.eu-west-1.amazonaws.com" "/{version}/binaries/dstack-runner-linux-{arch}" ) - version = get_dstack_runner_version() + if version is None: + version = get_dstack_runner_version() return url_template.format(version=version, arch=normalize_arch(arch).value) diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index d7cab39701..30ed2b1ec3 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -18,6 +18,8 @@ ComputeWithPlacementGroupSupport, GoArchType, get_dstack_runner_binary_path, + get_dstack_runner_download_url, + get_dstack_runner_version, get_dstack_shim_binary_path, get_dstack_working_dir, get_shim_env, @@ -62,7 +64,11 @@ ProjectModel, ) from dstack._internal.server.schemas.instances import InstanceCheck -from dstack._internal.server.schemas.runner import HealthcheckResponse, InstanceHealthResponse +from dstack._internal.server.schemas.runner import ( + ComponentStatus, + HealthcheckResponse, + InstanceHealthResponse, +) from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.fleets import ( fleet_model_to_fleet, @@ -116,6 +122,7 @@ from dstack._internal.utils.ssh import ( pkey_from_str, ) +from dstack._internal.utils.version import parse_version MIN_PROCESSING_INTERVAL = timedelta(seconds=10) @@ -910,15 +917,79 @@ def _check_instance_inner( args = (method.__func__.__name__, e.__class__.__name__, e) logger.exception(template, *args) return InstanceCheck(reachable=False, message=template % args) + + _maybe_update_runner(instance, shim_client) + try: remove_dangling_tasks_from_instance(shim_client, instance) except Exception as e: logger.exception("%s: error removing dangling tasks: %s", fmt(instance), e) + return runner_client.healthcheck_response_to_instance_check( healthcheck_response, instance_health_response ) +def _maybe_update_runner(instance: InstanceModel, shim_client: runner_client.ShimClient) -> None: + # To auto-update to the latest runner dev build from the CI, see DSTACK_USE_LATEST_FROM_BRANCH. + expected_version_str = get_dstack_runner_version() + try: + expected_version = parse_version(expected_version_str) + except ValueError as e: + logger.warning("Failed to parse expected runner version: %s", e) + return + if expected_version is None: + logger.debug("Cannot determine the expected runner version") + return + + try: + runner_info = shim_client.get_runner_info() + except requests.RequestException as e: + logger.warning("Instance %s: shim.get_runner_info(): request error: %s", instance.name, e) + return + if runner_info is None: + logger.debug("Instance %s: no runner info", instance.name) + return + + logger.debug( + "Instance %s: runner status=%s version=%s", + instance.name, + runner_info.status.value, + runner_info.version, + ) + if runner_info.status == ComponentStatus.INSTALLING: + return + + if runner_info.version: + try: + current_version = parse_version(runner_info.version) + except ValueError as e: + logger.warning("Instance %s: failed to parse runner version: %s", instance.name, e) + return + + if current_version is None or current_version >= expected_version: + logger.debug("Instance %s: the latest runner version already installed", instance.name) + return + + logger.debug( + "Instance %s: updating runner %s -> %s", + instance.name, + current_version, + expected_version, + ) + else: + logger.debug("Instance %s: installing runner %s", instance.name, expected_version) + + job_provisioning_data = get_or_error(get_instance_provisioning_data(instance)) + url = get_dstack_runner_download_url( + arch=job_provisioning_data.instance_type.resources.cpu_arch, version=expected_version_str + ) + try: + shim_client.install_runner(url) + except requests.RequestException as e: + logger.warning("Instance %s: shim.install_runner(): %s", instance.name, e) + + async def _terminate(instance: InstanceModel) -> None: if ( instance.last_termination_retry_at is not None diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 62ec3f6e3b..f88cf47a82 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -120,6 +120,32 @@ class InstanceHealthResponse(CoreModel): dcgm: Optional[DCGMHealthResponse] = None +class ComponentName(str, Enum): + RUNNER = "dstack-runner" + + +class ComponentStatus(str, Enum): + NOT_INSTALLED = "not-installed" + INSTALLED = "installed" + INSTALLING = "installing" + ERROR = "error" + + +class ComponentInfo(CoreModel): + name: ComponentName + version: str + status: ComponentStatus + + +class ComponentListResponse(CoreModel): + components: list[ComponentInfo] + + +class ComponentInstallRequest(CoreModel): + name: ComponentName + url: str + + class GPUMetrics(CoreModel): gpu_memory_usage_bytes: int gpu_util_percent: int diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 60f6c5d8c9..b270d4ea5f 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -15,6 +15,10 @@ from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint from dstack._internal.server.schemas.instances import InstanceCheck from dstack._internal.server.schemas.runner import ( + ComponentInfo, + ComponentInstallRequest, + ComponentListResponse, + ComponentName, GPUDevice, HealthcheckResponse, InstanceHealthResponse, @@ -189,10 +193,15 @@ class ShimClient: # `/api/instance/health` _INSTANCE_HEALTH_MIN_SHIM_VERSION = (0, 19, 22) + # `/api/components` + _COMPONENTS_RUNNER_MIN_SHIM_VERSION = (0, 19, 41) + _shim_version: Optional["_Version"] _api_version: int _negotiated: bool = False + _components: Optional[dict[ComponentName, ComponentInfo]] = None + def __init__( self, port: int, @@ -216,6 +225,14 @@ def is_instance_health_supported(self) -> bool: or self._shim_version >= self._INSTANCE_HEALTH_MIN_SHIM_VERSION ) + def is_runner_component_supported(self) -> bool: + if not self._negotiated: + self._negotiate() + return ( + self._shim_version is None + or self._shim_version >= self._COMPONENTS_RUNNER_MIN_SHIM_VERSION + ) + @overload def healthcheck(self) -> Optional[HealthcheckResponse]: ... @@ -246,6 +263,20 @@ def get_instance_health(self) -> Optional[InstanceHealthResponse]: self._raise_for_status(resp) return self._response(InstanceHealthResponse, resp) + def get_runner_info(self) -> Optional[ComponentInfo]: + if not self.is_runner_component_supported(): + logger.debug("runner info is not supported: %s", self._shim_version) + return None + components = self._get_components() + return components.get(ComponentName.RUNNER) + + def install_runner(self, url: str) -> None: + body = ComponentInstallRequest( + name=ComponentName.RUNNER, + url=url, + ) + self._request("POST", "/api/components/install", body, raise_for_status=True) + def list_tasks(self) -> TaskListResponse: if not self.is_api_v2_supported(): raise ShimAPIVersionError() @@ -444,6 +475,15 @@ def _negotiate(self, healthcheck_response: Optional[requests.Response] = None) - self._api_version = api_version self._negotiated = True + def _get_components(self) -> dict[ComponentName, ComponentInfo]: + resp = self._request("GET", "/api/components") + # TODO: Remove this check after 0.19.41 release, use _request(..., raise_for_status=True) + if resp.status_code == HTTPStatus.NOT_FOUND and self._shim_version is None: + # Old dev build of shim + return {} + resp.raise_for_status() + return {c.name: c for c in self._response(ComponentListResponse, resp).components} + def healthcheck_response_to_instance_check( response: HealthcheckResponse, diff --git a/src/tests/_internal/server/background/tasks/test_process_instances.py b/src/tests/_internal/server/background/tasks/test_process_instances.py index 690cb71d95..8661b81ee0 100644 --- a/src/tests/_internal/server/background/tasks/test_process_instances.py +++ b/src/tests/_internal/server/background/tasks/test_process_instances.py @@ -1,9 +1,10 @@ import datetime as dt +import logging from collections import defaultdict from collections.abc import Generator from contextlib import contextmanager from typing import Optional -from unittest.mock import Mock, call, patch +from unittest.mock import MagicMock, Mock, call, patch import gpuhunt import pytest @@ -44,11 +45,16 @@ from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse, DCGMHealthResult from dstack._internal.server.schemas.instances import InstanceCheck from dstack._internal.server.schemas.runner import ( + ComponentInfo, + ComponentName, + ComponentStatus, + HealthcheckResponse, InstanceHealthResponse, TaskListItem, TaskListResponse, TaskStatus, ) +from dstack._internal.server.services.runner.client import ShimClient from dstack._internal.server.testing.common import ( ComputeMockSpec, create_fleet, @@ -385,6 +391,14 @@ async def test_check_shim_check_instance_health(self, test_db, session: AsyncSes class TestRemoveDanglingTasks: + @pytest.fixture(autouse=True) + def disable_runner_update_check(self) -> Generator[None, None, None]: + with patch( + "dstack._internal.server.background.tasks.process_instances.get_dstack_runner_version" + ) as get_dstack_runner_version_mock: + get_dstack_runner_version_mock.return_value = "latest" + yield + @pytest.fixture def ssh_tunnel_mock(self) -> Generator[Mock, None, None]: with patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock: @@ -1145,3 +1159,160 @@ async def test_deletes_instance_health_checks( all_checks = res.scalars().all() assert len(all_checks) == 1 assert all_checks[0] == check + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures( + "test_db", "ssh_tunnel_mock", "shim_client_mock", "get_dstack_runner_version_mock" +) +class TestMaybeUpdateRunner: + @pytest.fixture + def ssh_tunnel_mock(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("dstack._internal.server.services.runner.ssh.SSHTunnel", MagicMock()) + + @pytest.fixture + def shim_client_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(spec_set=ShimClient) + mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-shim", version="0.19.40" + ) + mock.get_instance_health.return_value = InstanceHealthResponse() + mock.get_runner_info.return_value = ComponentInfo( + name=ComponentName.RUNNER, version="0.19.40", status=ComponentStatus.INSTALLED + ) + mock.list_tasks.return_value = TaskListResponse(tasks=[]) + monkeypatch.setattr( + "dstack._internal.server.services.runner.client.ShimClient", Mock(return_value=mock) + ) + return mock + + @pytest.fixture + def get_dstack_runner_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value="0.19.41") + monkeypatch.setattr( + "dstack._internal.server.background.tasks.process_instances.get_dstack_runner_version", + mock, + ) + return mock + + @pytest.fixture + def get_dstack_runner_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value="https://example.com/runner") + monkeypatch.setattr( + "dstack._internal.server.background.tasks.process_instances.get_dstack_runner_download_url", + mock, + ) + return mock + + async def test_cannot_determine_expected_version( + self, + caplog: pytest.LogCaptureFixture, + session: AsyncSession, + shim_client_mock: Mock, + get_dstack_runner_version_mock: Mock, + ): + caplog.set_level(logging.DEBUG) + project = await create_project(session=session) + await create_instance(session=session, project=project, status=InstanceStatus.IDLE) + get_dstack_runner_version_mock.return_value = "latest" + + await process_instances() + + assert "Cannot determine the expected runner version" in caplog.text + shim_client_mock.get_runner_info.assert_not_called() + shim_client_mock.install_runner.assert_not_called() + + async def test_failed_to_parse_current_version( + self, + caplog: pytest.LogCaptureFixture, + session: AsyncSession, + shim_client_mock: Mock, + ): + caplog.set_level(logging.WARNING) + project = await create_project(session=session) + await create_instance(session=session, project=project, status=InstanceStatus.IDLE) + shim_client_mock.get_runner_info.return_value.version = "invalid" + + await process_instances() + + assert "failed to parse runner version" in caplog.text + shim_client_mock.get_runner_info.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + @pytest.mark.parametrize("current_version", ["latest", "0.0.0", "0.19.41", "0.19.42"]) + async def test_latest_version_already_installed( + self, + caplog: pytest.LogCaptureFixture, + session: AsyncSession, + shim_client_mock: Mock, + current_version: str, + ): + caplog.set_level(logging.DEBUG) + project = await create_project(session=session) + await create_instance(session=session, project=project, status=InstanceStatus.IDLE) + shim_client_mock.get_runner_info.return_value.version = current_version + + await process_instances() + + assert "the latest runner version already installed" in caplog.text + shim_client_mock.get_runner_info.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + async def test_install_not_installed( + self, + caplog: pytest.LogCaptureFixture, + session: AsyncSession, + shim_client_mock: Mock, + get_dstack_runner_download_url_mock: Mock, + ): + caplog.set_level(logging.DEBUG) + project = await create_project(session=session) + await create_instance(session=session, project=project, status=InstanceStatus.IDLE) + shim_client_mock.get_runner_info.return_value.version = "" + shim_client_mock.get_runner_info.return_value.status = ComponentStatus.NOT_INSTALLED + + await process_instances() + + assert "installing runner 0.19.41" in caplog.text + get_dstack_runner_download_url_mock.assert_called_once_with(arch=None, version="0.19.41") + shim_client_mock.get_runner_info.assert_called_once() + shim_client_mock.install_runner.assert_called_once_with( + get_dstack_runner_download_url_mock.return_value + ) + + async def test_update_outdated( + self, + caplog: pytest.LogCaptureFixture, + session: AsyncSession, + shim_client_mock: Mock, + get_dstack_runner_download_url_mock: Mock, + ): + caplog.set_level(logging.DEBUG) + project = await create_project(session=session) + await create_instance(session=session, project=project, status=InstanceStatus.IDLE) + shim_client_mock.get_runner_info.return_value.version = "0.19.38" + + await process_instances() + + assert "updating runner 0.19.38 -> 0.19.41" in caplog.text + get_dstack_runner_download_url_mock.assert_called_once_with(arch=None, version="0.19.41") + shim_client_mock.get_runner_info.assert_called_once() + shim_client_mock.install_runner.assert_called_once_with( + get_dstack_runner_download_url_mock.return_value + ) + + async def test_already_updating( + self, + session: AsyncSession, + shim_client_mock: Mock, + ): + project = await create_project(session=session) + await create_instance(session=session, project=project, status=InstanceStatus.IDLE) + shim_client_mock.get_runner_info.return_value.version = "0.19.38" + shim_client_mock.get_runner_info.return_value.status = ComponentStatus.INSTALLING + + await process_instances() + + shim_client_mock.get_runner_info.assert_called_once() + shim_client_mock.install_runner.assert_not_called()