Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions runner/cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ func mainInner() int {
log.DefaultEntry.Logger.SetLevel(logrus.Level(defaultLogLevel))
log.DefaultEntry.Logger.SetOutput(os.Stderr)

shimBinaryPath, err := os.Executable()
if err != nil {
shimBinaryPath = consts.ShimBinaryPath
}

cmd := &cli.Command{
Name: "dstack-shim",
Usage: "Starts dstack-runner or docker container.",
Expand All @@ -54,6 +59,14 @@ func mainInner() int {
DefaultText: path.Join("~", consts.DstackDirPath),
Sources: cli.EnvVars("DSTACK_SHIM_HOME"),
},
&cli.StringFlag{
Name: "shim-binary-path",
Usage: "Path to shim's binary",
Value: shimBinaryPath,
Destination: &args.Shim.BinaryPath,
TakesFile: true,
Sources: cli.EnvVars("DSTACK_SHIM_BINARY_PATH"),
},
&cli.IntFlag{
Name: "shim-http-port",
Usage: "Set shim's http port",
Expand Down Expand Up @@ -172,6 +185,7 @@ func mainInner() int {

func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) {
log.DefaultEntry.Logger.SetLevel(logrus.Level(args.Shim.LogLevel))
log.Info(ctx, "Starting dstack-shim", "version", Version)

shimHomeDir := args.Shim.HomeDir
if shimHomeDir == "" {
Expand Down Expand Up @@ -211,6 +225,10 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
} else if runnerErr != nil {
return runnerErr
}
shimManager, shimErr := components.NewShimManager(ctx, args.Shim.BinaryPath)
if shimErr != nil {
return shimErr
}

log.Debug(ctx, "Shim", "args", args.Shim)
log.Debug(ctx, "Runner", "args", args.Runner)
Expand Down Expand Up @@ -259,7 +277,11 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
}

address := fmt.Sprintf("localhost:%d", args.Shim.HTTPPort)
shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper, runnerManager)
shimServer := api.NewShimServer(
ctx, address, Version,
dockerRunner, dcgmExporter, dcgmWrapper,
runnerManager, shimManager,
)

if serviceMode {
if err := shim.WriteHostInfo(shimHomeDir, dockerRunner.Resources(ctx)); err != nil {
Expand All @@ -278,6 +300,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
if err := shimServer.Serve(); err != nil {
serveErrCh <- err
}
close(serveErrCh)
}()

select {
Expand All @@ -287,7 +310,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)

shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second)
defer cancelShutdown()
shutdownErr := shimServer.Shutdown(shutdownCtx)
shutdownErr := shimServer.Shutdown(shutdownCtx, false)
if serveErr != nil {
return serveErr
}
Expand Down
3 changes: 3 additions & 0 deletions runner/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ const (
// 2. A default path on the host unless overridden via shim CLI
const RunnerBinaryPath = "/usr/local/bin/dstack-runner"

// A fallback path on the host used if os.Executable() has failed
const ShimBinaryPath = "/usr/local/bin/dstack-shim"

// Error-containing messages will be identified by this signature
const ExecutorFailedSignature = "Executor failed"

Expand Down
51 changes: 45 additions & 6 deletions runner/docs/shim.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ openapi: 3.1.2

info:
title: dstack-shim API
version: v2/0.19.41
version: v2/0.20.1
x-logo:
url: https://avatars.githubusercontent.com/u/54146142?s=260
description: >
Expand Down Expand Up @@ -41,7 +41,7 @@ paths:

**Important**: Since this endpoint is used for negotiation, it should always stay
backward/future compatible, specifically the `version` field

tags: [shim]
responses:
"200":
description: ""
Expand All @@ -50,6 +50,29 @@ paths:
schema:
$ref: "#/components/schemas/HealthcheckResponse"

/shutdown:
post:
summary: Request shim shutdown
description: |
(since [0.20.1](https://github.com/dstackai/dstack/releases/tag/0.20.1)) Request shim to shut down itself.
Restart must be handled by an external process supervisor, e.g., `systemd`.

**Note**: background jobs (e.g., component installation) are canceled regardless of the `force` option.
tags: [shim]
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/ShutdownRequest"
responses:
"200":
description: Request accepted
$ref: "#/components/responses/PlainTextOk"
"400":
description: Malformed JSON body or validation error
$ref: "#/components/responses/PlainTextBadRequest"

/instance/health:
get:
summary: Get instance health
Expand All @@ -66,7 +89,7 @@ paths:
/components:
get:
summary: Get components
description: (since [0.19.41](https://github.com/dstackai/dstack/releases/tag/0.19.41)) Returns a list of software components (e.g., `dstack-runner`)
description: (since [0.20.0](https://github.com/dstackai/dstack/releases/tag/0.20.0)) Returns a list of software components (e.g., `dstack-runner`)
tags: [Components]
responses:
"200":
Expand All @@ -80,7 +103,7 @@ paths:
post:
summary: Install component
description: >
(since [0.19.41](https://github.com/dstackai/dstack/releases/tag/0.19.41)) Request installing/updating the software component.
(since [0.20.0](https://github.com/dstackai/dstack/releases/tag/0.20.0)) Request installing/updating the software component.
Components are installed asynchronously
tags: [Components]
requestBody:
Expand Down Expand Up @@ -410,6 +433,10 @@ components:
type: string
enum:
- dstack-runner
- dstack-shim
description: |
* (since [0.20.0](https://github.com/dstackai/dstack/releases/tag/0.20.0)) `dstack-runner`
* (since [0.20.1](https://github.com/dstackai/dstack/releases/tag/0.20.1)) `dstack-shim`

ComponentStatus:
title: shim.components.ComponentStatus
Expand All @@ -430,7 +457,7 @@ components:
type: string
description: An empty string if status != installed
examples:
- 0.19.41
- 0.20.1
status:
allOf:
- $ref: "#/components/schemas/ComponentStatus"
Expand All @@ -457,6 +484,18 @@ components:
- version
additionalProperties: false

ShutdownRequest:
title: shim.api.ShutdownRequest
type: object
properties:
force:
type: boolean
examples:
- false
description: If `true`, don't wait for background job coroutines to complete after canceling them and close HTTP server forcefully.
required:
- force

InstanceHealthResponse:
title: shim.api.InstanceHealthResponse
type: object
Expand Down Expand Up @@ -486,7 +525,7 @@ components:
url:
type: string
examples:
- https://dstack-runner-downloads.s3.eu-west-1.amazonaws.com/0.19.41/binaries/dstack-runner-linux-amd64
- https://dstack-runner-downloads.s3.eu-west-1.amazonaws.com/0.20.1/binaries/dstack-runner-linux-amd64
required:
- name
- url
Expand Down
57 changes: 39 additions & 18 deletions runner/internal/shim/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ func (s *ShimServer) HealthcheckHandler(w http.ResponseWriter, r *http.Request)
}, nil
}

func (s *ShimServer) ShutdownHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
var req ShutdownRequest
if err := api.DecodeJSONBody(w, r, &req, true); err != nil {
return nil, err
}

go func() {
if err := s.Shutdown(s.ctx, req.Force); err != nil {
log.Error(s.ctx, "Shutdown", "err", err)
}
}()

return nil, nil
}

func (s *ShimServer) InstanceHealthHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
ctx := r.Context()
response := InstanceHealthResponse{}
Expand Down Expand Up @@ -159,9 +174,11 @@ func (s *ShimServer) TaskMetricsHandler(w http.ResponseWriter, r *http.Request)
}

func (s *ShimServer) ComponentListHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
runnerStatus := s.runnerManager.GetInfo(r.Context())
response := &ComponentListResponse{
Components: []components.ComponentInfo{runnerStatus},
Components: []components.ComponentInfo{
s.runnerManager.GetInfo(r.Context()),
s.shimManager.GetInfo(r.Context()),
},
}
return response, nil
}
Expand All @@ -176,27 +193,31 @@ func (s *ShimServer) ComponentInstallHandler(w http.ResponseWriter, r *http.Requ
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty name"}
}

var componentManager components.ComponentManager
switch components.ComponentName(req.Name) {
case components.ComponentNameRunner:
if req.URL == "" {
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty url"}
}

// There is still a small chance of time-of-check race condition, but we ignore it.
runnerInfo := s.runnerManager.GetInfo(r.Context())
if runnerInfo.Status == components.ComponentStatusInstalling {
return nil, &api.Error{Status: http.StatusConflict, Msg: "already installing"}
}

s.bgJobsGroup.Go(func() {
if err := s.runnerManager.Install(s.bgJobsCtx, req.URL, true); err != nil {
log.Error(s.bgJobsCtx, "runner background install", "err", err)
}
})

componentManager = s.runnerManager
case components.ComponentNameShim:
componentManager = s.shimManager
default:
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "unknown component"}
}

if req.URL == "" {
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty url"}
}

// There is still a small chance of time-of-check race condition, but we ignore it.
componentInfo := componentManager.GetInfo(r.Context())
if componentInfo.Status == components.ComponentStatusInstalling {
return nil, &api.Error{Status: http.StatusConflict, Msg: "already installing"}
}

s.bgJobsGroup.Go(func() {
if err := componentManager.Install(s.bgJobsCtx, req.URL, true); err != nil {
log.Error(s.bgJobsCtx, "component background install", "name", componentInfo.Name, "err", err)
}
})

return nil, nil
}
4 changes: 2 additions & 2 deletions runner/internal/shim/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestHealthcheck(t *testing.T) {
request := httptest.NewRequest("GET", "/api/healthcheck", nil)
responseRecorder := httptest.NewRecorder()

server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil)
server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil)

f := common.JSONResponseHandler(server.HealthcheckHandler)
f(responseRecorder, request)
Expand All @@ -30,7 +30,7 @@ func TestHealthcheck(t *testing.T) {
}

func TestTaskSubmit(t *testing.T) {
server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil)
server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil)
requestBody := `{
"id": "dummy-id",
"name": "dummy-name",
Expand Down
4 changes: 4 additions & 0 deletions runner/internal/shim/api/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ type HealthcheckResponse struct {
Version string `json:"version"`
}

type ShutdownRequest struct {
Force bool `json:"force"`
}

type InstanceHealthResponse struct {
DCGM *dcgm.Health `json:"dcgm"`
}
Expand Down
Loading