diff --git a/cmd/ssh-tpm-agent/main_test.go b/cmd/ssh-tpm-agent/main_test.go index 051febf..8335058 100644 --- a/cmd/ssh-tpm-agent/main_test.go +++ b/cmd/ssh-tpm-agent/main_test.go @@ -174,10 +174,10 @@ func runSSHAuth(t *testing.T, keytype tpm2.TPMAlgID, bits int, pin []byte, keyfn } defer session.Close() - session.Shell() - var b bytes.Buffer session.Stdout = &b + session.Shell() + session.Wait() <-msgSent if b.String() != "connected" { diff --git a/internal/keyring/keyring_test.go b/internal/keyring/keyring_test.go index cf1e67d..c2669bd 100644 --- a/internal/keyring/keyring_test.go +++ b/internal/keyring/keyring_test.go @@ -7,7 +7,7 @@ import ( "testing" ) -func TestSaveandGetData(t *testing.T) { +func TestSaveAndGetData(t *testing.T) { keyring, err := SessionKeyring.CreateKeyring() if err != nil { t.Fatalf("failed getting keyring: %v", err) @@ -27,7 +27,7 @@ func TestSaveandGetData(t *testing.T) { } } -func TestNokey(t *testing.T) { +func TestNoKey(t *testing.T) { keyring, err := SessionKeyring.CreateKeyring() if err != nil { t.Fatalf("failed getting keyring: %v", err) diff --git a/internal/keyring/threadkeyring.go b/internal/keyring/threadkeyring.go index 12bf933..a24d7f4 100644 --- a/internal/keyring/threadkeyring.go +++ b/internal/keyring/threadkeyring.go @@ -69,14 +69,19 @@ func NewThreadKeyring(ctx context.Context, keyring *Keyring) (*ThreadKeyring, er tk.removekey = make(chan *removekeyMsg) tk.readkey = make(chan *readkeyMsg) + // Channel for initialization to prevent Data Race + errCh := make(chan error, 1) + tk.wg.Add(1) go func() { var ak *Keyring runtime.LockOSThread() ak, err = keyring.CreateKeyring() if err != nil { + errCh <- err return } + errCh <- nil for { select { case msg := <-tk.addkey: @@ -91,5 +96,11 @@ func NewThreadKeyring(ctx context.Context, keyring *Keyring) (*ThreadKeyring, er } } }() + + // Wait for initialization to complete + if err := <-errCh; err != nil { + return nil, err + } + return &tk, err } diff --git a/internal/keyring/threadkeyring_test.go b/internal/keyring/threadkeyring_test.go index 10a9a2d..6fcda59 100644 --- a/internal/keyring/threadkeyring_test.go +++ b/internal/keyring/threadkeyring_test.go @@ -12,7 +12,7 @@ var ( ctx = context.Background() ) -func TestSaveandGetDataThreaded(t *testing.T) { +func TestSaveAndGetDataThreaded(t *testing.T) { keyring, err := NewThreadKeyring(ctx, SessionKeyring) if err != nil { t.Fatalf("failed getting keyring: %v", err) @@ -32,14 +32,11 @@ func TestSaveandGetDataThreaded(t *testing.T) { } } -func TestNokeyThreaded(t *testing.T) { +func TestNoKeyThreaded(t *testing.T) { keyring, err := NewThreadKeyring(ctx, SessionKeyring) if err != nil { t.Fatalf("failed getting keyring: %v", err) } - if err != nil { - t.Fatalf("failed getting keyring: %v", err) - } _, err = keyring.ReadKey("this.key.does.not.exist") if !errors.Is(err, syscall.ENOKEY) { t.Fatalf("err: %v", err) @@ -51,9 +48,6 @@ func TestRemoveKeyThreaded(t *testing.T) { if err != nil { t.Fatalf("failed getting keyring: %v", err) } - if err != nil { - t.Fatalf("failed getting keyring: %v", err) - } b := []byte("test string") if err := keyring.AddKey("test-2", b); err != nil { t.Fatalf("err: %v", err)