Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions internal/fn/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ import (
"gopkg.in/yaml.v2"
)

var (
readFile = os.ReadFile
writeFile = os.WriteFile
stat = os.Stat
mkdirAll = os.MkdirAll
)

func GetUserInputFromStdin() string {
var lines []string
scanner := bufio.NewScanner(os.Stdin)
Expand Down Expand Up @@ -130,7 +137,7 @@ func GetPathContent(src string) ([]byte, error) {

var content []byte
for _, filePath := range filePaths {
fileContent, err := os.ReadFile(filePath)
fileContent, err := readFile(filePath)
if err != nil {
return nil, fmt.Errorf("no valid SSH config found in %s: %w", src, err)
}
Expand All @@ -145,7 +152,7 @@ func Save(dest string, content []byte) error {
return err
}

info, err := os.Stat(destDir)
info, err := stat(destDir)
if err != nil {
return fmt.Errorf("can not write to destination file: %v", err)
}
Expand All @@ -154,7 +161,7 @@ func Save(dest string, content []byte) error {
return fmt.Errorf("can not write to destination file: directory %s is not writable", destDir)
}

if err := os.WriteFile(dest, content, 0644); err != nil {
if err := writeFile(dest, content, 0644); err != nil {
return fmt.Errorf("can not write to destination file: %v", err)
}
return nil
Expand All @@ -173,7 +180,7 @@ func TidyLastEmptyLines(input []byte) []byte {
}

func ensureDirectory(destDir string) error {
info, err := os.Stat(destDir)
info, err := stat(destDir)
if err == nil {
if !info.IsDir() {
return fmt.Errorf("can not create destination directory: %s is not a directory", destDir)
Expand All @@ -187,7 +194,7 @@ func ensureDirectory(destDir string) error {

parent := filepath.Dir(destDir)
if parent != destDir {
if parentInfo, parentErr := os.Stat(parent); parentErr == nil {
if parentInfo, parentErr := stat(parent); parentErr == nil {
if !parentInfo.IsDir() {
return fmt.Errorf("can not create destination directory: parent %s is not a directory", parent)
}
Expand All @@ -197,7 +204,7 @@ func ensureDirectory(destDir string) error {
}
}

if err := os.MkdirAll(destDir, 0755); err != nil {
if err := mkdirAll(destDir, 0755); err != nil {
return fmt.Errorf("can not create destination directory: %v", err)
}

Expand Down
170 changes: 170 additions & 0 deletions internal/fn/fn_errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package fn

import (
"errors"
"os"
"path/filepath"
"strings"
"testing"
)

func TestGetPathContentReadFileError(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config")
if err := os.WriteFile(configPath, []byte("Host example\n HostName example.com\n"), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}

originalReadFile := readFile
readFile = func(path string) ([]byte, error) {
if path == configPath {
return nil, errors.New("read failure")
}
return originalReadFile(path)
}
t.Cleanup(func() { readFile = originalReadFile })

_, err := GetPathContent(tmpDir)
if err == nil {
t.Fatalf("expected error but got nil")
}
if !strings.Contains(err.Error(), "no valid SSH config found") {
t.Fatalf("unexpected error message: %v", err)
}
}

func TestSaveStatError(t *testing.T) {
tmpDir := t.TempDir()
dest := filepath.Join(tmpDir, "sub", "file.txt")
destDir := filepath.Dir(dest)

callCount := 0
originalStat := stat
stat = func(path string) (os.FileInfo, error) {
if path == destDir {
callCount++
if callCount == 2 {
return nil, errors.New("stat failure")
}
}
return originalStat(path)
}
t.Cleanup(func() { stat = originalStat })

err := Save(dest, []byte("content"))
if err == nil {
t.Fatalf("expected error but got nil")
}
if !strings.Contains(err.Error(), "can not write to destination file") {
t.Fatalf("unexpected error message: %v", err)
}
}

func TestSaveWriteFileError(t *testing.T) {
tmpDir := t.TempDir()
dest := filepath.Join(tmpDir, "file.txt")

originalWriteFile := writeFile
writeFile = func(path string, data []byte, perm os.FileMode) error {
if path == dest {
return errors.New("write failure")
}
return originalWriteFile(path, data, perm)
}
t.Cleanup(func() { writeFile = originalWriteFile })

err := Save(dest, []byte("content"))
if err == nil {
t.Fatalf("expected error but got nil")
}
if !strings.Contains(err.Error(), "can not write to destination file") {
t.Fatalf("unexpected error message: %v", err)
}
}

func TestEnsureDirectoryExistingFile(t *testing.T) {
tmpDir := t.TempDir()
destDir := filepath.Join(tmpDir, "existing")

if err := os.WriteFile(destDir, []byte("content"), 0644); err != nil {
t.Fatalf("failed to create file: %v", err)
}

err := ensureDirectory(destDir)
if err == nil {
t.Fatalf("expected error but got nil")
}
if !strings.Contains(err.Error(), "is not a directory") {
t.Fatalf("unexpected error message: %v", err)
}
}

func TestEnsureDirectoryParentNotDir(t *testing.T) {
tmpDir := t.TempDir()
parent := filepath.Join(tmpDir, "parent")
if err := os.WriteFile(parent, []byte("content"), 0644); err != nil {
t.Fatalf("failed to create parent file: %v", err)
}
destDir := filepath.Join(parent, "child")

originalStat := stat
stat = func(path string) (os.FileInfo, error) {
if path == destDir {
return nil, os.ErrNotExist
}
return originalStat(path)
}
t.Cleanup(func() { stat = originalStat })

err := ensureDirectory(destDir)
if err == nil {
t.Fatalf("expected error but got nil")
}
if !strings.Contains(err.Error(), "parent") {
t.Fatalf("unexpected error message: %v", err)
}
}

func TestEnsureDirectoryMkdirError(t *testing.T) {
tmpDir := t.TempDir()
destDir := filepath.Join(tmpDir, "newdir")

originalMkdirAll := mkdirAll
mkdirAll = func(path string, perm os.FileMode) error {
if path == destDir {
return errors.New("mkdir failure")
}
return originalMkdirAll(path, perm)
}
t.Cleanup(func() { mkdirAll = originalMkdirAll })

err := ensureDirectory(destDir)
if err == nil {
t.Fatalf("expected error but got nil")
}
if !strings.Contains(err.Error(), "can not create destination directory") {
t.Fatalf("unexpected error message: %v", err)
}
}

func TestEnsureDirectoryStatUnexpectedError(t *testing.T) {
tmpDir := t.TempDir()
destDir := filepath.Join(tmpDir, "unexpected")

originalStat := stat
stat = func(path string) (os.FileInfo, error) {
if path == destDir {
return nil, os.ErrPermission
}
return originalStat(path)
}
t.Cleanup(func() { stat = originalStat })

err := ensureDirectory(destDir)
if err == nil {
t.Fatalf("expected error but got nil")
}
if !strings.Contains(err.Error(), "can not create destination directory") {
t.Fatalf("unexpected error message: %v", err)
}
}