diff --git a/internal/fn/fn.go b/internal/fn/fn.go index 7788524..1a13582 100644 --- a/internal/fn/fn.go +++ b/internal/fn/fn.go @@ -29,6 +29,13 @@ import ( "gopkg.in/yaml.v2" ) +var ( + readFile = os.ReadFile + writeFile = os.WriteFile + stat = os.Stat + mkdirAll = os.MkdirAll +) + func GetUserInputFromStdin() string { var lines []string scanner := bufio.NewScanner(os.Stdin) @@ -130,7 +137,7 @@ func GetPathContent(src string) ([]byte, error) { var content []byte for _, filePath := range filePaths { - fileContent, err := os.ReadFile(filePath) + fileContent, err := readFile(filePath) if err != nil { return nil, fmt.Errorf("no valid SSH config found in %s: %w", src, err) } @@ -145,7 +152,7 @@ func Save(dest string, content []byte) error { return err } - info, err := os.Stat(destDir) + info, err := stat(destDir) if err != nil { return fmt.Errorf("can not write to destination file: %v", err) } @@ -154,7 +161,7 @@ func Save(dest string, content []byte) error { return fmt.Errorf("can not write to destination file: directory %s is not writable", destDir) } - if err := os.WriteFile(dest, content, 0644); err != nil { + if err := writeFile(dest, content, 0644); err != nil { return fmt.Errorf("can not write to destination file: %v", err) } return nil @@ -173,7 +180,7 @@ func TidyLastEmptyLines(input []byte) []byte { } func ensureDirectory(destDir string) error { - info, err := os.Stat(destDir) + info, err := stat(destDir) if err == nil { if !info.IsDir() { return fmt.Errorf("can not create destination directory: %s is not a directory", destDir) @@ -187,7 +194,7 @@ func ensureDirectory(destDir string) error { parent := filepath.Dir(destDir) if parent != destDir { - if parentInfo, parentErr := os.Stat(parent); parentErr == nil { + if parentInfo, parentErr := stat(parent); parentErr == nil { if !parentInfo.IsDir() { return fmt.Errorf("can not create destination directory: parent %s is not a directory", parent) } @@ -197,7 +204,7 @@ func ensureDirectory(destDir string) error { } } - if err := os.MkdirAll(destDir, 0755); err != nil { + if err := mkdirAll(destDir, 0755); err != nil { return fmt.Errorf("can not create destination directory: %v", err) } diff --git a/internal/fn/fn_errors_test.go b/internal/fn/fn_errors_test.go new file mode 100644 index 0000000..b804469 --- /dev/null +++ b/internal/fn/fn_errors_test.go @@ -0,0 +1,170 @@ +package fn + +import ( + "errors" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestGetPathContentReadFileError(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config") + if err := os.WriteFile(configPath, []byte("Host example\n HostName example.com\n"), 0644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + + originalReadFile := readFile + readFile = func(path string) ([]byte, error) { + if path == configPath { + return nil, errors.New("read failure") + } + return originalReadFile(path) + } + t.Cleanup(func() { readFile = originalReadFile }) + + _, err := GetPathContent(tmpDir) + if err == nil { + t.Fatalf("expected error but got nil") + } + if !strings.Contains(err.Error(), "no valid SSH config found") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestSaveStatError(t *testing.T) { + tmpDir := t.TempDir() + dest := filepath.Join(tmpDir, "sub", "file.txt") + destDir := filepath.Dir(dest) + + callCount := 0 + originalStat := stat + stat = func(path string) (os.FileInfo, error) { + if path == destDir { + callCount++ + if callCount == 2 { + return nil, errors.New("stat failure") + } + } + return originalStat(path) + } + t.Cleanup(func() { stat = originalStat }) + + err := Save(dest, []byte("content")) + if err == nil { + t.Fatalf("expected error but got nil") + } + if !strings.Contains(err.Error(), "can not write to destination file") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestSaveWriteFileError(t *testing.T) { + tmpDir := t.TempDir() + dest := filepath.Join(tmpDir, "file.txt") + + originalWriteFile := writeFile + writeFile = func(path string, data []byte, perm os.FileMode) error { + if path == dest { + return errors.New("write failure") + } + return originalWriteFile(path, data, perm) + } + t.Cleanup(func() { writeFile = originalWriteFile }) + + err := Save(dest, []byte("content")) + if err == nil { + t.Fatalf("expected error but got nil") + } + if !strings.Contains(err.Error(), "can not write to destination file") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestEnsureDirectoryExistingFile(t *testing.T) { + tmpDir := t.TempDir() + destDir := filepath.Join(tmpDir, "existing") + + if err := os.WriteFile(destDir, []byte("content"), 0644); err != nil { + t.Fatalf("failed to create file: %v", err) + } + + err := ensureDirectory(destDir) + if err == nil { + t.Fatalf("expected error but got nil") + } + if !strings.Contains(err.Error(), "is not a directory") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestEnsureDirectoryParentNotDir(t *testing.T) { + tmpDir := t.TempDir() + parent := filepath.Join(tmpDir, "parent") + if err := os.WriteFile(parent, []byte("content"), 0644); err != nil { + t.Fatalf("failed to create parent file: %v", err) + } + destDir := filepath.Join(parent, "child") + + originalStat := stat + stat = func(path string) (os.FileInfo, error) { + if path == destDir { + return nil, os.ErrNotExist + } + return originalStat(path) + } + t.Cleanup(func() { stat = originalStat }) + + err := ensureDirectory(destDir) + if err == nil { + t.Fatalf("expected error but got nil") + } + if !strings.Contains(err.Error(), "parent") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestEnsureDirectoryMkdirError(t *testing.T) { + tmpDir := t.TempDir() + destDir := filepath.Join(tmpDir, "newdir") + + originalMkdirAll := mkdirAll + mkdirAll = func(path string, perm os.FileMode) error { + if path == destDir { + return errors.New("mkdir failure") + } + return originalMkdirAll(path, perm) + } + t.Cleanup(func() { mkdirAll = originalMkdirAll }) + + err := ensureDirectory(destDir) + if err == nil { + t.Fatalf("expected error but got nil") + } + if !strings.Contains(err.Error(), "can not create destination directory") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestEnsureDirectoryStatUnexpectedError(t *testing.T) { + tmpDir := t.TempDir() + destDir := filepath.Join(tmpDir, "unexpected") + + originalStat := stat + stat = func(path string) (os.FileInfo, error) { + if path == destDir { + return nil, os.ErrPermission + } + return originalStat(path) + } + t.Cleanup(func() { stat = originalStat }) + + err := ensureDirectory(destDir) + if err == nil { + t.Fatalf("expected error but got nil") + } + if !strings.Contains(err.Error(), "can not create destination directory") { + t.Fatalf("unexpected error message: %v", err) + } +}