From 9afdf2dc30fd9980cb5818812064db503c00fee7 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Wed, 17 Dec 2025 09:51:30 +0000 Subject: [PATCH] Implement shim auto-update shim binary is replaced at any time, but restart is postponed until all tasks are terminated, as safe restart with running tasks requires additional work (see _get_restart_safe_task_statuses() comment). Closes: https://github.com/dstackai/dstack/issues/3288 --- runner/cmd/shim/main.go | 27 +- runner/consts/consts.go | 3 + runner/docs/shim.openapi.yaml | 51 ++- runner/internal/shim/api/handlers.go | 57 ++- runner/internal/shim/api/handlers_test.go | 4 +- runner/internal/shim/api/schemas.go | 4 + runner/internal/shim/api/server.go | 36 +- runner/internal/shim/components/runner.go | 41 +- runner/internal/shim/components/shim.go | 61 +++ runner/internal/shim/components/types.go | 12 +- runner/internal/shim/components/utils.go | 29 ++ runner/internal/shim/models.go | 7 +- .../_internal/core/backends/base/compute.py | 82 +++- .../background/tasks/process_instances.py | 179 ++++++-- src/dstack/_internal/server/schemas/runner.py | 7 +- .../server/services/gateways/__init__.py | 2 +- .../server/services/runner/client.py | 158 +++++-- .../_internal/server/utils/provisioning.py | 15 +- src/dstack/_internal/settings.py | 6 + .../core/backends/base/test_compute.py | 7 +- .../tasks/test_process_instances.py | 423 ++++++++++++++---- .../server/services/runner/test_client.py | 91 +++- 22 files changed, 1043 insertions(+), 259 deletions(-) create mode 100644 runner/internal/shim/components/shim.go diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index af468a6a93..79aefbda6a 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -40,6 +40,11 @@ func mainInner() int { log.DefaultEntry.Logger.SetLevel(logrus.Level(defaultLogLevel)) log.DefaultEntry.Logger.SetOutput(os.Stderr) + shimBinaryPath, err := os.Executable() + if err != nil { + shimBinaryPath = consts.ShimBinaryPath + } + cmd := &cli.Command{ Name: "dstack-shim", Usage: "Starts dstack-runner or docker container.", @@ -54,6 +59,14 @@ func mainInner() int { DefaultText: path.Join("~", consts.DstackDirPath), Sources: cli.EnvVars("DSTACK_SHIM_HOME"), }, + &cli.StringFlag{ + Name: "shim-binary-path", + Usage: "Path to shim's binary", + Value: shimBinaryPath, + Destination: &args.Shim.BinaryPath, + TakesFile: true, + Sources: cli.EnvVars("DSTACK_SHIM_BINARY_PATH"), + }, &cli.IntFlag{ Name: "shim-http-port", Usage: "Set shim's http port", @@ -172,6 +185,7 @@ func mainInner() int { func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) { log.DefaultEntry.Logger.SetLevel(logrus.Level(args.Shim.LogLevel)) + log.Info(ctx, "Starting dstack-shim", "version", Version) shimHomeDir := args.Shim.HomeDir if shimHomeDir == "" { @@ -211,6 +225,10 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) } else if runnerErr != nil { return runnerErr } + shimManager, shimErr := components.NewShimManager(ctx, args.Shim.BinaryPath) + if shimErr != nil { + return shimErr + } log.Debug(ctx, "Shim", "args", args.Shim) log.Debug(ctx, "Runner", "args", args.Runner) @@ -259,7 +277,11 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) } address := fmt.Sprintf("localhost:%d", args.Shim.HTTPPort) - shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper, runnerManager) + shimServer := api.NewShimServer( + ctx, address, Version, + dockerRunner, dcgmExporter, dcgmWrapper, + runnerManager, shimManager, + ) if serviceMode { if err := shim.WriteHostInfo(shimHomeDir, dockerRunner.Resources(ctx)); err != nil { @@ -278,6 +300,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) if err := shimServer.Serve(); err != nil { serveErrCh <- err } + close(serveErrCh) }() select { @@ -287,7 +310,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second) defer cancelShutdown() - shutdownErr := shimServer.Shutdown(shutdownCtx) + shutdownErr := shimServer.Shutdown(shutdownCtx, false) if serveErr != nil { return serveErr } diff --git a/runner/consts/consts.go b/runner/consts/consts.go index aa0b8d056f..2c392b5ee4 100644 --- a/runner/consts/consts.go +++ b/runner/consts/consts.go @@ -13,6 +13,9 @@ const ( // 2. A default path on the host unless overridden via shim CLI const RunnerBinaryPath = "/usr/local/bin/dstack-runner" +// A fallback path on the host used if os.Executable() has failed +const ShimBinaryPath = "/usr/local/bin/dstack-shim" + // Error-containing messages will be identified by this signature const ExecutorFailedSignature = "Executor failed" diff --git a/runner/docs/shim.openapi.yaml b/runner/docs/shim.openapi.yaml index e6f49fa079..e375e4e9d3 100644 --- a/runner/docs/shim.openapi.yaml +++ b/runner/docs/shim.openapi.yaml @@ -2,7 +2,7 @@ openapi: 3.1.2 info: title: dstack-shim API - version: v2/0.19.41 + version: v2/0.20.1 x-logo: url: https://avatars.githubusercontent.com/u/54146142?s=260 description: > @@ -41,7 +41,7 @@ paths: **Important**: Since this endpoint is used for negotiation, it should always stay backward/future compatible, specifically the `version` field - + tags: [shim] responses: "200": description: "" @@ -50,6 +50,29 @@ paths: schema: $ref: "#/components/schemas/HealthcheckResponse" + /shutdown: + post: + summary: Request shim shutdown + description: | + (since [0.20.1](https://github.com/dstackai/dstack/releases/tag/0.20.1)) Request shim to shut down itself. + Restart must be handled by an external process supervisor, e.g., `systemd`. + + **Note**: background jobs (e.g., component installation) are canceled regardless of the `force` option. + tags: [shim] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ShutdownRequest" + responses: + "200": + description: Request accepted + $ref: "#/components/responses/PlainTextOk" + "400": + description: Malformed JSON body or validation error + $ref: "#/components/responses/PlainTextBadRequest" + /instance/health: get: summary: Get instance health @@ -66,7 +89,7 @@ paths: /components: get: summary: Get components - description: (since [0.19.41](https://github.com/dstackai/dstack/releases/tag/0.19.41)) Returns a list of software components (e.g., `dstack-runner`) + description: (since [0.20.0](https://github.com/dstackai/dstack/releases/tag/0.20.0)) Returns a list of software components (e.g., `dstack-runner`) tags: [Components] responses: "200": @@ -80,7 +103,7 @@ paths: post: summary: Install component description: > - (since [0.19.41](https://github.com/dstackai/dstack/releases/tag/0.19.41)) Request installing/updating the software component. + (since [0.20.0](https://github.com/dstackai/dstack/releases/tag/0.20.0)) Request installing/updating the software component. Components are installed asynchronously tags: [Components] requestBody: @@ -410,6 +433,10 @@ components: type: string enum: - dstack-runner + - dstack-shim + description: | + * (since [0.20.0](https://github.com/dstackai/dstack/releases/tag/0.20.0)) `dstack-runner` + * (since [0.20.1](https://github.com/dstackai/dstack/releases/tag/0.20.1)) `dstack-shim` ComponentStatus: title: shim.components.ComponentStatus @@ -430,7 +457,7 @@ components: type: string description: An empty string if status != installed examples: - - 0.19.41 + - 0.20.1 status: allOf: - $ref: "#/components/schemas/ComponentStatus" @@ -457,6 +484,18 @@ components: - version additionalProperties: false + ShutdownRequest: + title: shim.api.ShutdownRequest + type: object + properties: + force: + type: boolean + examples: + - false + description: If `true`, don't wait for background job coroutines to complete after canceling them and close HTTP server forcefully. + required: + - force + InstanceHealthResponse: title: shim.api.InstanceHealthResponse type: object @@ -486,7 +525,7 @@ components: url: type: string examples: - - https://dstack-runner-downloads.s3.eu-west-1.amazonaws.com/0.19.41/binaries/dstack-runner-linux-amd64 + - https://dstack-runner-downloads.s3.eu-west-1.amazonaws.com/0.20.1/binaries/dstack-runner-linux-amd64 required: - name - url diff --git a/runner/internal/shim/api/handlers.go b/runner/internal/shim/api/handlers.go index 7e4f172272..dc1be824cb 100644 --- a/runner/internal/shim/api/handlers.go +++ b/runner/internal/shim/api/handlers.go @@ -22,6 +22,21 @@ func (s *ShimServer) HealthcheckHandler(w http.ResponseWriter, r *http.Request) }, nil } +func (s *ShimServer) ShutdownHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + var req ShutdownRequest + if err := api.DecodeJSONBody(w, r, &req, true); err != nil { + return nil, err + } + + go func() { + if err := s.Shutdown(s.ctx, req.Force); err != nil { + log.Error(s.ctx, "Shutdown", "err", err) + } + }() + + return nil, nil +} + func (s *ShimServer) InstanceHealthHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { ctx := r.Context() response := InstanceHealthResponse{} @@ -159,9 +174,11 @@ func (s *ShimServer) TaskMetricsHandler(w http.ResponseWriter, r *http.Request) } func (s *ShimServer) ComponentListHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { - runnerStatus := s.runnerManager.GetInfo(r.Context()) response := &ComponentListResponse{ - Components: []components.ComponentInfo{runnerStatus}, + Components: []components.ComponentInfo{ + s.runnerManager.GetInfo(r.Context()), + s.shimManager.GetInfo(r.Context()), + }, } return response, nil } @@ -176,27 +193,31 @@ func (s *ShimServer) ComponentInstallHandler(w http.ResponseWriter, r *http.Requ return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty name"} } + var componentManager components.ComponentManager switch components.ComponentName(req.Name) { case components.ComponentNameRunner: - if req.URL == "" { - return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty url"} - } - - // There is still a small chance of time-of-check race condition, but we ignore it. - runnerInfo := s.runnerManager.GetInfo(r.Context()) - if runnerInfo.Status == components.ComponentStatusInstalling { - return nil, &api.Error{Status: http.StatusConflict, Msg: "already installing"} - } - - s.bgJobsGroup.Go(func() { - if err := s.runnerManager.Install(s.bgJobsCtx, req.URL, true); err != nil { - log.Error(s.bgJobsCtx, "runner background install", "err", err) - } - }) - + componentManager = s.runnerManager + case components.ComponentNameShim: + componentManager = s.shimManager default: return nil, &api.Error{Status: http.StatusBadRequest, Msg: "unknown component"} } + if req.URL == "" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty url"} + } + + // There is still a small chance of time-of-check race condition, but we ignore it. + componentInfo := componentManager.GetInfo(r.Context()) + if componentInfo.Status == components.ComponentStatusInstalling { + return nil, &api.Error{Status: http.StatusConflict, Msg: "already installing"} + } + + s.bgJobsGroup.Go(func() { + if err := componentManager.Install(s.bgJobsCtx, req.URL, true); err != nil { + log.Error(s.bgJobsCtx, "component background install", "name", componentInfo.Name, "err", err) + } + }) + return nil, nil } diff --git a/runner/internal/shim/api/handlers_test.go b/runner/internal/shim/api/handlers_test.go index c04621eb0a..9bc829a94c 100644 --- a/runner/internal/shim/api/handlers_test.go +++ b/runner/internal/shim/api/handlers_test.go @@ -13,7 +13,7 @@ func TestHealthcheck(t *testing.T) { request := httptest.NewRequest("GET", "/api/healthcheck", nil) responseRecorder := httptest.NewRecorder() - server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil) + server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil) f := common.JSONResponseHandler(server.HealthcheckHandler) f(responseRecorder, request) @@ -30,7 +30,7 @@ func TestHealthcheck(t *testing.T) { } func TestTaskSubmit(t *testing.T) { - server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil) + server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil) requestBody := `{ "id": "dummy-id", "name": "dummy-name", diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index a7d5fa7d48..cd0db6a202 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -11,6 +11,10 @@ type HealthcheckResponse struct { Version string `json:"version"` } +type ShutdownRequest struct { + Force bool `json:"force"` +} + type InstanceHealthResponse struct { DCGM *dcgm.Health `json:"dcgm"` } diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index 15e0191354..0482db7945 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/dstackai/dstack/runner/internal/api" + "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/shim" "github.com/dstackai/dstack/runner/internal/shim/components" "github.com/dstackai/dstack/runner/internal/shim/dcgm" @@ -26,8 +27,11 @@ type TaskRunner interface { } type ShimServer struct { - httpServer *http.Server - mu sync.RWMutex + httpServer *http.Server + mu sync.RWMutex + ctx context.Context + inShutdown bool + inForceShutdown bool bgJobsCtx context.Context bgJobsCancel context.CancelFunc @@ -38,7 +42,8 @@ type ShimServer struct { dcgmExporter *dcgm.DCGMExporter dcgmWrapper dcgm.DCGMWrapperInterface // interface with nil value normalized to plain nil - runnerManager *components.RunnerManager + runnerManager components.ComponentManager + shimManager components.ComponentManager version string } @@ -46,7 +51,7 @@ type ShimServer struct { func NewShimServer( ctx context.Context, address string, version string, runner TaskRunner, dcgmExporter *dcgm.DCGMExporter, dcgmWrapper dcgm.DCGMWrapperInterface, - runnerManager *components.RunnerManager, + runnerManager components.ComponentManager, shimManager components.ComponentManager, ) *ShimServer { bgJobsCtx, bgJobsCancel := context.WithCancel(ctx) if dcgmWrapper != nil && reflect.ValueOf(dcgmWrapper).IsNil() { @@ -59,6 +64,7 @@ func NewShimServer( Handler: r, BaseContext: func(l net.Listener) context.Context { return ctx }, }, + ctx: ctx, bgJobsCtx: bgJobsCtx, bgJobsCancel: bgJobsCancel, @@ -70,12 +76,14 @@ func NewShimServer( dcgmWrapper: dcgmWrapper, runnerManager: runnerManager, + shimManager: shimManager, version: version, } // The healthcheck endpoint should stay backward compatible, as it is used for negotiation r.AddHandler("GET", "/api/healthcheck", s.HealthcheckHandler) + r.AddHandler("POST", "/api/shutdown", s.ShutdownHandler) r.AddHandler("GET", "/api/instance/health", s.InstanceHealthHandler) r.AddHandler("GET", "/api/components", s.ComponentListHandler) r.AddHandler("POST", "/api/components/install", s.ComponentInstallHandler) @@ -96,8 +104,26 @@ func (s *ShimServer) Serve() error { return nil } -func (s *ShimServer) Shutdown(ctx context.Context) error { +func (s *ShimServer) Shutdown(ctx context.Context, force bool) error { + s.mu.Lock() + + if s.inForceShutdown || s.inShutdown && !force { + log.Info(ctx, "Already shutting down, ignoring request") + s.mu.Unlock() + return nil + } + + s.inShutdown = true + if force { + s.inForceShutdown = true + } + s.mu.Unlock() + + log.Info(ctx, "Shutting down", "force", force) s.bgJobsCancel() + if force { + return s.httpServer.Close() + } err := s.httpServer.Shutdown(ctx) s.bgJobsGroup.Wait() return err diff --git a/runner/internal/shim/components/runner.go b/runner/internal/shim/components/runner.go index b18f51d3c3..3dc361a251 100644 --- a/runner/internal/shim/components/runner.go +++ b/runner/internal/shim/components/runner.go @@ -2,13 +2,8 @@ package components import ( "context" - "errors" "fmt" - "os/exec" - "strings" "sync" - - "github.com/dstackai/dstack/runner/internal/common" ) type RunnerManager struct { @@ -42,7 +37,7 @@ func (m *RunnerManager) Install(ctx context.Context, url string, force bool) err m.mu.Lock() if m.status == ComponentStatusInstalling { m.mu.Unlock() - return errors.New("install runner: already installing") + return fmt.Errorf("install %s: already installing", ComponentNameRunner) } m.status = ComponentStatusInstalling m.version = "" @@ -57,38 +52,10 @@ func (m *RunnerManager) Install(ctx context.Context, url string, force bool) err return checkErr } -func (m *RunnerManager) check(ctx context.Context) error { +func (m *RunnerManager) check(ctx context.Context) (err error) { m.mu.Lock() defer m.mu.Unlock() - exists, err := common.PathExists(m.path) - if err != nil { - m.status = ComponentStatusError - m.version = "" - return fmt.Errorf("check runner: %w", err) - } - if !exists { - m.status = ComponentStatusNotInstalled - m.version = "" - return nil - } - - cmd := exec.CommandContext(ctx, m.path, "--version") - output, err := cmd.Output() - if err != nil { - m.status = ComponentStatusError - m.version = "" - return fmt.Errorf("check runner: %w", err) - } - - rawVersion := string(output) // dstack-runner version 0.19.38 - versionFields := strings.Fields(rawVersion) - if len(versionFields) != 3 { - m.status = ComponentStatusError - m.version = "" - return fmt.Errorf("check runner: unexpected version output: %s", rawVersion) - } - m.status = ComponentStatusInstalled - m.version = versionFields[2] - return nil + m.status, m.version, err = checkDstackComponent(ctx, ComponentNameRunner, m.path) + return err } diff --git a/runner/internal/shim/components/shim.go b/runner/internal/shim/components/shim.go new file mode 100644 index 0000000000..5ac9b08d39 --- /dev/null +++ b/runner/internal/shim/components/shim.go @@ -0,0 +1,61 @@ +package components + +import ( + "context" + "fmt" + "sync" +) + +type ShimManager struct { + path string + version string + status ComponentStatus + + mu *sync.RWMutex +} + +func NewShimManager(ctx context.Context, pth string) (*ShimManager, error) { + m := ShimManager{ + path: pth, + mu: &sync.RWMutex{}, + } + err := m.check(ctx) + return &m, err +} + +func (m *ShimManager) GetInfo(ctx context.Context) ComponentInfo { + m.mu.RLock() + defer m.mu.RUnlock() + return ComponentInfo{ + Name: ComponentNameShim, + Version: m.version, + Status: m.status, + } +} + +func (m *ShimManager) Install(ctx context.Context, url string, force bool) error { + m.mu.Lock() + if m.status == ComponentStatusInstalling { + m.mu.Unlock() + return fmt.Errorf("install %s: already installing", ComponentNameShim) + } + m.status = ComponentStatusInstalling + m.version = "" + m.mu.Unlock() + + downloadErr := downloadFile(ctx, url, m.path, 0o755, force) + // Recheck the binary even if the download has failed, just in case. + checkErr := m.check(ctx) + if downloadErr != nil { + return downloadErr + } + return checkErr +} + +func (m *ShimManager) check(ctx context.Context) (err error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.status, m.version, err = checkDstackComponent(ctx, ComponentNameShim, m.path) + return err +} diff --git a/runner/internal/shim/components/types.go b/runner/internal/shim/components/types.go index 13d1af857e..57c205af53 100644 --- a/runner/internal/shim/components/types.go +++ b/runner/internal/shim/components/types.go @@ -1,8 +1,13 @@ package components +import "context" + type ComponentName string -const ComponentNameRunner ComponentName = "dstack-runner" +const ( + ComponentNameRunner ComponentName = "dstack-runner" + ComponentNameShim ComponentName = "dstack-shim" +) type ComponentStatus string @@ -18,3 +23,8 @@ type ComponentInfo struct { Version string `json:"version"` Status ComponentStatus `json:"status"` } + +type ComponentManager interface { + GetInfo(ctx context.Context) ComponentInfo + Install(ctx context.Context, url string, force bool) error +} diff --git a/runner/internal/shim/components/utils.go b/runner/internal/shim/components/utils.go index 9161a64499..073832133d 100644 --- a/runner/internal/shim/components/utils.go +++ b/runner/internal/shim/components/utils.go @@ -7,9 +7,12 @@ import ( "io" "net/http" "os" + "os/exec" "path/filepath" + "strings" "time" + "github.com/dstackai/dstack/runner/internal/common" "github.com/dstackai/dstack/runner/internal/log" ) @@ -85,3 +88,29 @@ func downloadFile(ctx context.Context, url string, path string, mode os.FileMode return nil } + +func checkDstackComponent(ctx context.Context, name ComponentName, pth string) (status ComponentStatus, version string, err error) { + exists, err := common.PathExists(pth) + if err != nil { + return ComponentStatusError, "", fmt.Errorf("check %s: %w", name, err) + } + if !exists { + return ComponentStatusNotInstalled, "", nil + } + + cmd := exec.CommandContext(ctx, pth, "--version") + output, err := cmd.Output() + if err != nil { + return ComponentStatusError, "", fmt.Errorf("check %s: %w", name, err) + } + + rawVersion := string(output) // dstack-{shim,runner} version 0.19.38 + versionFields := strings.Fields(rawVersion) + if len(versionFields) != 3 { + return ComponentStatusError, "", fmt.Errorf("check %s: unexpected version output: %s", name, rawVersion) + } + if versionFields[0] != string(name) { + return ComponentStatusError, "", fmt.Errorf("check %s: unexpected component name: %s", name, versionFields[0]) + } + return ComponentStatusInstalled, versionFields[2], nil +} diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index b8da12670d..0a0c697eec 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -15,9 +15,10 @@ type DockerParameters interface { type CLIArgs struct { Shim struct { - HTTPPort int - HomeDir string - LogLevel int + HTTPPort int + HomeDir string + BinaryPath string + LogLevel int } Runner struct { diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index a0ff70c1ba..802aecb654 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -51,6 +51,7 @@ logger = get_logger(__name__) DSTACK_SHIM_BINARY_NAME = "dstack-shim" +DSTACK_SHIM_RESTART_INTERVAL_SECONDS = 3 DSTACK_RUNNER_BINARY_NAME = "dstack-runner" DEFAULT_PRIVATE_SUBNETS = ("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16") NVIDIA_GPUS_REQUIRING_PROPRIETARY_KERNEL_MODULES = frozenset( @@ -758,13 +759,35 @@ def get_shim_commands( return commands -def get_dstack_runner_version() -> str: - if settings.DSTACK_VERSION is not None: - return settings.DSTACK_VERSION - version = os.environ.get("DSTACK_RUNNER_VERSION", None) - if version is None and settings.DSTACK_USE_LATEST_FROM_BRANCH: - version = get_latest_runner_build() - return version or "latest" +def get_dstack_runner_version() -> Optional[str]: + if version := settings.DSTACK_VERSION: + return version + if version := settings.DSTACK_RUNNER_VERSION: + return version + if version_url := settings.DSTACK_RUNNER_VERSION_URL: + return _fetch_version(version_url) + if settings.DSTACK_USE_LATEST_FROM_BRANCH: + return get_latest_runner_build() + return None + + +def get_dstack_shim_version() -> Optional[str]: + if version := settings.DSTACK_VERSION: + return version + if version := settings.DSTACK_SHIM_VERSION: + return version + if version := settings.DSTACK_RUNNER_VERSION: + logger.warning( + "DSTACK_SHIM_VERSION is not set, using DSTACK_RUNNER_VERSION." + " Future versions will not fall back to DSTACK_RUNNER_VERSION." + " Set DSTACK_SHIM_VERSION to supress this warning." + ) + return version + if version_url := settings.DSTACK_SHIM_VERSION_URL: + return _fetch_version(version_url) + if settings.DSTACK_USE_LATEST_FROM_BRANCH: + return get_latest_runner_build() + return None def normalize_arch(arch: Optional[str] = None) -> GoArchType: @@ -789,7 +812,7 @@ def normalize_arch(arch: Optional[str] = None) -> GoArchType: def get_dstack_runner_download_url( arch: Optional[str] = None, version: Optional[str] = None ) -> str: - url_template = os.environ.get("DSTACK_RUNNER_DOWNLOAD_URL") + url_template = settings.DSTACK_RUNNER_DOWNLOAD_URL if not url_template: if settings.DSTACK_VERSION is not None: bucket = "dstack-runner-downloads" @@ -800,12 +823,12 @@ def get_dstack_runner_download_url( "/{version}/binaries/dstack-runner-linux-{arch}" ) if version is None: - version = get_dstack_runner_version() - return url_template.format(version=version, arch=normalize_arch(arch).value) + version = get_dstack_runner_version() or "latest" + return _format_download_url(url_template, version, arch) -def get_dstack_shim_download_url(arch: Optional[str] = None) -> str: - url_template = os.environ.get("DSTACK_SHIM_DOWNLOAD_URL") +def get_dstack_shim_download_url(arch: Optional[str] = None, version: Optional[str] = None) -> str: + url_template = settings.DSTACK_SHIM_DOWNLOAD_URL if not url_template: if settings.DSTACK_VERSION is not None: bucket = "dstack-runner-downloads" @@ -815,8 +838,9 @@ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str: f"https://{bucket}.s3.eu-west-1.amazonaws.com" "/{version}/binaries/dstack-shim-linux-{arch}" ) - version = get_dstack_runner_version() - return url_template.format(version=version, arch=normalize_arch(arch).value) + if version is None: + version = get_dstack_shim_version() or "latest" + return _format_download_url(url_template, version, arch) def get_setup_cloud_instance_commands( @@ -878,8 +902,16 @@ def get_run_shim_script( dstack_shim_binary_path = get_dstack_shim_binary_path(bin_path) privileged_flag = "--privileged" if is_privileged else "" pjrt_device_env = f"--pjrt-device={pjrt_device}" if pjrt_device else "" + # TODO: Use a proper process supervisor? return [ - f"nohup {dstack_shim_binary_path} {privileged_flag} {pjrt_device_env} &", + f""" + nohup sh -c ' + while true; do + {dstack_shim_binary_path} {privileged_flag} {pjrt_device_env} + sleep {DSTACK_SHIM_RESTART_INTERVAL_SECONDS} + done + ' & + """, ] @@ -1022,9 +1054,7 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non channel = "release" if settings.DSTACK_RELEASE else "stgn" base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}" if build == "latest": - r = requests.get(f"{base_url}/latest-version", timeout=5) - r.raise_for_status() - build = r.text.strip() + build = _fetch_version(f"{base_url}/latest-version") or "latest" logger.debug("Found the latest gateway build: %s", build) wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" # Build package spec with extras if router is specified @@ -1034,7 +1064,7 @@ def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = Non def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]: - build = get_dstack_runner_version() + build = get_dstack_runner_version() or "latest" gateway_package = get_dstack_gateway_wheel(build, router) return [ "mkdir -p /home/ubuntu/dstack", @@ -1069,3 +1099,17 @@ def requires_nvidia_proprietary_kernel_modules(gpu_name: str) -> bool: instead of open kernel modules. """ return gpu_name.lower() in NVIDIA_GPUS_REQUIRING_PROPRIETARY_KERNEL_MODULES + + +def _fetch_version(url: str) -> Optional[str]: + r = requests.get(url, timeout=5) + r.raise_for_status() + version = r.text.strip() + if not version: + logger.warning("Empty version response from URL: %s", url) + return None + return version + + +def _format_download_url(template: str, version: str, arch: Optional[str]) -> str: + return template.format(version=version, arch=normalize_arch(arch).value) diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 30ed2b1ec3..7d54171765 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -4,6 +4,7 @@ from datetime import timedelta from typing import Any, Dict, Optional, cast +import gpuhunt import requests from paramiko.pkey import PKey from paramiko.ssh_exception import PasswordRequiredException @@ -21,6 +22,8 @@ get_dstack_runner_download_url, get_dstack_runner_version, get_dstack_shim_binary_path, + get_dstack_shim_download_url, + get_dstack_shim_version, get_dstack_working_dir, get_shim_env, get_shim_pre_start_commands, @@ -65,6 +68,7 @@ ) from dstack._internal.server.schemas.instances import InstanceCheck from dstack._internal.server.schemas.runner import ( + ComponentInfo, ComponentStatus, HealthcheckResponse, InstanceHealthResponse, @@ -122,7 +126,6 @@ from dstack._internal.utils.ssh import ( pkey_from_str, ) -from dstack._internal.utils.version import parse_version MIN_PROCESSING_INTERVAL = timedelta(seconds=10) @@ -918,76 +921,170 @@ def _check_instance_inner( logger.exception(template, *args) return InstanceCheck(reachable=False, message=template % args) - _maybe_update_runner(instance, shim_client) - try: remove_dangling_tasks_from_instance(shim_client, instance) except Exception as e: logger.exception("%s: error removing dangling tasks: %s", fmt(instance), e) + # There should be no shim API calls after this function call since it can request shim restart. + _maybe_install_components(instance, shim_client) + return runner_client.healthcheck_response_to_instance_check( healthcheck_response, instance_health_response ) -def _maybe_update_runner(instance: InstanceModel, shim_client: runner_client.ShimClient) -> None: - # To auto-update to the latest runner dev build from the CI, see DSTACK_USE_LATEST_FROM_BRANCH. - expected_version_str = get_dstack_runner_version() +def _maybe_install_components( + instance: InstanceModel, shim_client: runner_client.ShimClient +) -> None: try: - expected_version = parse_version(expected_version_str) - except ValueError as e: - logger.warning("Failed to parse expected runner version: %s", e) + components = shim_client.get_components() + except requests.RequestException as e: + logger.warning("Instance %s: shim.get_components(): request error: %s", instance.name, e) return - if expected_version is None: - logger.debug("Cannot determine the expected runner version") + if components is None: + logger.debug("Instance %s: no components info", instance.name) return - try: - runner_info = shim_client.get_runner_info() - except requests.RequestException as e: - logger.warning("Instance %s: shim.get_runner_info(): request error: %s", instance.name, e) - return - if runner_info is None: + installed_shim_version: Optional[str] = None + installation_requested = False + + if (runner_info := components.runner) is not None: + installation_requested |= _maybe_install_runner(instance, shim_client, runner_info) + else: logger.debug("Instance %s: no runner info", instance.name) + + if (shim_info := components.shim) is not None: + if shim_info.status == ComponentStatus.INSTALLED: + installed_shim_version = shim_info.version + installation_requested |= _maybe_install_shim(instance, shim_client, shim_info) + else: + logger.debug("Instance %s: no shim info", instance.name) + + running_shim_version = shim_client.get_version_string() + if ( + # old shim without `dstack-shim` component and `/api/shutdown` support + installed_shim_version is None + # or the same version is already running + or installed_shim_version == running_shim_version + # or we just requested installation of at least one component + or installation_requested + # or at least one component is already being installed + or any(c.status == ComponentStatus.INSTALLING for c in components) + # or at least one shim task won't survive restart + or not shim_client.is_safe_to_restart() + ): return + if shim_client.shutdown(force=False): + logger.debug( + "Instance %s: restarting shim %s -> %s", + instance.name, + running_shim_version, + installed_shim_version, + ) + else: + logger.debug("Instance %s: cannot restart shim", instance.name) + + +def _maybe_install_runner( + instance: InstanceModel, shim_client: runner_client.ShimClient, runner_info: ComponentInfo +) -> bool: + # For developers: + # * To install the latest dev build for the current branch from the CI, + # set DSTACK_USE_LATEST_FROM_BRANCH=1. + # * To provide your own build, set DSTACK_RUNNER_VERSION_URL and DSTACK_RUNNER_DOWNLOAD_URL. + expected_version = get_dstack_runner_version() + if expected_version is None: + logger.debug("Cannot determine the expected runner version") + return False + + installed_version = runner_info.version logger.debug( - "Instance %s: runner status=%s version=%s", + "Instance %s: runner status=%s installed_version=%s", instance.name, runner_info.status.value, - runner_info.version, + installed_version or "(no version)", ) - if runner_info.status == ComponentStatus.INSTALLING: - return - if runner_info.version: - try: - current_version = parse_version(runner_info.version) - except ValueError as e: - logger.warning("Instance %s: failed to parse runner version: %s", instance.name, e) - return - - if current_version is None or current_version >= expected_version: - logger.debug("Instance %s: the latest runner version already installed", instance.name) - return + if runner_info.status == ComponentStatus.INSTALLING: + logger.debug("Instance %s: runner is already being installed", instance.name) + return False - logger.debug( - "Instance %s: updating runner %s -> %s", - instance.name, - current_version, - expected_version, - ) - else: - logger.debug("Instance %s: installing runner %s", instance.name, expected_version) + if installed_version and installed_version == expected_version: + logger.debug("Instance %s: expected runner version already installed", instance.name) + return False - job_provisioning_data = get_or_error(get_instance_provisioning_data(instance)) url = get_dstack_runner_download_url( - arch=job_provisioning_data.instance_type.resources.cpu_arch, version=expected_version_str + arch=_get_instance_cpu_arch(instance), version=expected_version + ) + logger.debug( + "Instance %s: installing runner %s -> %s from %s", + instance.name, + installed_version or "(no version)", + expected_version, + url, ) try: shim_client.install_runner(url) + return True except requests.RequestException as e: logger.warning("Instance %s: shim.install_runner(): %s", instance.name, e) + return False + + +def _maybe_install_shim( + instance: InstanceModel, shim_client: runner_client.ShimClient, shim_info: ComponentInfo +) -> bool: + # For developers: + # * To install the latest dev build for the current branch from the CI, + # set DSTACK_USE_LATEST_FROM_BRANCH=1. + # * To provide your own build, set DSTACK_SHIM_VERSION_URL and DSTACK_SHIM_DOWNLOAD_URL. + expected_version = get_dstack_shim_version() + if expected_version is None: + logger.debug("Cannot determine the expected shim version") + return False + + installed_version = shim_info.version + logger.debug( + "Instance %s: shim status=%s installed_version=%s running_version=%s", + instance.name, + shim_info.status.value, + installed_version or "(no version)", + shim_client.get_version_string(), + ) + + if shim_info.status == ComponentStatus.INSTALLING: + logger.debug("Instance %s: shim is already being installed", instance.name) + return False + + if installed_version and installed_version == expected_version: + logger.debug("Instance %s: expected shim version already installed", instance.name) + return False + + url = get_dstack_shim_download_url( + arch=_get_instance_cpu_arch(instance), version=expected_version + ) + logger.debug( + "Instance %s: installing shim %s -> %s from %s", + instance.name, + installed_version or "(no version)", + expected_version, + url, + ) + try: + shim_client.install_shim(url) + return True + except requests.RequestException as e: + logger.warning("Instance %s: shim.install_shim(): %s", instance.name, e) + return False + + +def _get_instance_cpu_arch(instance: InstanceModel) -> Optional[gpuhunt.CPUArchitecture]: + jpd = get_instance_provisioning_data(instance) + if jpd is None: + return None + return jpd.instance_type.resources.cpu_arch async def _terminate(instance: InstanceModel) -> None: diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index f3c3614b58..12ff6c6825 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -121,8 +121,13 @@ class InstanceHealthResponse(CoreModel): dcgm: Optional[DCGMHealthResponse] = None +class ShutdownRequest(CoreModel): + force: bool + + class ComponentName(str, Enum): RUNNER = "dstack-runner" + SHIM = "dstack-shim" class ComponentStatus(str, Enum): @@ -133,7 +138,7 @@ class ComponentStatus(str, Enum): class ComponentInfo(CoreModel): - name: ComponentName + name: str # Not using ComponentName enum for compatibility of newer shim with older server version: str status: ComponentStatus diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 682feaf31b..4ab80a8331 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -412,7 +412,7 @@ async def init_gateways(session: AsyncSession): if settings.SKIP_GATEWAY_UPDATE: logger.debug("Skipping gateways update due to DSTACK_SKIP_GATEWAY_UPDATE env variable") else: - build = get_dstack_runner_version() + build = get_dstack_runner_version() or "latest" for gateway_compute, res in await gather_map_async( gateway_computes, diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index b270d4ea5f..c83a42b744 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -1,10 +1,12 @@ import uuid +from collections.abc import Generator from http import HTTPStatus from typing import BinaryIO, Dict, List, Literal, Optional, TypeVar, Union, overload import packaging.version import requests import requests.exceptions +from typing_extensions import Self from dstack._internal.core.errors import DstackError from dstack._internal.core.models.common import CoreModel, NetworkMode @@ -28,9 +30,11 @@ MetricsResponse, PullResponse, ShimVolumeInfo, + ShutdownRequest, SubmitBody, TaskInfoResponse, TaskListResponse, + TaskStatus, TaskSubmitRequest, TaskTerminateRequest, ) @@ -143,7 +147,7 @@ class ShimError(DstackError): pass -class ShimHTTPError(DstackError): +class ShimHTTPError(ShimError): """ An HTTP error wrapper for `requests.exceptions.HTTPError`. Should be used as follows: @@ -185,6 +189,47 @@ class ShimAPIVersionError(ShimError): pass +class ComponentList: + _items: dict[ComponentName, ComponentInfo] + + def __init__(self) -> None: + self._items = {} + + def __iter__(self) -> Generator[ComponentInfo, None, None]: + for component_info in self._items.values(): + yield component_info + + @classmethod + def from_response(cls, response: ComponentListResponse) -> Self: + components = cls() + for component_info in response.components: + try: + components.add(component_info) + except ValueError as e: + logger.warning("Error processing ComponentInfo: %s", e) + return components + + @property + def runner(self) -> Optional[ComponentInfo]: + return self.get(ComponentName.RUNNER) + + @property + def shim(self) -> Optional[ComponentInfo]: + return self.get(ComponentName.SHIM) + + def get(self, name: ComponentName) -> Optional[ComponentInfo]: + return self._items.get(name) + + def add(self, component_info: ComponentInfo) -> None: + try: + name = ComponentName(component_info.name) + except ValueError as e: + raise ValueError(f"Unknown component: {component_info.name}") from e + if name in self._items: + raise ValueError(f"Duplicate component: {component_info.name}") + self._items[name] = component_info + + class ShimClient: # API v2 (a.k.a. Future API) — `/api/tasks/[:id[/{terminate,remove}]]` # API v1 (a.k.a. Legacy API) — `/api/{submit,pull,stop}` @@ -194,14 +239,16 @@ class ShimClient: _INSTANCE_HEALTH_MIN_SHIM_VERSION = (0, 19, 22) # `/api/components` - _COMPONENTS_RUNNER_MIN_SHIM_VERSION = (0, 19, 41) + _COMPONENTS_MIN_SHIM_VERSION = (0, 20, 0) + + # `/api/shutdown` + _SHUTDOWN_MIN_SHIM_VERSION = (0, 20, 1) - _shim_version: Optional["_Version"] + _shim_version_string: str + _shim_version_tuple: Optional["_Version"] _api_version: int _negotiated: bool = False - _components: Optional[dict[ComponentName, ComponentInfo]] = None - def __init__( self, port: int, @@ -212,6 +259,16 @@ def __init__( # Methods shared by all API versions + def get_version_string(self) -> str: + if not self._negotiated: + self._negotiate() + return self._shim_version_string + + def get_version_tuple(self) -> Optional["_Version"]: + if not self._negotiated: + self._negotiate() + return self._shim_version_tuple + def is_api_v2_supported(self) -> bool: if not self._negotiated: self._negotiate() @@ -221,16 +278,24 @@ def is_instance_health_supported(self) -> bool: if not self._negotiated: self._negotiate() return ( - self._shim_version is None - or self._shim_version >= self._INSTANCE_HEALTH_MIN_SHIM_VERSION + self._shim_version_tuple is None + or self._shim_version_tuple >= self._INSTANCE_HEALTH_MIN_SHIM_VERSION ) - def is_runner_component_supported(self) -> bool: + def are_components_supported(self) -> bool: if not self._negotiated: self._negotiate() return ( - self._shim_version is None - or self._shim_version >= self._COMPONENTS_RUNNER_MIN_SHIM_VERSION + self._shim_version_tuple is None + or self._shim_version_tuple >= self._COMPONENTS_MIN_SHIM_VERSION + ) + + def is_shutdown_supported(self) -> bool: + if not self._negotiated: + self._negotiate() + return ( + self._shim_version_tuple is None + or self._shim_version_tuple >= self._SHUTDOWN_MIN_SHIM_VERSION ) @overload @@ -254,7 +319,7 @@ def healthcheck(self, unmask_exceptions: bool = False) -> Optional[HealthcheckRe def get_instance_health(self) -> Optional[InstanceHealthResponse]: if not self.is_instance_health_supported(): - logger.debug("instance health is not supported: %s", self._shim_version) + logger.debug("instance health is not supported: %s", self._shim_version_string) return None resp = self._request("GET", "/api/instance/health") if resp.status_code == HTTPStatus.NOT_FOUND: @@ -263,12 +328,37 @@ def get_instance_health(self) -> Optional[InstanceHealthResponse]: self._raise_for_status(resp) return self._response(InstanceHealthResponse, resp) - def get_runner_info(self) -> Optional[ComponentInfo]: - if not self.is_runner_component_supported(): - logger.debug("runner info is not supported: %s", self._shim_version) + def shutdown(self, *, force: bool) -> bool: + if not self.is_shutdown_supported(): + logger.debug("shim shutdown is not supported: %s", self._shim_version_string) + return False + body = ShutdownRequest(force=force) + resp = self._request("POST", "/api/shutdown", body) + # TODO: Remove this check after 0.20.1 release, use _request(..., raise_for_status=True) + if resp.status_code == HTTPStatus.NOT_FOUND and self._shim_version_tuple is None: + # Old dev build of shim + logger.debug("shim shutdown is not supported: %s", self._shim_version_string) + return False + self._raise_for_status(resp) + return True + + def is_safe_to_restart(self) -> bool: + if not self.is_api_v2_supported(): + # old shim, `/api/shutdown` is not supported anyway + return False + task_list = self.list_tasks() + if (tasks := task_list.tasks) is None: + # old shim, `/api/shutdown` is not supported anyway + return False + restart_safe_task_statuses = self._get_restart_safe_task_statuses() + return all(t.status in restart_safe_task_statuses for t in tasks) + + def get_components(self) -> Optional[ComponentList]: + if not self.are_components_supported(): + logger.debug("components are not supported: %s", self._shim_version_string) return None - components = self._get_components() - return components.get(ComponentName.RUNNER) + resp = self._request("GET", "/api/components", raise_for_status=True) + return ComponentList.from_response(self._response(ComponentListResponse, resp)) def install_runner(self, url: str) -> None: body = ComponentInstallRequest( @@ -277,6 +367,13 @@ def install_runner(self, url: str) -> None: ) self._request("POST", "/api/components/install", body, raise_for_status=True) + def install_shim(self, url: str) -> None: + body = ComponentInstallRequest( + name=ComponentName.SHIM, + url=url, + ) + self._request("POST", "/api/components/install", body, raise_for_status=True) + def list_tasks(self) -> TaskListResponse: if not self.is_api_v2_supported(): raise ShimAPIVersionError() @@ -459,30 +556,23 @@ def _raise_for_status(self, response: requests.Response) -> None: def _negotiate(self, healthcheck_response: Optional[requests.Response] = None) -> None: if healthcheck_response is None: healthcheck_response = self._request("GET", "/api/healthcheck", raise_for_status=True) - raw_version = self._response(HealthcheckResponse, healthcheck_response).version - version = _parse_version(raw_version) - if version is None or version >= self._API_V2_MIN_SHIM_VERSION: + version_string = self._response(HealthcheckResponse, healthcheck_response).version + version_tuple = _parse_version(version_string) + if version_tuple is None or version_tuple >= self._API_V2_MIN_SHIM_VERSION: api_version = 2 else: api_version = 1 - logger.debug( - "shim version: %s %s (API v%s)", - raw_version, - version or "(latest)", - api_version, - ) - self._shim_version = version + self._shim_version_string = version_string + self._shim_version_tuple = version_tuple self._api_version = api_version self._negotiated = True - def _get_components(self) -> dict[ComponentName, ComponentInfo]: - resp = self._request("GET", "/api/components") - # TODO: Remove this check after 0.19.41 release, use _request(..., raise_for_status=True) - if resp.status_code == HTTPStatus.NOT_FOUND and self._shim_version is None: - # Old dev build of shim - return {} - resp.raise_for_status() - return {c.name: c for c in self._response(ComponentListResponse, resp).components} + def _get_restart_safe_task_statuses(self) -> list[TaskStatus]: + # TODO: Rework shim's DockerRunner.Run() so that it does not wait for container termination + # (this at least requires replacing .waitContainer() with periodic polling of container + # statuses and moving some cleanup defer calls to .Terminate() and/or .Remove()) and add + # TaskStatus.RUNNING to the list of restart-safe task statuses for supported shim versions. + return [TaskStatus.TERMINATED] def healthcheck_response_to_instance_check( diff --git a/src/dstack/_internal/server/utils/provisioning.py b/src/dstack/_internal/server/utils/provisioning.py index 632dce777a..fcbe3bf086 100644 --- a/src/dstack/_internal/server/utils/provisioning.py +++ b/src/dstack/_internal/server/utils/provisioning.py @@ -8,7 +8,11 @@ import paramiko from gpuhunt import AcceleratorVendor, correct_gpu_memory_gib -from dstack._internal.core.backends.base.compute import GoArchType, normalize_arch +from dstack._internal.core.backends.base.compute import ( + DSTACK_SHIM_RESTART_INTERVAL_SECONDS, + GoArchType, + normalize_arch, +) from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT # FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute @@ -116,16 +120,23 @@ def run_pre_start_commands( def run_shim_as_systemd_service( client: paramiko.SSHClient, binary_path: str, working_dir: str, dev: bool ) -> None: + # Stop restart attempts after ≈ 1 hour + start_limit_interval_seconds = 3600 + start_limit_burst = int( + start_limit_interval_seconds / DSTACK_SHIM_RESTART_INTERVAL_SECONDS * 0.9 + ) shim_service = dedent(f"""\ [Unit] Description=dstack-shim After=network-online.target + StartLimitIntervalSec={start_limit_interval_seconds} + StartLimitBurst={start_limit_burst} [Service] Type=simple User=root Restart=always - RestartSec=10 + RestartSec={DSTACK_SHIM_RESTART_INTERVAL_SECONDS} WorkingDirectory={working_dir} EnvironmentFile={working_dir}/{DSTACK_SHIM_ENV_FILE} ExecStart={binary_path} diff --git a/src/dstack/_internal/settings.py b/src/dstack/_internal/settings.py index 245681411d..81682480a2 100644 --- a/src/dstack/_internal/settings.py +++ b/src/dstack/_internal/settings.py @@ -10,6 +10,12 @@ # TODO: update the code to treat 0.0.0 as dev version. DSTACK_VERSION = None DSTACK_RELEASE = os.getenv("DSTACK_RELEASE") is not None or version.__is_release__ +DSTACK_RUNNER_VERSION = os.getenv("DSTACK_RUNNER_VERSION") +DSTACK_RUNNER_VERSION_URL = os.getenv("DSTACK_RUNNER_VERSION_URL") +DSTACK_RUNNER_DOWNLOAD_URL = os.getenv("DSTACK_RUNNER_DOWNLOAD_URL") +DSTACK_SHIM_VERSION = os.getenv("DSTACK_SHIM_VERSION") +DSTACK_SHIM_VERSION_URL = os.getenv("DSTACK_SHIM_VERSION_URL") +DSTACK_SHIM_DOWNLOAD_URL = os.getenv("DSTACK_SHIM_DOWNLOAD_URL") DSTACK_USE_LATEST_FROM_BRANCH = os.getenv("DSTACK_USE_LATEST_FROM_BRANCH") is not None diff --git a/src/tests/_internal/core/backends/base/test_compute.py b/src/tests/_internal/core/backends/base/test_compute.py index 848aea822c..7892a3f0f5 100644 --- a/src/tests/_internal/core/backends/base/test_compute.py +++ b/src/tests/_internal/core/backends/base/test_compute.py @@ -1,6 +1,7 @@ import re from typing import Optional +import gpuhunt import pytest from dstack._internal.core.backends.base.compute import ( @@ -62,11 +63,13 @@ def test_validates_project_name(self): class TestNormalizeArch: - @pytest.mark.parametrize("arch", [None, "", "X86", "x86_64", "AMD64"]) + @pytest.mark.parametrize( + "arch", [None, "", "X86", "x86_64", "AMD64", gpuhunt.CPUArchitecture.X86] + ) def test_amd64(self, arch: Optional[str]): assert normalize_arch(arch) is GoArchType.AMD64 - @pytest.mark.parametrize("arch", ["arm", "ARM64", "AArch64"]) + @pytest.mark.parametrize("arch", ["arm", "ARM64", "AArch64", gpuhunt.CPUArchitecture.ARM]) def test_arm64(self, arch: str): assert normalize_arch(arch) is GoArchType.ARM64 diff --git a/src/tests/_internal/server/background/tasks/test_process_instances.py b/src/tests/_internal/server/background/tasks/test_process_instances.py index e7c44ab434..cb5028c42b 100644 --- a/src/tests/_internal/server/background/tasks/test_process_instances.py +++ b/src/tests/_internal/server/background/tasks/test_process_instances.py @@ -8,6 +8,7 @@ import gpuhunt import pytest +import pytest_asyncio from freezegun import freeze_time from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -41,7 +42,11 @@ delete_instance_health_checks, process_instances, ) -from dstack._internal.server.models import InstanceHealthCheckModel, PlacementGroupModel +from dstack._internal.server.models import ( + InstanceHealthCheckModel, + InstanceModel, + PlacementGroupModel, +) from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse, DCGMHealthResult from dstack._internal.server.schemas.instances import InstanceCheck from dstack._internal.server.schemas.runner import ( @@ -54,7 +59,7 @@ TaskListResponse, TaskStatus, ) -from dstack._internal.server.services.runner.client import ShimClient +from dstack._internal.server.services.runner.client import ComponentList, ShimClient from dstack._internal.server.testing.common import ( ComputeMockSpec, create_fleet, @@ -390,14 +395,14 @@ async def test_check_shim_check_instance_health(self, test_db, session: AsyncSes assert health_check.response == health_response.json() +@pytest.mark.usefixtures("disable_maybe_install_components") class TestRemoveDanglingTasks: - @pytest.fixture(autouse=True) - def disable_runner_update_check(self) -> Generator[None, None, None]: - with patch( - "dstack._internal.server.background.tasks.process_instances.get_dstack_runner_version" - ) as get_dstack_runner_version_mock: - get_dstack_runner_version_mock.return_value = "latest" - yield + @pytest.fixture + def disable_maybe_install_components(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "dstack._internal.server.background.tasks.process_instances._maybe_install_components", + Mock(return_value=None), + ) @pytest.fixture def ssh_tunnel_mock(self) -> Generator[Mock, None, None]: @@ -1163,33 +1168,71 @@ async def test_deletes_instance_health_checks( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) -@pytest.mark.usefixtures( - "test_db", "ssh_tunnel_mock", "shim_client_mock", "get_dstack_runner_version_mock" -) -class TestMaybeUpdateRunner: +@pytest.mark.usefixtures("test_db", "instance", "ssh_tunnel_mock", "shim_client_mock") +class BaseTestMaybeInstallComponents: + EXPECTED_VERSION = "0.20.1" + + @pytest_asyncio.fixture + async def instance(self, session: AsyncSession) -> InstanceModel: + project = await create_project(session=session) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + return instance + + @pytest.fixture + def component_list(self) -> ComponentList: + return ComponentList() + + @pytest.fixture + def debug_task_log(self, caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture: + caplog.set_level( + level=logging.DEBUG, + logger="dstack._internal.server.background.tasks.process_instances", + ) + return caplog + @pytest.fixture def ssh_tunnel_mock(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr("dstack._internal.server.services.runner.ssh.SSHTunnel", MagicMock()) @pytest.fixture - def shim_client_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + def shim_client_mock( + self, + monkeypatch: pytest.MonkeyPatch, + component_list: ComponentList, + ) -> Mock: mock = Mock(spec_set=ShimClient) mock.healthcheck.return_value = HealthcheckResponse( - service="dstack-shim", version="0.19.40" + service="dstack-shim", version=self.EXPECTED_VERSION ) mock.get_instance_health.return_value = InstanceHealthResponse() - mock.get_runner_info.return_value = ComponentInfo( - name=ComponentName.RUNNER, version="0.19.40", status=ComponentStatus.INSTALLED - ) + mock.get_components.return_value = component_list mock.list_tasks.return_value = TaskListResponse(tasks=[]) + mock.is_safe_to_restart.return_value = False monkeypatch.setattr( "dstack._internal.server.services.runner.client.ShimClient", Mock(return_value=mock) ) return mock + +@pytest.mark.usefixtures("get_dstack_runner_version_mock") +class TestMaybeInstallRunner(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.RUNNER, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + @pytest.fixture def get_dstack_runner_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: - mock = Mock(return_value="0.19.41") + mock = Mock(return_value=self.EXPECTED_VERSION) monkeypatch.setattr( "dstack._internal.server.background.tasks.process_instances.get_dstack_runner_version", mock, @@ -1207,112 +1250,328 @@ def get_dstack_runner_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) - async def test_cannot_determine_expected_version( self, - caplog: pytest.LogCaptureFixture, - session: AsyncSession, + debug_task_log: pytest.LogCaptureFixture, shim_client_mock: Mock, get_dstack_runner_version_mock: Mock, ): - caplog.set_level(logging.DEBUG) - project = await create_project(session=session) - await create_instance(session=session, project=project, status=InstanceStatus.IDLE) - get_dstack_runner_version_mock.return_value = "latest" + get_dstack_runner_version_mock.return_value = None await process_instances() - assert "Cannot determine the expected runner version" in caplog.text - shim_client_mock.get_runner_info.assert_not_called() + assert "Cannot determine the expected runner version" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() shim_client_mock.install_runner.assert_not_called() - async def test_failed_to_parse_current_version( - self, - caplog: pytest.LogCaptureFixture, - session: AsyncSession, - shim_client_mock: Mock, + async def test_expected_version_already_installed( + self, debug_task_log: pytest.LogCaptureFixture, shim_client_mock: Mock ): - caplog.set_level(logging.WARNING) - project = await create_project(session=session) - await create_instance(session=session, project=project, status=InstanceStatus.IDLE) - shim_client_mock.get_runner_info.return_value.version = "invalid" + shim_client_mock.get_components.return_value.runner.version = self.EXPECTED_VERSION await process_instances() - assert "failed to parse runner version" in caplog.text - shim_client_mock.get_runner_info.assert_called_once() + assert "expected runner version already installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() shim_client_mock.install_runner.assert_not_called() - @pytest.mark.parametrize("current_version", ["latest", "0.0.0", "0.19.41", "0.19.42"]) - async def test_latest_version_already_installed( + @pytest.mark.parametrize("status", [ComponentStatus.NOT_INSTALLED, ComponentStatus.ERROR]) + async def test_install_not_installed_or_error( self, - caplog: pytest.LogCaptureFixture, - session: AsyncSession, + debug_task_log: pytest.LogCaptureFixture, shim_client_mock: Mock, - current_version: str, + get_dstack_runner_download_url_mock: Mock, + status: ComponentStatus, ): - caplog.set_level(logging.DEBUG) - project = await create_project(session=session) - await create_instance(session=session, project=project, status=InstanceStatus.IDLE) - shim_client_mock.get_runner_info.return_value.version = current_version + shim_client_mock.get_components.return_value.runner.version = "" + shim_client_mock.get_components.return_value.runner.status = status await process_instances() - assert "the latest runner version already installed" in caplog.text - shim_client_mock.get_runner_info.assert_called_once() - shim_client_mock.install_runner.assert_not_called() + assert f"installing runner (no version) -> {self.EXPECTED_VERSION}" in debug_task_log.text + get_dstack_runner_download_url_mock.assert_called_once_with( + arch=None, version=self.EXPECTED_VERSION + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_called_once_with( + get_dstack_runner_download_url_mock.return_value + ) - async def test_install_not_installed( + @pytest.mark.parametrize("installed_version", ["0.19.40", "0.21.0", "dev"]) + async def test_install_installed( self, - caplog: pytest.LogCaptureFixture, - session: AsyncSession, + debug_task_log: pytest.LogCaptureFixture, shim_client_mock: Mock, get_dstack_runner_download_url_mock: Mock, + installed_version: str, ): - caplog.set_level(logging.DEBUG) - project = await create_project(session=session) - await create_instance(session=session, project=project, status=InstanceStatus.IDLE) - shim_client_mock.get_runner_info.return_value.version = "" - shim_client_mock.get_runner_info.return_value.status = ComponentStatus.NOT_INSTALLED + shim_client_mock.get_components.return_value.runner.version = installed_version await process_instances() - assert "installing runner 0.19.41" in caplog.text - get_dstack_runner_download_url_mock.assert_called_once_with(arch=None, version="0.19.41") - shim_client_mock.get_runner_info.assert_called_once() + assert ( + f"installing runner {installed_version} -> {self.EXPECTED_VERSION}" + in debug_task_log.text + ) + get_dstack_runner_download_url_mock.assert_called_once_with( + arch=None, version=self.EXPECTED_VERSION + ) + shim_client_mock.get_components.assert_called_once() shim_client_mock.install_runner.assert_called_once_with( get_dstack_runner_download_url_mock.return_value ) - async def test_update_outdated( + async def test_already_installing( + self, debug_task_log: pytest.LogCaptureFixture, shim_client_mock: Mock + ): + shim_client_mock.get_components.return_value.runner.version = "dev" + shim_client_mock.get_components.return_value.runner.status = ComponentStatus.INSTALLING + + await process_instances() + + assert "runner is already being installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + +@pytest.mark.usefixtures("get_dstack_shim_version_mock") +class TestMaybeInstallShim(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.SHIM, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + + @pytest.fixture + def get_dstack_shim_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=self.EXPECTED_VERSION) + monkeypatch.setattr( + "dstack._internal.server.background.tasks.process_instances.get_dstack_shim_version", + mock, + ) + return mock + + @pytest.fixture + def get_dstack_shim_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value="https://example.com/shim") + monkeypatch.setattr( + "dstack._internal.server.background.tasks.process_instances.get_dstack_shim_download_url", + mock, + ) + return mock + + async def test_cannot_determine_expected_version( self, - caplog: pytest.LogCaptureFixture, - session: AsyncSession, + debug_task_log: pytest.LogCaptureFixture, shim_client_mock: Mock, - get_dstack_runner_download_url_mock: Mock, + get_dstack_shim_version_mock: Mock, ): - caplog.set_level(logging.DEBUG) - project = await create_project(session=session) - await create_instance(session=session, project=project, status=InstanceStatus.IDLE) - shim_client_mock.get_runner_info.return_value.version = "0.19.38" + get_dstack_shim_version_mock.return_value = None await process_instances() - assert "updating runner 0.19.38 -> 0.19.41" in caplog.text - get_dstack_runner_download_url_mock.assert_called_once_with(arch=None, version="0.19.41") - shim_client_mock.get_runner_info.assert_called_once() - shim_client_mock.install_runner.assert_called_once_with( - get_dstack_runner_download_url_mock.return_value + assert "Cannot determine the expected shim version" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + async def test_expected_version_already_installed( + self, debug_task_log: pytest.LogCaptureFixture, shim_client_mock: Mock + ): + shim_client_mock.get_components.return_value.shim.version = self.EXPECTED_VERSION + + await process_instances() + + assert "expected shim version already installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + @pytest.mark.parametrize("status", [ComponentStatus.NOT_INSTALLED, ComponentStatus.ERROR]) + async def test_install_not_installed_or_error( + self, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_shim_download_url_mock: Mock, + status: ComponentStatus, + ): + shim_client_mock.get_components.return_value.shim.version = "" + shim_client_mock.get_components.return_value.shim.status = status + + await process_instances() + + assert f"installing shim (no version) -> {self.EXPECTED_VERSION}" in debug_task_log.text + get_dstack_shim_download_url_mock.assert_called_once_with( + arch=None, version=self.EXPECTED_VERSION + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_called_once_with( + get_dstack_shim_download_url_mock.return_value ) - async def test_already_updating( + @pytest.mark.parametrize("installed_version", ["0.19.40", "0.21.0", "dev"]) + async def test_install_installed( self, - session: AsyncSession, + debug_task_log: pytest.LogCaptureFixture, shim_client_mock: Mock, + get_dstack_shim_download_url_mock: Mock, + installed_version: str, ): - project = await create_project(session=session) - await create_instance(session=session, project=project, status=InstanceStatus.IDLE) - shim_client_mock.get_runner_info.return_value.version = "0.19.38" - shim_client_mock.get_runner_info.return_value.status = ComponentStatus.INSTALLING + shim_client_mock.get_components.return_value.shim.version = installed_version await process_instances() - shim_client_mock.get_runner_info.assert_called_once() - shim_client_mock.install_runner.assert_not_called() + assert ( + f"installing shim {installed_version} -> {self.EXPECTED_VERSION}" + in debug_task_log.text + ) + get_dstack_shim_download_url_mock.assert_called_once_with( + arch=None, version=self.EXPECTED_VERSION + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_called_once_with( + get_dstack_shim_download_url_mock.return_value + ) + + async def test_already_installing( + self, debug_task_log: pytest.LogCaptureFixture, shim_client_mock: Mock + ): + shim_client_mock.get_components.return_value.shim.version = "dev" + shim_client_mock.get_components.return_value.shim.status = ComponentStatus.INSTALLING + + await process_instances() + + assert "shim is already being installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + +@pytest.mark.usefixtures("maybe_install_runner_mock", "maybe_install_shim_mock") +class TestMaybeRestartShim(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.RUNNER, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + components.add( + ComponentInfo( + name=ComponentName.SHIM, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + + @pytest.fixture + def maybe_install_runner_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=False) + monkeypatch.setattr( + "dstack._internal.server.background.tasks.process_instances._maybe_install_runner", + mock, + ) + return mock + + @pytest.fixture + def maybe_install_shim_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=False) + monkeypatch.setattr( + "dstack._internal.server.background.tasks.process_instances._maybe_install_shim", + mock, + ) + return mock + + async def test_up_to_date(self, shim_client_mock: Mock): + shim_client_mock.get_version_string.return_value = self.EXPECTED_VERSION + shim_client_mock.is_safe_to_restart.return_value = True + + await process_instances() + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_no_shim_component_info(self, shim_client_mock: Mock): + shim_client_mock.get_components.return_value = ComponentList() + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + + await process_instances() + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_shutdown_requested(self, shim_client_mock: Mock): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + + await process_instances() + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_called_once_with(force=False) + + async def test_outdated_but_task_wont_survive_restart(self, shim_client_mock: Mock): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = False + + await process_instances() + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_runner_installation_in_progress( + self, shim_client_mock: Mock, component_list: ComponentList + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + runner_info = component_list.runner + assert runner_info is not None + runner_info.status = ComponentStatus.INSTALLING + + await process_instances() + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_shim_installation_in_progress( + self, shim_client_mock: Mock, component_list: ComponentList + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + shim_info = component_list.shim + assert shim_info is not None + shim_info.status = ComponentStatus.INSTALLING + + await process_instances() + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_runner_installation_requested( + self, shim_client_mock: Mock, maybe_install_runner_mock: Mock + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + maybe_install_runner_mock.return_value = True + + await process_instances() + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_shim_installation_requested( + self, shim_client_mock: Mock, maybe_install_shim_mock: Mock + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + maybe_install_shim_mock.return_value = True + + await process_instances() + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() diff --git a/src/tests/_internal/server/services/runner/test_client.py b/src/tests/_internal/server/services/runner/test_client.py index e68a007cff..588c231a19 100644 --- a/src/tests/_internal/server/services/runner/test_client.py +++ b/src/tests/_internal/server/services/runner/test_client.py @@ -99,7 +99,7 @@ def test( client._negotiate() - assert client._shim_version == expected_shim_version + assert client._shim_version_tuple == expected_shim_version assert client._api_version == expected_api_version assert adapter.call_count == 1 self.assert_request(adapter, 0, "GET", "/api/healthcheck") @@ -129,7 +129,7 @@ def test_healthcheck(self, client: ShimClient, adapter: requests_mock.Adapter): assert adapter.call_count == 1 self.assert_request(adapter, 0, "GET", "/api/healthcheck") # healthcheck() method also performs negotiation to save API calls - assert client._shim_version == (0, 18, 30) + assert client._shim_version_tuple == (0, 18, 30) assert client._api_version == 1 def test_submit(self, client: ShimClient, adapter: requests_mock.Adapter): @@ -262,9 +262,94 @@ def test_healthcheck(self, client: ShimClient, adapter: requests_mock.Adapter): assert adapter.call_count == 1 self.assert_request(adapter, 0, "GET", "/api/healthcheck") # healthcheck() method also performs negotiation to save API calls - assert client._shim_version == (0, 18, 40) + assert client._shim_version_tuple == (0, 18, 40) assert client._api_version == 2 + def test_is_safe_to_restart_false_old_shim( + self, client: ShimClient, adapter: requests_mock.Adapter + ): + adapter.register_uri( + "GET", + "/api/tasks", + json={ + # pre-0.19.26 shim returns ids instead of tasks + "tasks": None, + "ids": [], + }, + ) + + res = client.is_safe_to_restart() + + assert res is False + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "GET", "/api/tasks") + + @pytest.mark.parametrize( + "task_status", + [ + TaskStatus.PENDING, + TaskStatus.PREPARING, + TaskStatus.PULLING, + TaskStatus.CREATING, + TaskStatus.RUNNING, + ], + ) + def test_is_safe_to_restart_false_status_not_safe( + self, client: ShimClient, adapter: requests_mock.Adapter, task_status: TaskStatus + ): + adapter.register_uri( + "GET", + "/api/tasks", + json={ + "tasks": [ + { + "id": str(uuid.uuid4()), + "status": "terminated", + }, + { + "id": str(uuid.uuid4()), + "status": task_status.value, + }, + ], + "ids": None, + }, + ) + + res = client.is_safe_to_restart() + + assert res is False + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "GET", "/api/tasks") + + def test_is_safe_to_restart_true(self, client: ShimClient, adapter: requests_mock.Adapter): + adapter.register_uri( + "GET", + "/api/tasks", + json={ + "tasks": [ + { + "id": str(uuid.uuid4()), + "status": "terminated", + }, + { + "id": str(uuid.uuid4()), + # TODO: replace with "running" once it's safe + "status": "terminated", + }, + ], + "ids": None, + }, + ) + + res = client.is_safe_to_restart() + + assert res is True + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "GET", "/api/tasks") + def test_get_task(self, client: ShimClient, adapter: requests_mock.Adapter): task_id = "d35b6e24-b556-4d6e-81e3-5982d2c34449" url = f"/api/tasks/{task_id}"