Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 16 additions & 91 deletions runner/internal/shim/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@ import (
"encoding/hex"
"math/rand"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -20,90 +17,33 @@ import (
)

// TestDocker_SSHServer pulls ubuntu image (without sshd), installs openssh-server and exits
// Basically, it indirectly tests a shell script generated by getSSHShellCommands
func TestDocker_SSHServer(t *testing.T) {
if testing.Short() || (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") {
t.Skip()
}
t.Parallel()

params := &dockerParametersMock{
commands: []string{"echo 1"},
sshPort: nextPort(),
runnerDir: t.TempDir(),
commands: []string{"/usr/sbin/sshd -V 2>&1 | grep OpenSSH"},
sshShellCommands: true,
runnerDir: t.TempDir(),
}

timeout := 180 // seconds
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), time.Duration(timeout)*time.Second)
defer cancel()

dockerRunner, err := NewDockerRunner(ctx, params)
require.NoError(t, err)

taskConfig := createTaskConfig(t)
defer dockerRunner.Remove(context.Background(), taskConfig.ID)
defer dockerRunner.Remove(t.Context(), taskConfig.ID)

assert.NoError(t, dockerRunner.Submit(ctx, taskConfig))
assert.NoError(t, dockerRunner.Run(ctx, taskConfig.ID))
}

// TestDocker_SSHServerConnect pulls ubuntu image (without sshd), installs openssh-server and tries to connect via SSH
func TestDocker_SSHServerConnect(t *testing.T) {
if testing.Short() || (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") {
t.Skip()
}
t.Parallel()

tempDir := t.TempDir()
require.NoError(t, exec.CommandContext(t.Context(), "ssh-keygen", "-t", "rsa", "-b", "2048", "-f", tempDir+"/id_rsa", "-q", "-N", "").Run())
publicBytes, err := os.ReadFile(tempDir + "/id_rsa.pub")
require.NoError(t, err)

params := &dockerParametersMock{
commands: []string{"sleep 5"},
sshPort: nextPort(),
publicSSHKey: string(publicBytes),
runnerDir: t.TempDir(),
}

timeout := 180 // seconds
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
defer cancel()

dockerRunner, err := NewDockerRunner(ctx, params)
require.NoError(t, err)

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
taskConfig := createTaskConfig(t)
defer dockerRunner.Remove(context.Background(), taskConfig.ID)

assert.NoError(t, dockerRunner.Submit(ctx, taskConfig))
assert.NoError(t, dockerRunner.Run(ctx, taskConfig.ID))
}()

for i := 0; i < timeout; i++ {
cmd := exec.CommandContext(
t.Context(),
"ssh",
"-F", "none",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-i", tempDir+"/id_rsa",
"-p", strconv.Itoa(params.sshPort),
"root@localhost", "whoami",
)
output, err := cmd.Output()
if err == nil {
assert.Equal(t, "root\n", string(output))
break
}
time.Sleep(time.Second) // 1 attempt per second
}
wg.Wait()
}

func TestDocker_ShmNoexecByDefault(t *testing.T) {
if testing.Short() || (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") {
t.Skip()
Expand All @@ -116,14 +56,14 @@ func TestDocker_ShmNoexecByDefault(t *testing.T) {
}

timeout := 180 // seconds
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), time.Duration(timeout)*time.Second)
defer cancel()

dockerRunner, err := NewDockerRunner(ctx, params)
require.NoError(t, err)

taskConfig := createTaskConfig(t)
defer dockerRunner.Remove(context.Background(), taskConfig.ID)
defer dockerRunner.Remove(t.Context(), taskConfig.ID)

assert.NoError(t, dockerRunner.Submit(ctx, taskConfig))
assert.NoError(t, dockerRunner.Run(ctx, taskConfig.ID))
Expand All @@ -141,15 +81,15 @@ func TestDocker_ShmExecIfSizeSpecified(t *testing.T) {
}

timeout := 180 // seconds
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
ctx, cancel := context.WithTimeout(t.Context(), time.Duration(timeout)*time.Second)
defer cancel()

dockerRunner, err := NewDockerRunner(ctx, params)
require.NoError(t, err)

taskConfig := createTaskConfig(t)
taskConfig.ShmSize = 1024 * 1024
defer dockerRunner.Remove(context.Background(), taskConfig.ID)
defer dockerRunner.Remove(t.Context(), taskConfig.ID)

assert.NoError(t, dockerRunner.Submit(ctx, taskConfig))
assert.NoError(t, dockerRunner.Run(ctx, taskConfig.ID))
Expand All @@ -158,11 +98,9 @@ func TestDocker_ShmExecIfSizeSpecified(t *testing.T) {
/* Mocks */

type dockerParametersMock struct {
// If sshPort is not set (equals zero), sshd won't be started.
commands []string
sshPort int
publicSSHKey string
runnerDir string
commands []string
sshShellCommands bool
runnerDir string
}

func (c *dockerParametersMock) DockerPrivileged() bool {
Expand All @@ -174,24 +112,17 @@ func (c *dockerParametersMock) DockerPJRTDevice() string {
}

func (c *dockerParametersMock) DockerShellCommands(publicKeys []string) []string {
userPublicKey := c.publicSSHKey
if len(publicKeys) > 0 {
userPublicKey = strings.Join(publicKeys, "\n")
}
userPublicKey := strings.Join(publicKeys, "\n")
commands := make([]string, 0)
if c.sshPort != 0 {
if c.sshShellCommands {
commands = append(commands, getSSHShellCommands(userPublicKey)...)
}
commands = append(commands, c.commands...)
return commands
}

func (c *dockerParametersMock) DockerPorts() []int {
ports := make([]int, 0)
if c.sshPort != 0 {
ports = append(ports, c.sshPort)
}
return ports
return []int{}
}

func (c *dockerParametersMock) DockerMounts(string) ([]mount.Mount, error) {
Expand All @@ -204,12 +135,6 @@ func (c *dockerParametersMock) MakeRunnerDir(string) (string, error) {

/* Utilities */

var portNumber int32 = 10000

func nextPort() int {
return int(atomic.AddInt32(&portNumber, 1))
}

var (
randSrc = rand.New(rand.NewSource(time.Now().UnixNano()))
randMu = sync.Mutex{}
Expand Down