diff --git a/runner/cmd/runner/cmd.go b/runner/cmd/runner/cmd.go deleted file mode 100644 index 08f3d5b018..0000000000 --- a/runner/cmd/runner/cmd.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -import ( - "log" - "os" - - "github.com/urfave/cli/v2" - - "github.com/dstackai/dstack/runner/consts" -) - -// Version is a build-time variable. The value is overridden by ldflags. -var Version string - -func App() { - var tempDir string - var homeDir string - var httpPort int - var sshPort int - var logLevel int - - app := &cli.App{ - Name: "dstack-runner", - Usage: "configure and start dstack-runner", - Version: Version, - Flags: []cli.Flag{ - &cli.IntFlag{ - Name: "log-level", - Value: 2, - DefaultText: "4 (Info)", - Usage: "log verbosity level: 2 (Error), 3 (Warning), 4 (Info), 5 (Debug), 6 (Trace)", - Destination: &logLevel, - }, - }, - Commands: []*cli.Command{ - { - Name: "start", - Usage: "Start dstack-runner", - Flags: []cli.Flag{ - &cli.PathFlag{ - Name: "temp-dir", - Usage: "Temporary directory for logs and other files", - Value: consts.RunnerTempDir, - Destination: &tempDir, - }, - &cli.PathFlag{ - Name: "home-dir", - Usage: "HomeDir directory for credentials and $HOME", - Value: consts.RunnerHomeDir, - Destination: &homeDir, - }, - &cli.IntFlag{ - Name: "http-port", - Usage: "Set a http port", - Value: consts.RunnerHTTPPort, - Destination: &httpPort, - }, - &cli.IntFlag{ - Name: "ssh-port", - Usage: "Set the ssh port", - Value: consts.RunnerSSHPort, - Destination: &sshPort, - }, - }, - Action: func(c *cli.Context) error { - err := start(tempDir, homeDir, httpPort, sshPort, logLevel, Version) - if err != nil { - return cli.Exit(err, 1) - } - return nil - }, - }, - }, - } - err := app.Run(os.Args) - if err != nil { - log.Fatal(err) - } -} diff --git a/runner/cmd/runner/main.go b/runner/cmd/runner/main.go index 27c07292b9..b34ee7b05a 100644 --- a/runner/cmd/runner/main.go +++ b/runner/cmd/runner/main.go @@ -4,22 +4,94 @@ import ( "context" "fmt" "io" - _ "net/http/pprof" "os" "path/filepath" "github.com/sirupsen/logrus" + "github.com/urfave/cli/v3" "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/runner/api" ) +// Version is a build-time variable. The value is overridden by ldflags. +var Version string + func main() { - App() + os.Exit(mainInner()) +} + +func mainInner() int { + var tempDir string + var homeDir string + var httpPort int + var sshPort int + var logLevel int + + cmd := &cli.Command{ + Name: "dstack-runner", + Usage: "configure and start dstack-runner", + Version: Version, + Flags: []cli.Flag{ + &cli.IntFlag{ + Name: "log-level", + Value: 2, + DefaultText: "4 (Info)", + Usage: "log verbosity level: 2 (Error), 3 (Warning), 4 (Info), 5 (Debug), 6 (Trace)", + Destination: &logLevel, + }, + }, + Commands: []*cli.Command{ + { + Name: "start", + Usage: "Start dstack-runner", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "temp-dir", + Usage: "Temporary directory for logs and other files", + Value: consts.RunnerTempDir, + Destination: &tempDir, + TakesFile: true, + }, + &cli.StringFlag{ + Name: "home-dir", + Usage: "HomeDir directory for credentials and $HOME", + Value: consts.RunnerHomeDir, + Destination: &homeDir, + TakesFile: true, + }, + &cli.IntFlag{ + Name: "http-port", + Usage: "Set a http port", + Value: consts.RunnerHTTPPort, + Destination: &httpPort, + }, + &cli.IntFlag{ + Name: "ssh-port", + Usage: "Set the ssh port", + Value: consts.RunnerSSHPort, + Destination: &sshPort, + }, + }, + Action: func(cxt context.Context, cmd *cli.Command) error { + return start(cxt, tempDir, homeDir, httpPort, sshPort, logLevel, Version) + }, + }, + }, + } + + ctx := context.Background() + + if err := cmd.Run(ctx, os.Args); err != nil { + log.Error(ctx, err.Error()) + return 1 + } + + return 0 } -func start(tempDir string, homeDir string, httpPort int, sshPort int, logLevel int, version string) error { +func start(ctx context.Context, tempDir string, homeDir string, httpPort int, sshPort int, logLevel int, version string) error { if err := os.MkdirAll(tempDir, 0o755); err != nil { return fmt.Errorf("create temp directory: %w", err) } @@ -31,20 +103,20 @@ func start(tempDir string, homeDir string, httpPort int, sshPort int, logLevel i defer func() { closeErr := defaultLogFile.Close() if closeErr != nil { - log.Error(context.TODO(), "Failed to close default log file", "err", closeErr) + log.Error(ctx, "Failed to close default log file", "err", closeErr) } }() log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile)) log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel)) - server, err := api.NewServer(context.TODO(), tempDir, homeDir, fmt.Sprintf(":%d", httpPort), sshPort, version) + server, err := api.NewServer(ctx, tempDir, homeDir, fmt.Sprintf(":%d", httpPort), sshPort, version) if err != nil { return fmt.Errorf("create server: %w", err) } - log.Trace(context.TODO(), "Starting API server", "port", httpPort) - if err := server.Run(); err != nil { + log.Trace(ctx, "Starting API server", "port", httpPort) + if err := server.Run(ctx); err != nil { return fmt.Errorf("server failed: %w", err) } diff --git a/runner/go.mod b/runner/go.mod index b317f6c7b0..260fb880ae 100644 --- a/runner/go.mod +++ b/runner/go.mod @@ -20,7 +20,6 @@ require ( github.com/shirou/gopsutil/v4 v4.24.11 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.11.1 - github.com/urfave/cli/v2 v2.27.7 github.com/urfave/cli/v3 v3.6.1 golang.org/x/crypto v0.22.0 golang.org/x/sys v0.26.0 @@ -33,7 +32,6 @@ require ( github.com/bits-and-blooms/bitset v1.22.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/containerd/log v0.1.0 // indirect - github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/distribution/reference v0.6.0 // indirect @@ -62,7 +60,6 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect - github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.2.2 // indirect github.com/tidwall/btree v1.7.0 // indirect @@ -70,7 +67,6 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/ulikunitz/xz v0.5.12 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect - github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.50.0 // indirect go.opentelemetry.io/otel v1.25.0 // indirect diff --git a/runner/go.sum b/runner/go.sum index de734fa39a..20c4568f9f 100644 --- a/runner/go.sum +++ b/runner/go.sum @@ -34,8 +34,6 @@ github.com/codeclysm/extract/v4 v4.0.0 h1:H87LFsUNaJTu2e/8p/oiuiUsOK/TaPQ5wxsjPn github.com/codeclysm/extract/v4 v4.0.0/go.mod h1:SFju1lj6as7FvUgalpSct7torJE0zttbJUWtryPRG6s= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= -github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= -github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53EtKeQYTC3kyg= @@ -155,8 +153,6 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= -github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/shirou/gopsutil/v4 v4.24.11 h1:WaU9xqGFKvFfsUv94SXcUPD7rCkU0vr/asVdQOBZNj8= @@ -185,14 +181,10 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/ulikunitz/xz v0.5.12 h1:37Nm15o69RwBkXM0J6A5OlE67RZTfzUxTj8fB3dfcsc= github.com/ulikunitz/xz v0.5.12/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= -github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= -github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/urfave/cli/v3 v3.6.1 h1:j8Qq8NyUawj/7rTYdBGrxcH7A/j7/G8Q5LhWEW4G3Mo= github.com/urfave/cli/v3 v3.6.1/go.mod h1:ysVLtOEmg2tOy6PknnYVhDoouyC/6N42TMeoMzskhso= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= -github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= -github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index 9d98315b1b..c973f45e1a 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + _ "net/http/pprof" "os" "os/signal" "syscall" @@ -80,21 +81,21 @@ func NewServer(ctx context.Context, tempDir string, homeDir string, address stri return s, nil } -func (s *Server) Run() error { - signals := []os.Signal{os.Interrupt, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGQUIT} +func (s *Server) Run(ctx context.Context) error { + signals := []os.Signal{os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT} signalCh := make(chan os.Signal, 1) go func() { if err := s.srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Error(context.TODO(), "Server failed", "err", err) + log.Error(ctx, "Server failed", "err", err) } }() - defer func() { _ = s.srv.Shutdown(context.TODO()) }() + defer func() { _ = s.srv.Shutdown(ctx) }() select { case <-s.jobBarrierCh: // job started case <-time.After(s.submitWaitDuration): - log.Error(context.TODO(), "Job didn't start in time, shutting down") + log.Error(ctx, "Job didn't start in time, shutting down") return errors.New("no job submitted") } @@ -103,10 +104,10 @@ func (s *Server) Run() error { signal.Notify(signalCh, signals...) select { case <-signalCh: - log.Error(context.TODO(), "Received interrupt signal, shutting down") + log.Error(ctx, "Received interrupt signal, shutting down") s.stop() case <-s.jobBarrierCh: - log.Info(context.TODO(), "Job finished, shutting down") + log.Info(ctx, "Job finished, shutting down") } close(s.shutdownCh) signal.Reset(signals...) @@ -123,9 +124,9 @@ loop: for _, ch := range logsToWait { select { case <-ch.ch: - log.Info(context.TODO(), "Logs streaming finished", "endpoint", ch.name) + log.Info(ctx, "Logs streaming finished", "endpoint", ch.name) case <-waitLogsDone: - log.Error(context.TODO(), "Logs streaming didn't finish in time") + log.Error(ctx, "Logs streaming didn't finish in time") break loop // break the loop, not the select } }