diff --git a/powershell/powershell_test.go b/powershell/powershell_test.go index 44ff9c1..52b367e 100644 --- a/powershell/powershell_test.go +++ b/powershell/powershell_test.go @@ -114,3 +114,25 @@ func TestFile(t *testing.T) { } } } + +func TestNewSession(t *testing.T) { + ps, err := NewSession() + if err != nil { + t.Fatalf("NewSession() failed: %v", err) + } + if err := ps.Close(); err != nil { + t.Fatalf("Close() failed: %v", err) + } +} + +func TestExecute(t *testing.T) { + ps, err := NewSession() + if err != nil { + t.Fatalf("NewSession() failed: %v", err) + } + defer ps.Close() + _, err = ps.Execute("ipconfig") + if err != nil { + t.Fatalf("Execute() failed: %v", err) + } +} diff --git a/powershell/powershell_windows.go b/powershell/powershell_windows.go index af36b19..60fc21e 100644 --- a/powershell/powershell_windows.go +++ b/powershell/powershell_windows.go @@ -18,11 +18,15 @@ package powershell import ( + "bufio" "encoding/json" "fmt" + "io" "os" "os/exec" "path/filepath" + "strings" + "sync" ) var ( @@ -106,3 +110,69 @@ func Version() (VersionTable, error) { err = json.Unmarshal(o, &psv) return psv, err } + +// Session manages a persistent PowerShell process. +type Session struct { + cmd *exec.Cmd + stdin io.WriteCloser + stdout *bufio.Reader + mu sync.Mutex +} + +// NewSession creates and starts a new PowerShell session. +func NewSession() (*Session, error) { + cmd := exec.Command(powerShellExe, "-NoExit", "-NoProfile", "-Command", "-") + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("failed to open stdin pipe: %w", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to open stdout pipe: %w", err) + } + cmd.Stderr = cmd.Stdout + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start PowerShell session: %w", err) + } + return &Session{ + cmd: cmd, + stdin: stdin, + stdout: bufio.NewReader(stdout), + }, nil +} + +// Execute runs a command in the PowerShell session and returns its output. +func (ps *Session) Execute(command string) (string, error) { + ps.mu.Lock() + defer ps.mu.Unlock() + + // This is annoying but we need some way of identifying the end of the command's output. + const endMarker = "----------END-OF-COMMAND-OUTPUT----------" + fullCommand := fmt.Sprintf("%s; Write-Host '%s'\n", command, endMarker) + if _, err := ps.stdin.Write([]byte(fullCommand)); err != nil { + return "", fmt.Errorf("failed to write to stdin: %w", err) + } + + var output strings.Builder + for { + line, err := ps.stdout.ReadString('\n') + if err != nil { + return "", fmt.Errorf("failed to read from stdout: %w", err) + } + if strings.TrimSpace(line) == endMarker { + break + } + output.WriteString(line) + } + return strings.TrimSpace(output.String()), nil +} + +// Close terminates the PowerShell session. +func (ps *Session) Close() error { + // Attempt to close the stdin pipe first. + if err := ps.stdin.Close(); err != nil { + return fmt.Errorf("failed to close stdin: %w", err) + } + return ps.cmd.Process.Kill() +}