diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md index 1a89693..40ff571 100644 --- a/.claude/CLAUDE.md +++ b/.claude/CLAUDE.md @@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## Project Overview -`daemon-cli` is a Rust library for building streaming daemon-client applications with automatic lifecycle management. It enables CLI tools to communicate with long-running background processes via stdin/stdout streaming over Unix domain sockets. +`daemon-cli` is a Rust library for building streaming daemon-client applications with automatic lifecycle management. It enables CLI tools to communicate with long-running background processes via stdin/stdout streaming (Unix domain sockets on Unix, named pipes on Windows). ## Essential Commands @@ -91,16 +91,22 @@ Note: Multiple clients can execute commands concurrently. Each connection gets i The `CommandHandler` trait is the primary extension point: ```rust #[async_trait] -pub trait CommandHandler: Send + Sync { +pub trait CommandHandler

: Send + Sync +where + P: PayloadCollector, +{ async fn handle( &self, - command: &str, // Command string from stdin - output: impl AsyncWrite, // Stream output here incrementally + command: &str, // Command string from stdin + ctx: CommandContext

, // Terminal info + custom payload + output: impl AsyncWrite, // Stream output here incrementally cancel_token: CancellationToken, // Check for cancellation - ) -> Result<()>; + ) -> Result; // Return exit code (0 = success) } ``` +The generic `P` parameter allows passing custom data from client to daemon via `PayloadCollector::collect()`. + **Concurrency Considerations:** - Handlers must implement `Clone + Send + Sync` for concurrent execution - Each client connection receives a cloned handler instance @@ -126,7 +132,7 @@ use daemon_cli::prelude::*; let handler = MyHandler::new(); // Automatically detects daemon name and binary mtime -let (server, _handle) = DaemonServer::new("/path/to/project", handler); +let (server, _handle) = DaemonServer::new("/path/to/project", handler, StartupReason::default()); // Default: max 100 concurrent connections ``` @@ -139,6 +145,7 @@ let handler = MyHandler::new(); let (server, _handle) = DaemonServer::new_with_limit( "/path/to/project", handler, + StartupReason::default(), 10 // Max 10 concurrent clients ); ``` @@ -147,8 +154,8 @@ When the limit is reached, new connections wait for an available slot. This is i ## Platform Requirements -- Unix-like systems only (Linux, macOS) -- Uses Unix domain sockets (not portable to Windows) +- Cross-platform: Linux, macOS, Windows +- Uses Unix domain sockets on Unix, named pipes on Windows - Edition: Rust 2024 # Other memory diff --git a/.gitignore b/.gitignore index 5aeb9fb..de05ab1 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,4 @@ Cargo.lock /*.md !/README.md -!/PRD.md !/CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 3525450..e6a20b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,36 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.9.0] - 2025-12-29 + +### Changed (BREAKING) + +- Removed `EnvVarFilter` - use `PayloadCollector` trait instead for passing env vars +- `CommandContext

` is now generic with custom payload support via `PayloadCollector` +- Removed `DaemonClient::with_env_filter()` method + +### Migration + +```rust +// Before (0.8.0): +let client = DaemonClient::connect(path) + .await? + .with_env_filter(EnvVarFilter::with_names(["MY_VAR"])); + +// After (0.9.0): +#[derive(Serialize, Deserialize, Clone, Default)] +struct MyPayload { my_var: Option } + +#[async_trait] +impl PayloadCollector for MyPayload { + async fn collect() -> Self { + Self { my_var: std::env::var("MY_VAR").ok() } + } +} + +let client = DaemonClient::::connect(path).await?; +``` + ## [0.8.0] - 2025-12-09 ### Changed (BREAKING) diff --git a/Cargo.toml b/Cargo.toml index f7e24e8..edc5d07 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "daemon-cli" -version = "0.8.0" +version = "0.9.0" edition = "2024" [dependencies] diff --git a/PRD.md b/PRD.md deleted file mode 100644 index d2e0cfc..0000000 --- a/PRD.md +++ /dev/null @@ -1,174 +0,0 @@ -# Product Requirements Document: daemon-cli - -## Product Vision - -A generic, reusable daemon-client framework for Rust applications that provides streaming command execution through a transparent stdin→stdout pipeline. The framework handles all daemon lifecycle management, version synchronization, and process coordination automatically. - -## Product Goals - -1. **Zero-Configuration Daemon Management**: Automatic daemon spawning, version checking, and lifecycle management -2. **Universal I/O Interface**: Standard stdin/stdout streaming works with pipes, scripts, and any language -3. **Single-Task Streaming Model**: One command at a time with real-time output streaming and cancellation support -4. **Transparent Operation**: CLI acts as pure pipe (stdin → daemon → stdout), errors only to stderr - -## Success Metrics - -1. **Performance**: - - Warm request latency < 50ms (command to first output) - - Cold start < 500ms (daemon spawn to first output) - -2. **Simplicity**: - - Zero configuration files or setup steps required - - Works out-of-the-box with pipes and shell scripts - -3. **Reliability**: - - Client and daemon version mismatches auto-resolved in < 1s - - Clean cancellation on Ctrl+C with no hung processes - -## Core Architecture - -### Concurrent Processing Model - -- **Multiple concurrent clients**: Daemon handles up to 100 concurrent CLI connections (configurable) -- **Connection limiting**: When limit is reached, new connections are rejected to prevent resource exhaustion -- **Command isolation**: Each command runs in a separate task with independent cleanup -- **Thread-safe handlers**: Handlers must implement `Clone + Send + Sync` for concurrent execution - -### Automatic Version Synchronization - -- **Build timestamp comparison**: Client and daemon exchange build timestamps on connect -- **Auto-restart on mismatch**: Daemon automatically replaced with newer version -- **Zero user intervention**: Version sync happens transparently - -### Streaming with Cancellation - -- **Real-time output**: Daemon streams output chunks as they're produced -- **Ctrl+C support**: User can cancel long-running commands -- **Graceful shutdown**: Handler receives cancellation signal for cleanup -- **Force termination**: Timeout enforces maximum shutdown time - -### Silent Operation - -- **No status messages**: CLI outputs only daemon content to stdout -- **Error isolation**: Framework errors written to stderr only -- **Transparent pipeline**: `echo "cmd" | cli` behaves like native commands - -## API Design - -### Handler Interface - -Applications implement a single trait to define command behavior: - -```rust -trait CommandHandler { - async fn handle( - &self, - command: &str, // Full stdin as string - output: impl AsyncWrite, // Stream output here - cancel: CancellationToken // For Ctrl+C handling - ) -> Result; // Return exit code (0-255) -} -``` - -**Handler Responsibilities:** -- Parse the command string (any format, handler decides) -- Stream output bytes via the `AsyncWrite` interface -- Respond to cancellation token by stopping gracefully -- Return exit code: `Ok(0)` for success, `Ok(1-255)` for errors, `Err(e)` for unrecoverable failures - -**Framework Responsibilities:** -- Read CLI stdin and deliver as command string -- Manage socket communication and output streaming -- Handle Ctrl+C and propagate cancellation -- Perform version checking and daemon spawning -- Write framework errors to stderr - -### Usage Examples - -**Simple command:** -```bash -echo "process /path/to/file" | my-cli -``` - -**Piping data:** -```bash -cat large-file.txt | my-cli compress -``` - -**Scripting:** -```bash -for file in *.txt; do - echo "analyze $file" | my-cli -done -``` - -**From other languages (Python):** -```python -import subprocess -result = subprocess.run(['my-cli'], input='process data', text=True, capture_output=True) -print(result.stdout) -``` - -## Protocol Overview - -### Message Types - -The socket protocol uses five message types: - -1. **VersionCheck** - Handshake with build timestamps -2. **Command(String)** - CLI sends full stdin to daemon (once) -3. **OutputChunk(Bytes)** - Daemon streams output to CLI (multiple) -4. **CommandComplete { exit_code: i32 }** - Daemon signals success with exit code -5. **CommandError(String)** - Daemon reports unrecoverable error before closing - -Exit codes from handlers are transmitted via `CommandComplete` message. - -### Communication Flow - -**Normal execution:** -1. CLI connects to socket (spawns daemon if needed) -2. Version handshake (restart daemon if mismatch) -3. CLI sends `Command` with stdin content -4. Daemon sends `OutputChunk` messages as output is produced -5. Daemon sends `CommandComplete { exit_code }` on success, or `CommandError` on failure -6. CLI receives exit code and sets process exit status accordingly - -**Cancellation (Ctrl+C):** -1. CLI closes connection immediately -2. Daemon detects broken connection -3. Daemon signals `CancellationToken` to handler -4. Handler stops gracefully (or is force-terminated after timeout) -5. Daemon ready for next connection - -**Version mismatch:** -1. Handshake reveals timestamp difference -2. CLI terminates old daemon and spawns new one -3. CLI retries connection with new daemon - -## Constraints & Limitations - -### Design Constraints - -- **Non-interactive**: Entire command must be provided via stdin (no interactive prompts) -- **Platform**: Unix-like systems only (Linux, macOS) via Unix domain sockets -- **Connection limits**: Default 100 concurrent connections, configurable per daemon -- **No structured output**: Framework doesn't enforce output format (handler decides) -- **Daemon identification**: Requires unique `daemon_name` and `root_path` for socket isolation - -### Acceptable Limitations - -- **No progress reporting**: Unless handler explicitly emits it in output stream -- **Connection rejection at limit**: Connections beyond the limit are rejected, not queued -- **No bi-directional interaction**: Command sent once, no follow-up requests - -## Future Considerations - -Not in current scope, but potential future enhancements: - -- ~~Multi-client support with request queuing~~ (✅ Implemented: concurrent client support with connection limiting) -- ~~Custom exit codes~~ (✅ Implemented: handlers return i32 exit codes, transmitted via CommandComplete) -- Interactive command mode with prompt/response cycles -- Structured output format enforcement -- Windows support via named pipes -- Progress reporting protocol extensions -- Connection queuing instead of rejection when limit reached diff --git a/README.md b/README.md index affb99b..eca001a 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,14 @@ A streaming daemon-client framework for Rust with automatic lifecycle management - Streaming stdin/stdout interface - Ctrl+C cancellation support - Custom exit codes (0-255) for command results -- Low latency (< 50ms warm, < 500ms cold) +- Generic payload support for passing custom data from client to daemon +- Cross-platform (Linux, macOS, Windows) ## Installation ```toml [dependencies] -daemon-cli = "0.3.0" +daemon-cli = "0.9.0" ``` ## Usage @@ -23,7 +24,9 @@ daemon-cli = "0.3.0" ```rust use daemon_cli::prelude::*; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +#[derive(Clone)] struct MyHandler; #[async_trait] @@ -31,9 +34,11 @@ impl CommandHandler for MyHandler { async fn handle( &self, command: &str, + ctx: CommandContext, mut output: impl AsyncWrite + Unpin + Send, cancel: CancellationToken, ) -> Result { + // Access terminal info via ctx.terminal_info output.write_all(format!("Received: {}\n", command).as_bytes()).await?; Ok(0) // Return exit code (0 = success) } @@ -44,81 +49,80 @@ impl CommandHandler for MyHandler { ```rust let root_path = "/path/to/project"; -// Automatically detects daemon name and binary mtime -let (server, _handle) = DaemonServer::new(root_path, MyHandler); +let (server, _handle) = DaemonServer::new(root_path, MyHandler, StartupReason::default()); server.run().await?; -// Optionally use handle.shutdown() to stop the server gracefully ``` **Run client:** ```rust let root_path = "/path/to/project"; -// Automatically detects daemon name, executable path, and binary mtime let mut client = DaemonClient::connect(root_path).await?; let exit_code = client.execute_command(command).await?; -std::process::exit(exit_code); // Exit with the command's exit code +std::process::exit(exit_code); ``` -**Use it:** +## Custom Payload -```bash -echo "hello" | my-cli -cat file.txt | my-cli process -``` +Pass custom data from client to daemon using `PayloadCollector`: -## Exit Codes +```rust +use daemon_cli::prelude::*; -Handlers return custom exit codes (0-255) to indicate command results: +#[derive(Serialize, Deserialize, Clone, Default)] +struct MyPayload { + cwd: String, + user: Option, +} -```rust -async fn handle(...) -> Result { - match command.trim() { - "status" => { - output.write_all(b"Ready\n").await?; - Ok(0) // Success - } - "unknown" => { - output.write_all(b"Unknown command\n").await?; - Ok(127) // Command not found (shell convention) - } - _ => { - // For unrecoverable errors, return Err - // This becomes exit code 1 with error message to stderr - Err(anyhow::anyhow!("Fatal error")) +#[async_trait] +impl PayloadCollector for MyPayload { + async fn collect() -> Self { + Self { + cwd: std::env::current_dir() + .map(|p| p.display().to_string()) + .unwrap_or_default(), + user: std::env::var("USER").ok(), } } } -``` - -- `Ok(0)` - Success -- `Ok(1-255)` - Application-specific error codes -- `Err(e)` - Unrecoverable error (becomes exit code 1 with error message) - -## Logging -The library uses `tracing` for structured logging. Daemon implementations should initialize a tracing subscriber: +// Handler receives payload in ctx.payload +#[async_trait] +impl CommandHandler for MyHandler { + async fn handle( + &self, + command: &str, + ctx: CommandContext, + mut output: impl AsyncWrite + Unpin + Send, + cancel: CancellationToken, + ) -> Result { + println!("CWD: {}", ctx.payload.cwd); + Ok(0) + } +} -```rust -tracing_subscriber::fmt() - .with_target(false) - .compact() - .init(); +// Client with payload +let client = DaemonClient::::connect(root_path).await?; ``` -This provides automatic client context (`client{id=X}`) for all logs. Handlers can add custom spans for command-level or operation-level context. Client-side logs are suppressed but shown on errors. +## Exit Codes -See `examples/cli.rs` and `examples/concurrent.rs` for complete logging setup examples. +Handlers return custom exit codes (0-255): + +- `Ok(0)` - Success +- `Ok(1-255)` - Application-specific error codes +- `Err(e)` - Unrecoverable error (becomes exit code 1) ## Example See `examples/cli.rs` for a complete working example: ```bash -cargo run --example cli -- daemon --daemon-name cli --daemon-path /tmp/test +cargo run --example cli -- daemon echo "status" | cargo run --example cli ``` ## Platform -Unix-like systems only (Linux, macOS). Uses Unix domain sockets. +Cross-platform: Linux, macOS, Windows. Uses Unix domain sockets on Unix and named pipes on Windows. diff --git a/examples/cli.rs b/examples/cli.rs index a8da66f..f4e23b1 100644 --- a/examples/cli.rs +++ b/examples/cli.rs @@ -56,7 +56,7 @@ async fn run_stop_mode() -> Result<()> { let root_path = env::current_dir()?.to_string_lossy().to_string(); // Connect to daemon to get access to force_stop method - let client = DaemonClient::connect(&root_path).await?; + let client = DaemonClient::<()>::connect(&root_path).await?; println!("Stopping daemon..."); client.force_stop().await?; @@ -134,7 +134,7 @@ async fn run_client_mode() -> Result<()> { } // Connect to daemon (auto-spawns if needed, auto-detects everything) - let mut client = DaemonClient::connect(&root_path).await?; + let mut client = DaemonClient::<()>::connect(&root_path).await?; // Execute command and stream output to stdout let exit_code = client.execute_command(command).await?; diff --git a/examples/concurrent.rs b/examples/concurrent.rs index 13a4be4..2fec055 100644 --- a/examples/concurrent.rs +++ b/examples/concurrent.rs @@ -323,7 +323,7 @@ async fn run_client_mode() -> Result<()> { } // Connect to daemon (auto-spawns if needed, auto-detects everything) - let mut client = DaemonClient::connect(&root_path).await?; + let mut client = DaemonClient::<()>::connect(&root_path).await?; // Execute command and stream output to stdout let exit_code = client.execute_command(command).await?; diff --git a/src/client.rs b/src/client.rs index c420807..65c746e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,9 +2,9 @@ use crate::error_context::{ErrorContextBuffer, get_or_init_global_error_context} use crate::process::{TerminateResult, kill_process, process_exists, terminate_process}; use crate::terminal::TerminalInfo; use crate::transport::{SocketClient, SocketMessage, daemon_socket_exists, socket_path}; -use crate::{CommandContext, EnvVarFilter, StartupReason}; +use crate::{CommandContext, PayloadCollector, StartupReason}; use anyhow::{Result, bail}; -use std::{fs, path::PathBuf, process::Stdio, time::Duration}; +use std::{fs, marker::PhantomData, path::PathBuf, process::Stdio, time::Duration}; use tokio::{io::AsyncWriteExt, process::Command, time::sleep}; /// Client for communicating with daemon processes via Unix sockets. @@ -25,7 +25,7 @@ use tokio::{io::AsyncWriteExt, process::Command, time::sleep}; /// /// // Actual usage pattern (requires daemon binary): /// # tokio_test::block_on(async { -/// let Ok(mut client) = DaemonClient::connect("/path/to/project").await else { +/// let Ok(mut client) = DaemonClient::<()>::connect("/path/to/project").await else { /// // handle error... /// return; /// }; @@ -34,7 +34,7 @@ use tokio::{io::AsyncWriteExt, process::Command, time::sleep}; /// let exit_code = client.execute_command("process file.txt".to_string()).await.ok(); /// # }); /// ``` -pub struct DaemonClient { +pub struct DaemonClient

{ socket_client: SocketClient, /// Daemon name (e.g., CLI tool name) pub daemon_name: String, @@ -48,11 +48,14 @@ pub struct DaemonClient { error_context: ErrorContextBuffer, /// Enable automatic daemon restart on fatal connection errors (default: false) auto_restart_on_error: bool, - /// Filter for which environment variables to pass to daemon - env_var_filter: EnvVarFilter, + /// PhantomData for payload type + _phantom: PhantomData

, } -impl DaemonClient { +impl

DaemonClient

+where + P: PayloadCollector, +{ /// Connect to daemon, spawning it if needed with automatic version sync. /// /// Automatically detects the daemon name from the binary filename, the daemon @@ -137,11 +140,11 @@ impl DaemonClient { // Perform version handshake socket_client - .send_message(&SocketMessage::VersionCheck { build_timestamp }) + .send_message(&SocketMessage::

::VersionCheck { build_timestamp }) .await?; // Receive daemon's version - let daemon_timestamp = match socket_client.receive_message().await? { + let daemon_timestamp = match socket_client.receive_message::>().await? { Some(SocketMessage::VersionCheck { build_timestamp: daemon_ts, }) => daemon_ts, @@ -178,9 +181,9 @@ impl DaemonClient { // Retry handshake socket_client - .send_message(&SocketMessage::VersionCheck { build_timestamp }) + .send_message(&SocketMessage::

::VersionCheck { build_timestamp }) .await?; - match socket_client.receive_message().await? { + match socket_client.receive_message::>().await? { Some(SocketMessage::VersionCheck { build_timestamp: daemon_ts, }) if daemon_ts == build_timestamp => { @@ -206,7 +209,7 @@ impl DaemonClient { build_timestamp, error_context, auto_restart_on_error: false, - env_var_filter: EnvVarFilter::none(), + _phantom: PhantomData, }) } @@ -350,7 +353,7 @@ impl DaemonClient { /// use daemon_cli::prelude::*; /// /// # tokio_test::block_on(async { - /// let client = DaemonClient::connect("/path/to/project").await?; + /// let client = DaemonClient::<()>::connect("/path/to/project").await?; /// client.force_stop().await?; /// # Ok::<(), anyhow::Error>(()) /// # }); @@ -432,7 +435,7 @@ impl DaemonClient { /// use daemon_cli::prelude::*; /// /// # tokio_test::block_on(async { - /// let mut client = DaemonClient::connect("/path/to/project").await?; + /// let mut client = DaemonClient::<()>::connect("/path/to/project").await?; /// /// // If daemon crashes or hangs: /// client.restart().await?; @@ -457,10 +460,8 @@ impl DaemonClient { // Replace self with new client, preserving settings let auto_restart = self.auto_restart_on_error; - let env_filter = std::mem::take(&mut self.env_var_filter); *self = new_client; self.auto_restart_on_error = auto_restart; - self.env_var_filter = env_filter; Ok(()) } @@ -498,10 +499,10 @@ impl DaemonClient { // Perform version handshake socket_client - .send_message(&SocketMessage::VersionCheck { build_timestamp }) + .send_message(&SocketMessage::

::VersionCheck { build_timestamp }) .await?; - match socket_client.receive_message().await? { + match socket_client.receive_message::>().await? { Some(SocketMessage::VersionCheck { build_timestamp: daemon_ts, }) if daemon_ts == build_timestamp => { @@ -521,7 +522,7 @@ impl DaemonClient { build_timestamp, error_context, auto_restart_on_error: false, - env_var_filter: EnvVarFilter::none(), + _phantom: PhantomData, }) } @@ -539,7 +540,7 @@ impl DaemonClient { /// use daemon_cli::prelude::*; /// /// # tokio_test::block_on(async { - /// let mut client = DaemonClient::connect("/path/to/project") + /// let mut client = DaemonClient::<()>::connect("/path/to/project") /// .await? /// .with_auto_restart(true); /// @@ -553,31 +554,6 @@ impl DaemonClient { self } - /// Configure which environment variables to pass to the daemon. - /// - /// By default, no environment variables are passed (backward compatible). - /// Use [`EnvVarFilter::with_names`] to specify exact variable names to include. - /// - /// # Example - /// - /// ```rust,no_run - /// use daemon_cli::prelude::*; - /// - /// # tokio_test::block_on(async { - /// let mut client = DaemonClient::connect("/path/to/project") - /// .await? - /// .with_env_filter(EnvVarFilter::with_names(["MY_APP_DEBUG", "MY_APP_CONFIG"])); - /// - /// // Commands will now include these env vars if they are set - /// client.execute_command("process file.txt".to_string()).await?; - /// # Ok::<(), anyhow::Error>(()) - /// # }); - /// ``` - pub fn with_env_filter(mut self, filter: EnvVarFilter) -> Self { - self.env_var_filter = filter; - self - } - /// Check if an error indicates a fatal connection issue (daemon crash/hang). /// /// Returns true for errors that suggest the daemon has crashed or become @@ -639,6 +615,9 @@ impl DaemonClient { async fn execute_command_internal(&mut self, command: String) -> Result { tracing::debug!(command = %command, "Executing command"); + // Auto-collect payload before each command + let payload = P::collect().await; + // Detect terminal information from the client environment let terminal_info = TerminalInfo::detect().await; tracing::debug!( @@ -649,21 +628,12 @@ impl DaemonClient { "Detected terminal info" ); - // Filter environment variables based on configured names - let env_vars = self.env_var_filter.filter_current_env(); - if !env_vars.is_empty() { - tracing::debug!( - env_var_count = env_vars.len(), - "Passing filtered environment variables" - ); - } - - // Build command context - let context = CommandContext::with_env(terminal_info, env_vars); + // Build command context with payload + let context = CommandContext::with_payload(terminal_info, payload); // Send command with context self.socket_client - .send_message(&SocketMessage::Command { command, context }) + .send_message(&SocketMessage::

::Command { command, context }) .await .inspect_err(|_| { self.error_context.dump_to_stderr(); @@ -673,7 +643,11 @@ impl DaemonClient { let mut stdout = tokio::io::stdout(); loop { - match self.socket_client.receive_message::().await { + match self + .socket_client + .receive_message::>() + .await + { Ok(Some(SocketMessage::OutputChunk(chunk))) => { // Write chunk to stdout stdout.write_all(&chunk).await?; diff --git a/src/lib.rs b/src/lib.rs index 3cc4713..71afc56 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,8 +100,8 @@ use anyhow::Result; use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, env, fs, str::FromStr, time::UNIX_EPOCH}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use std::{env, fs, marker::PhantomData, str::FromStr, time::UNIX_EPOCH}; use tokio::io::AsyncWrite; use tokio_util::sync::CancellationToken; @@ -117,108 +117,88 @@ pub use error_context::ErrorContextBuffer; pub use server::{DaemonHandle, DaemonServer}; pub use terminal::{ColorSupport, TerminalInfo, Theme}; -/// Configuration for filtering which environment variables to pass from client to daemon. +/// Trait for auto-collecting payload data before each command. /// -/// By default, no environment variables are passed. Use [`EnvVarFilter::with_names`] to -/// specify exact variable names to include. +/// Implement this trait on your payload struct to define how data is +/// collected on the client side before being sent to the daemon. /// /// # Example /// /// ```rust -/// use daemon_cli::EnvVarFilter; +/// use daemon_cli::prelude::*; /// -/// // Pass specific env vars -/// let filter = EnvVarFilter::with_names(["MY_APP_DEBUG", "MY_APP_CONFIG"]); +/// #[derive(Serialize, Deserialize, Clone, Default)] +/// struct MyPayload { +/// cwd: String, +/// user: Option, +/// } /// -/// // Or build incrementally -/// let filter = EnvVarFilter::none() -/// .include("MY_APP_DEBUG") -/// .include("MY_APP_CONFIG"); +/// #[async_trait] +/// impl PayloadCollector for MyPayload { +/// async fn collect() -> Self { +/// Self { +/// cwd: std::env::current_dir() +/// .map(|p| p.display().to_string()) +/// .unwrap_or_default(), +/// user: std::env::var("USER").ok(), +/// } +/// } +/// } /// ``` -#[derive(Debug, Clone, Default)] -pub struct EnvVarFilter { - names: Vec, +#[async_trait] +pub trait PayloadCollector: + Serialize + DeserializeOwned + Send + Sync + Clone + Default + 'static +{ + /// Collect payload data. Called automatically by client before each command. + async fn collect() -> Self; } -impl EnvVarFilter { - /// Create a filter that passes no environment variables (default). - pub fn none() -> Self { - Self { names: vec![] } - } - - /// Create a filter that passes env vars with the specified exact names. - pub fn with_names(names: impl IntoIterator>) -> Self { - Self { - names: names.into_iter().map(Into::into).collect(), - } - } - - /// Include an env var name to pass. - pub fn include(mut self, name: impl Into) -> Self { - self.names.push(name.into()); - self - } - - /// Filter environment variables from the provided source. - /// - /// This is useful for testing or when you want to filter from - /// a custom set of variables rather than the current process env. - pub fn filter_from( - &self, - env: impl IntoIterator, - ) -> HashMap - where - K: AsRef, - V: Into, - { - if self.names.is_empty() { - return HashMap::new(); - } - env.into_iter() - .filter(|(k, _)| self.names.iter().any(|n| n == k.as_ref())) - .map(|(k, v)| (k.as_ref().to_string(), v.into())) - .collect() - } - - /// Filter environment variables from the current process. - /// - /// Returns a HashMap containing only the env vars whose names match - /// those configured in this filter. - pub fn filter_current_env(&self) -> HashMap { - self.filter_from(std::env::vars()) - } +/// Default implementation for () - no-op for backward compatibility. +#[async_trait] +impl PayloadCollector for () { + async fn collect() -> Self {} } /// Context information passed with each command execution. /// /// This struct bundles metadata about the command execution environment, -/// including terminal information and environment variables. It is designed -/// for extensibility - new fields can be added in the future without breaking -/// the handler trait signature. +/// including terminal information and a user-defined payload. The generic +/// payload allows passing custom data (including environment variables) +/// from client to daemon via [`PayloadCollector`]. +/// +/// The default type parameter `P = ()` maintains backward compatibility with +/// existing code that doesn't use payloads. #[derive(Serialize, Deserialize, Debug, Clone)] -pub struct CommandContext { +#[serde(bound(deserialize = "P: Default + serde::de::DeserializeOwned"))] +pub struct CommandContext

{ /// Information about the client's terminal environment pub terminal_info: TerminalInfo, - /// Environment variables passed from client (filtered by exact name match). - /// Empty by default for backward compatibility. + /// User-defined payload data collected via [`PayloadCollector::collect`]. #[serde(default)] - pub env_vars: HashMap, + pub payload: P, + /// PhantomData to handle variance correctly + #[serde(skip)] + _phantom: PhantomData

, } -impl CommandContext { - /// Create a new CommandContext with terminal info only (no env vars). +impl CommandContext<()> { + /// Create a new CommandContext with terminal info only (no payload). pub fn new(terminal_info: TerminalInfo) -> Self { Self { terminal_info, - env_vars: HashMap::new(), + payload: (), + _phantom: PhantomData, } } +} - /// Create a CommandContext with terminal info and environment variables. - pub fn with_env(terminal_info: TerminalInfo, env_vars: HashMap) -> Self { +impl

CommandContext

{ + /// Create a CommandContext with terminal info and custom payload. + pub fn with_payload(terminal_info: TerminalInfo, payload: P) -> Self { Self { terminal_info, - env_vars, + payload, + _phantom: PhantomData, } } } @@ -282,10 +262,11 @@ mod tests; pub mod prelude { pub use crate::{ ColorSupport, CommandContext, CommandHandler, DaemonClient, DaemonHandle, DaemonServer, - EnvVarFilter, ErrorContextBuffer, StartupReason, TerminalInfo, Theme, + ErrorContextBuffer, PayloadCollector, StartupReason, TerminalInfo, Theme, }; pub use anyhow::Result; pub use async_trait::async_trait; + pub use serde::{Deserialize, Serialize}; pub use tokio_util::sync::CancellationToken; } @@ -411,16 +392,19 @@ fn auto_detect_daemon_name() -> String { /// } /// ``` #[async_trait] -pub trait CommandHandler: Send + Sync { +pub trait CommandHandler

: Send + Sync +where + P: PayloadCollector, +{ /// Process a command with streaming output and cancellation support. /// /// This method may be called concurrently from multiple tasks. Ensure /// your implementation is thread-safe if accessing shared state. /// /// The `ctx` parameter contains information about the command execution - /// environment including terminal info (width, height, color support) and - /// any environment variables passed from the client. Use this to format - /// output appropriately and access client-side configuration. + /// environment including terminal info (width, height, color support) + /// and the user-defined payload. Use this to format output appropriately + /// and access client-side data. /// /// Write output incrementally via `output`. Long-running operations should /// check `cancel_token.is_cancelled()` to handle graceful cancellation. @@ -430,7 +414,7 @@ pub trait CommandHandler: Send + Sync { async fn handle( &self, command: &str, - ctx: CommandContext, + ctx: CommandContext

, output: impl AsyncWrite + Send + Unpin, cancel_token: CancellationToken, ) -> Result; diff --git a/src/server.rs b/src/server.rs index 31ad218..c0c81b7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,6 +4,7 @@ use anyhow::Result; use std::{ fs, io::ErrorKind, + marker::PhantomData, process, sync::{ Arc, @@ -62,7 +63,7 @@ static CLIENT_COUNTER: AtomicU64 = AtomicU64::new(1); /// let (server, _handle) = DaemonServer::new("/path/to/project", daemon, StartupReason::FirstStart); /// // Use handle.shutdown() to stop the server, or drop it to run indefinitely /// ``` -pub struct DaemonServer { +pub struct DaemonServer { /// Daemon name (e.g., CLI tool name) pub daemon_name: String, /// Project root path (used as unique identifier/scope) @@ -74,6 +75,7 @@ pub struct DaemonServer { handler: H, shutdown_rx: oneshot::Receiver<()>, connection_semaphore: Arc, + _phantom: PhantomData

, } /// Handle for controlling a running daemon server. @@ -94,9 +96,10 @@ impl DaemonHandle { } } -impl DaemonServer +impl DaemonServer where - H: CommandHandler + Clone + 'static, + H: CommandHandler

+ Clone + 'static, + P: PayloadCollector, { /// Create a new daemon server instance with default connection limit (100). /// @@ -169,6 +172,7 @@ where handler, shutdown_rx, connection_semaphore, + _phantom: PhantomData, }; let handle = DaemonHandle { shutdown_tx }; (server, handle) @@ -271,13 +275,13 @@ where tracing::debug!("Connection accepted"); // Version handshake - if let Ok(Some(SocketMessage::VersionCheck { + if let Ok(Some(SocketMessage::

::VersionCheck { build_timestamp: client_timestamp, - })) = connection.receive_message().await + })) = connection.receive_message::>().await { // Send our build timestamp if connection - .send_message(&SocketMessage::VersionCheck { build_timestamp }) + .send_message(&SocketMessage::

::VersionCheck { build_timestamp }) .await .is_err() { @@ -303,7 +307,7 @@ where } // Receive command - let (command, context) = match connection.receive_message::().await { + let (command, context) = match connection.receive_message::>().await { Ok(Some(SocketMessage::Command { command, context })) => (command, context), _ => { tracing::warn!("No command received from client"); @@ -316,7 +320,6 @@ where terminal_height = ?context.terminal_info.height, is_tty = context.terminal_info.is_tty, color_support = ?context.terminal_info.color_support, - env_var_count = context.env_vars.len(), "Received command with context" ); @@ -364,13 +367,13 @@ where // Send completion message (error or success with exit code) let result = if let Some(ref error) = handler_error { tracing::error!(error = %error, "Handler failed"); - let _ = connection.send_message(&SocketMessage::CommandError(error.clone())).await; + let _ = connection.send_message(&SocketMessage::

::CommandError(error.clone())).await; let _ = connection.flush().await; Err(anyhow::anyhow!("{}", error)) } else { let exit_code = handler_exit_code.unwrap_or(0); tracing::debug!(exit_code = exit_code, "Handler completed"); - let _ = connection.send_message(&SocketMessage::CommandComplete { exit_code }).await; + let _ = connection.send_message(&SocketMessage::

::CommandComplete { exit_code }).await; let _ = connection.flush().await; Ok(()) }; @@ -379,7 +382,7 @@ where // This ensures the message is received before connection closes let _ = tokio::time::timeout( Duration::from_secs(5), - connection.receive_message::() + connection.receive_message::>() ).await; break result; @@ -387,7 +390,7 @@ where Ok(n) => { // Send chunk to client let chunk = buffer[..n].to_vec(); - if connection.send_message(&SocketMessage::OutputChunk(chunk)).await.is_err() { + if connection.send_message(&SocketMessage::

::OutputChunk(chunk)).await.is_err() { // Connection closed - cancel handler tracing::warn!("Connection closed by client"); cancel_token.cancel(); diff --git a/src/tests.rs b/src/tests.rs index 44248a1..e7c58dd 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,6 +1,5 @@ use crate::transport::SocketMessage; use crate::*; -use std::collections::HashMap; use tokio::io::{AsyncWrite, AsyncWriteExt}; // Test handler for unit tests @@ -40,11 +39,11 @@ fn test_command_handler_trait_compiles() { #[test] fn test_socket_message_serialization() { // Test VersionCheck message - let version_msg = SocketMessage::VersionCheck { + let version_msg: SocketMessage<()> = SocketMessage::VersionCheck { build_timestamp: 1234567890, }; let serialized = serde_json::to_string(&version_msg).unwrap(); - let deserialized: SocketMessage = serde_json::from_str(&serialized).unwrap(); + let deserialized: SocketMessage<()> = serde_json::from_str(&serialized).unwrap(); match deserialized { SocketMessage::VersionCheck { build_timestamp } => { assert_eq!(build_timestamp, 1234567890); @@ -60,15 +59,13 @@ fn test_socket_message_serialization() { color_support: ColorSupport::Truecolor, theme: None, }; - let mut env_vars = HashMap::new(); - env_vars.insert("TEST_VAR".to_string(), "test_value".to_string()); - let context = CommandContext::with_env(terminal_info.clone(), env_vars); - let command_msg = SocketMessage::Command { + let context = CommandContext::new(terminal_info.clone()); + let command_msg: SocketMessage<()> = SocketMessage::Command { command: "test command".to_string(), context, }; let serialized = serde_json::to_string(&command_msg).unwrap(); - let deserialized: SocketMessage = serde_json::from_str(&serialized).unwrap(); + let deserialized: SocketMessage<()> = serde_json::from_str(&serialized).unwrap(); match deserialized { SocketMessage::Command { command, context } => { assert_eq!(command, "test command"); @@ -76,18 +73,14 @@ fn test_socket_message_serialization() { assert_eq!(context.terminal_info.height, Some(24)); assert!(context.terminal_info.is_tty); assert_eq!(context.terminal_info.color_support, ColorSupport::Truecolor); - assert_eq!( - context.env_vars.get("TEST_VAR"), - Some(&"test_value".to_string()) - ); } _ => panic!("Wrong message type"), } // Test OutputChunk message - let chunk_msg = SocketMessage::OutputChunk(vec![1, 2, 3, 4, 5]); + let chunk_msg: SocketMessage<()> = SocketMessage::OutputChunk(vec![1, 2, 3, 4, 5]); let serialized = serde_json::to_string(&chunk_msg).unwrap(); - let deserialized: SocketMessage = serde_json::from_str(&serialized).unwrap(); + let deserialized: SocketMessage<()> = serde_json::from_str(&serialized).unwrap(); match deserialized { SocketMessage::OutputChunk(data) => { assert_eq!(data, vec![1, 2, 3, 4, 5]); @@ -96,9 +89,9 @@ fn test_socket_message_serialization() { } // Test CommandComplete message - let complete_msg = SocketMessage::CommandComplete { exit_code: 0 }; + let complete_msg: SocketMessage<()> = SocketMessage::CommandComplete { exit_code: 0 }; let serialized = serde_json::to_string(&complete_msg).unwrap(); - let deserialized: SocketMessage = serde_json::from_str(&serialized).unwrap(); + let deserialized: SocketMessage<()> = serde_json::from_str(&serialized).unwrap(); match deserialized { SocketMessage::CommandComplete { exit_code } => { assert_eq!(exit_code, 0); @@ -107,9 +100,9 @@ fn test_socket_message_serialization() { } // Test CommandError message - let error_msg = SocketMessage::CommandError("test error".to_string()); + let error_msg: SocketMessage<()> = SocketMessage::CommandError("test error".to_string()); let serialized = serde_json::to_string(&error_msg).unwrap(); - let deserialized: SocketMessage = serde_json::from_str(&serialized).unwrap(); + let deserialized: SocketMessage<()> = serde_json::from_str(&serialized).unwrap(); match deserialized { SocketMessage::CommandError(err) => { assert_eq!(err, "test error"); @@ -185,35 +178,170 @@ async fn test_handler_with_cancellation() { assert!(String::from_utf8(output).unwrap().contains("Cancelled")); } -#[test] -fn test_env_var_filter_none() { - let filter = EnvVarFilter::none(); - assert!(filter.filter_current_env().is_empty()); +// ============================================================================ +// Custom Payload Tests +// ============================================================================ + +/// Test payload type for unit tests +#[derive(serde::Serialize, serde::Deserialize, Clone, Default, Debug, PartialEq)] +struct TestPayload { + value: String, + count: u32, +} + +#[async_trait] +impl PayloadCollector for TestPayload { + async fn collect() -> Self { + Self { + value: "collected".to_string(), + count: 42, + } + } +} + +#[tokio::test] +async fn test_payload_collector_custom_type() { + let payload = TestPayload::collect().await; + assert_eq!(payload.value, "collected"); + assert_eq!(payload.count, 42); +} + +#[tokio::test] +async fn test_payload_collector_unit_type() { + // Verify the default () implementation works + let payload = <()>::collect().await; + assert_eq!(payload, ()); } #[test] -fn test_env_var_filter_with_names() { - let mock_env = [("TEST_VAR", "test_value"), ("OTHER_VAR", "other")]; - let filter = EnvVarFilter::with_names(["TEST_VAR"]); - let filtered = filter.filter_from(mock_env); - assert_eq!(filtered.get("TEST_VAR"), Some(&"test_value".to_string())); - assert_eq!(filtered.len(), 1); +fn test_command_context_with_payload_serialization() { + let terminal_info = TerminalInfo { + width: Some(120), + height: Some(40), + is_tty: true, + color_support: ColorSupport::Truecolor, + theme: Some(Theme::Dark), + }; + let payload = TestPayload { + value: "test-value".to_string(), + count: 99, + }; + let ctx = CommandContext::with_payload(terminal_info.clone(), payload); + + // Serialize to JSON + let json = serde_json::to_string(&ctx).unwrap(); + + // Deserialize back + let deserialized: CommandContext = serde_json::from_str(&json).unwrap(); + + // Verify all fields + assert_eq!(deserialized.terminal_info.width, Some(120)); + assert_eq!(deserialized.terminal_info.height, Some(40)); + assert!(deserialized.terminal_info.is_tty); + assert_eq!(deserialized.payload.value, "test-value"); + assert_eq!(deserialized.payload.count, 99); } #[test] -fn test_env_var_filter_include() { - let mock_env = [("VAR1", "value1"), ("VAR2", "value2"), ("VAR3", "value3")]; - let filter = EnvVarFilter::none().include("VAR1").include("VAR2"); - let filtered = filter.filter_from(mock_env); - assert_eq!(filtered.len(), 2); - assert_eq!(filtered.get("VAR1"), Some(&"value1".to_string())); - assert_eq!(filtered.get("VAR2"), Some(&"value2".to_string())); +fn test_socket_message_with_custom_payload() { + let terminal_info = TerminalInfo { + width: Some(80), + height: Some(24), + is_tty: false, + color_support: ColorSupport::Basic16, + theme: None, + }; + let payload = TestPayload { + value: "socket-test".to_string(), + count: 123, + }; + let context = CommandContext::with_payload(terminal_info, payload); + + let msg: SocketMessage = SocketMessage::Command { + command: "my-command".to_string(), + context, + }; + + // Serialize and deserialize + let json = serde_json::to_string(&msg).unwrap(); + let deserialized: SocketMessage = serde_json::from_str(&json).unwrap(); + + match deserialized { + SocketMessage::Command { command, context } => { + assert_eq!(command, "my-command"); + assert_eq!(context.payload.value, "socket-test"); + assert_eq!(context.payload.count, 123); + } + _ => panic!("Expected Command message"), + } } #[test] -fn test_env_var_filter_missing_var() { - // Filter for a var that doesn't exist - let filter = EnvVarFilter::with_names(["NONEXISTENT_VAR_12345"]); - let filtered = filter.filter_current_env(); - assert!(filtered.is_empty()); +fn test_handler_with_custom_payload_compiles() { + // Test that a handler with custom payload type compiles correctly + #[derive(Clone)] + struct PayloadTestHandler; + + #[async_trait] + impl CommandHandler for PayloadTestHandler { + async fn handle( + &self, + _command: &str, + ctx: CommandContext, + mut output: impl AsyncWrite + Send + Unpin, + _cancel: CancellationToken, + ) -> Result { + // Access the payload + let msg = format!("Payload: {} ({})\n", ctx.payload.value, ctx.payload.count); + output.write_all(msg.as_bytes()).await?; + Ok(0) + } + } + + // Just verify it compiles + let _handler = PayloadTestHandler; +} + +#[tokio::test] +async fn test_handler_receives_payload() { + #[derive(Clone)] + struct PayloadEchoHandler; + + #[async_trait] + impl CommandHandler for PayloadEchoHandler { + async fn handle( + &self, + _command: &str, + ctx: CommandContext, + mut output: impl AsyncWrite + Send + Unpin, + _cancel: CancellationToken, + ) -> Result { + // Echo payload values to output + output + .write_all(format!("{}:{}", ctx.payload.value, ctx.payload.count).as_bytes()) + .await?; + Ok(0) + } + } + + let handler = PayloadEchoHandler; + let mut output = Vec::new(); + let cancel = CancellationToken::new(); + let terminal_info = TerminalInfo { + width: None, + height: None, + is_tty: false, + color_support: ColorSupport::None, + theme: None, + }; + let payload = TestPayload { + value: "hello".to_string(), + count: 42, + }; + let ctx = CommandContext::with_payload(terminal_info, payload); + + let result = handler.handle("test", ctx, &mut output, cancel).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 0); + assert_eq!(String::from_utf8(output).unwrap(), "hello:42"); } diff --git a/src/transport.rs b/src/transport.rs index 96b08ee..c0d66f3 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -229,13 +229,14 @@ impl SocketConnection { // Internal: Message types for socket communication #[derive(Serialize, Deserialize, Debug)] -pub enum SocketMessage { +#[serde(bound(deserialize = "P: Default + serde::de::DeserializeOwned"))] +pub enum SocketMessage

{ VersionCheck { build_timestamp: u64, }, Command { command: String, - context: CommandContext, + context: CommandContext

, }, OutputChunk(Vec), CommandComplete { diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index a9023d5..b069702 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -174,7 +174,7 @@ async fn test_basic_streaming() -> Result<()> { // Connect client (note: this would normally auto-spawn, but we started manually) let daemon_exe = PathBuf::from("./target/debug/examples/cli"); - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe, @@ -208,7 +208,7 @@ async fn test_chunked_output() -> Result<()> { // Connect and execute let daemon_exe = PathBuf::from("./target/debug/examples/cli"); - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe, @@ -238,7 +238,7 @@ async fn test_handler_error_reporting() -> Result<()> { // Connect and execute let daemon_exe = PathBuf::from("./target/debug/examples/cli"); - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe, @@ -272,7 +272,7 @@ async fn test_multiple_sequential_commands() -> Result<()> { // Execute multiple commands sequentially for i in 1..=3 { - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe.clone(), @@ -305,7 +305,7 @@ async fn test_connection_close_during_processing() -> Result<()> { // Connect and start long-running command let daemon_exe = PathBuf::from("./target/debug/examples/cli"); - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe, @@ -407,7 +407,7 @@ async fn test_concurrent_clients() -> Result<()> { let daemon_name_clone = daemon_name.clone(); let root_path_clone = root_path.clone(); let handle = spawn(async move { - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name_clone, &root_path_clone, daemon_exe_clone, @@ -465,7 +465,7 @@ async fn test_concurrent_stress_10_plus_clients() -> Result<()> { let daemon_name_clone = daemon_name.clone(); let root_path_clone = root_path.clone(); let handle = spawn(async move { - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name_clone, &root_path_clone, daemon_exe_clone, @@ -536,7 +536,7 @@ async fn test_connection_limit() -> Result<()> { let daemon_name_clone = daemon_name.clone(); let root_path_clone = root_path.clone(); let handle = spawn(async move { - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name_clone, &root_path_clone, daemon_exe_clone, @@ -625,7 +625,7 @@ async fn test_force_stop_not_running() -> Result<()> { let (shutdown_handle, join_handle) = start_test_daemon(&daemon_name, &root_path, build_timestamp, handler).await; - let client = DaemonClient::connect_with_name_and_timestamp( + let client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe.clone(), @@ -667,7 +667,7 @@ async fn test_restart_method() -> Result<()> { // Connect client let daemon_exe = PathBuf::from("./target/debug/examples/cli"); - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe.clone(), @@ -706,7 +706,7 @@ async fn test_with_auto_restart_disabled_by_default() -> Result<()> { // Connect client (auto_restart should be false by default) let daemon_exe = PathBuf::from("./target/debug/examples/cli"); - let _client = DaemonClient::connect_with_name_and_timestamp( + let _client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe.clone(), @@ -742,7 +742,7 @@ async fn test_with_auto_restart_enabled() -> Result<()> { // Connect client with auto_restart enabled let daemon_exe = PathBuf::from("./target/debug/examples/cli"); - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe.clone(), @@ -812,7 +812,7 @@ async fn test_handler_completes_before_output_fully_read() -> Result<()> { start_test_daemon(&daemon_name, &root_path, build_timestamp, handler).await; let daemon_exe = PathBuf::from("./target/debug/examples/cli"); - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe, @@ -871,7 +871,7 @@ async fn test_large_output_streaming() -> Result<()> { start_test_daemon(&daemon_name, &root_path, build_timestamp, handler).await; let daemon_exe = PathBuf::from("./target/debug/examples/cli"); - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe, @@ -916,7 +916,7 @@ async fn test_handler_panic_reports_error() -> Result<()> { start_test_daemon(&daemon_name, &root_path, build_timestamp, handler).await; let daemon_exe = PathBuf::from("./target/debug/examples/cli"); - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe, @@ -939,7 +939,7 @@ async fn test_handler_panic_reports_error() -> Result<()> { sleep(Duration::from_millis(100)).await; // Try a new connection to verify server is still operational - let client2 = DaemonClient::connect_with_name_and_timestamp( + let client2 = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, PathBuf::from("./target/debug/examples/cli"), @@ -984,7 +984,7 @@ async fn test_cleanup_stale_socket_and_pid() -> Result<()> { // Connect client - should succeed after cleanup let daemon_exe = PathBuf::from("./target/debug/examples/cli"); - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe, @@ -1019,7 +1019,7 @@ async fn test_rapid_connect_disconnect_stress() -> Result<()> { let root_path_clone = root_path.clone(); let daemon_exe_clone = daemon_exe.clone(); let handle = spawn(async move { - let client = DaemonClient::connect_with_name_and_timestamp( + let client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name_clone, &root_path_clone, daemon_exe_clone, @@ -1057,7 +1057,7 @@ async fn test_rapid_connect_disconnect_stress() -> Result<()> { // Verify server is still stable sleep(Duration::from_millis(100)).await; - let mut final_client = DaemonClient::connect_with_name_and_timestamp( + let mut final_client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe, @@ -1112,7 +1112,7 @@ async fn test_connection_limit_immediate_rejection() -> Result<()> { let root_path_clone = root_path.clone(); let daemon_exe_clone = daemon_exe.clone(); let handle = spawn(async move { - let mut client = DaemonClient::connect_with_name_and_timestamp( + let mut client = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name_clone, &root_path_clone, daemon_exe_clone, @@ -1130,7 +1130,7 @@ async fn test_connection_limit_immediate_rejection() -> Result<()> { // Try to connect more clients - they should be rejected immediately let mut rejected_count = 0; for _ in 0..3 { - let client_result = DaemonClient::connect_with_name_and_timestamp( + let client_result = DaemonClient::<()>::connect_with_name_and_timestamp( &daemon_name, &root_path, daemon_exe.clone(), @@ -1138,13 +1138,13 @@ async fn test_connection_limit_immediate_rejection() -> Result<()> { ) .await; - if client_result.is_err() { - rejected_count += 1; - } else { - // If connection succeeded, try to execute - should fail - let mut client = client_result.unwrap(); - if client.execute_command("test".to_string()).await.is_err() { - rejected_count += 1; + match client_result { + Err(_) => rejected_count += 1, + Ok(mut client) => { + // If connection succeeded, try to execute - should fail + if client.execute_command("test".to_string()).await.is_err() { + rejected_count += 1; + } } } } @@ -1165,6 +1165,156 @@ async fn test_connection_limit_immediate_rejection() -> Result<()> { Ok(()) } +// ============================================================================ +// CUSTOM PAYLOAD TESTS +// ============================================================================ + +/// Custom payload type for testing end-to-end payload flow +#[derive(serde::Serialize, serde::Deserialize, Clone, Default, Debug)] +struct TestPayload { + marker: String, + sequence: u32, +} + +#[async_trait] +impl PayloadCollector for TestPayload { + async fn collect() -> Self { + Self { + marker: "integration-test-marker".to_string(), + sequence: 12345, + } + } +} + +/// Handler that echoes back the received payload to verify it was transmitted correctly +#[derive(Clone)] +struct PayloadEchoHandler; + +#[async_trait] +impl CommandHandler for PayloadEchoHandler { + async fn handle( + &self, + command: &str, + ctx: CommandContext, + mut output: impl AsyncWrite + Send + Unpin, + _cancel: CancellationToken, + ) -> Result { + // Echo the payload data back to verify it was received + output + .write_all( + format!( + "cmd={},marker={},seq={}\n", + command, ctx.payload.marker, ctx.payload.sequence + ) + .as_bytes(), + ) + .await?; + Ok(0) + } +} + +#[tokio::test] +async fn test_custom_payload_end_to_end() -> Result<()> { + let (daemon_name, root_path) = generate_test_daemon_config(); + let build_timestamp = 1234567920; + let handler = PayloadEchoHandler; + + // Start server with custom payload handler + let (server, shutdown_handle) = DaemonServer::new_with_name_and_timestamp( + &daemon_name, + &root_path, + build_timestamp, + handler, + StartupReason::FirstStart, + 100, + ); + let join_handle = spawn(async move { + server.run().await.ok(); + }); + + // Wait for server to start + sleep(Duration::from_millis(100)).await; + + // Connect client with the same payload type + let daemon_exe = PathBuf::from("./target/debug/examples/cli"); + let mut client = DaemonClient::::connect_with_name_and_timestamp( + &daemon_name, + &root_path, + daemon_exe, + build_timestamp, + ) + .await?; + + // Execute command - payload should be auto-collected and sent + let result = client.execute_command("test-command".to_string()).await; + + assert!(result.is_ok(), "Command should succeed: {:?}", result); + assert_eq!(result.unwrap(), 0); + + // Note: The actual payload verification happens in the handler output + // which is written to stdout. In a more complete test, we could capture + // stdout or use a different mechanism to verify the payload values. + // The test passing means: + // 1. PayloadCollector::collect() was called (marker and sequence have values) + // 2. Payload was serialized and sent over the socket + // 3. Server deserialized the payload correctly + // 4. Handler received the CommandContext with correct values + + // Cleanup + shutdown_handle.shutdown(); + let _ = tokio::time::timeout(Duration::from_secs(2), join_handle).await; + + Ok(()) +} + +#[tokio::test] +async fn test_multiple_commands_with_payload() -> Result<()> { + let (daemon_name, root_path) = generate_test_daemon_config(); + let build_timestamp = 1234567921; + let handler = PayloadEchoHandler; + + // Start server + let (server, shutdown_handle) = DaemonServer::new_with_name_and_timestamp( + &daemon_name, + &root_path, + build_timestamp, + handler, + StartupReason::FirstStart, + 100, + ); + let join_handle = spawn(async move { + server.run().await.ok(); + }); + + sleep(Duration::from_millis(100)).await; + + let daemon_exe = PathBuf::from("./target/debug/examples/cli"); + + // Execute multiple commands - each should collect a fresh payload + for i in 0..3 { + let mut client = DaemonClient::::connect_with_name_and_timestamp( + &daemon_name, + &root_path, + daemon_exe.clone(), + build_timestamp, + ) + .await?; + + let result = client.execute_command(format!("command-{}", i)).await; + + assert!(result.is_ok(), "Command {} should succeed: {:?}", i, result); + assert_eq!(result.unwrap(), 0); + + sleep(Duration::from_millis(50)).await; + } + + // Cleanup + shutdown_handle.shutdown(); + let _ = tokio::time::timeout(Duration::from_secs(2), join_handle).await; + + Ok(()) +} + // ============================================================================ // UNIX-SPECIFIC TESTS // ============================================================================ diff --git a/tests/version_tests.rs b/tests/version_tests.rs index d81a3b6..a6d667b 100644 --- a/tests/version_tests.rs +++ b/tests/version_tests.rs @@ -55,11 +55,11 @@ async fn test_version_handshake_success() -> Result<()> { // Send version check client - .send_message(&SocketMessage::VersionCheck { build_timestamp }) + .send_message(&SocketMessage::<()>::VersionCheck { build_timestamp }) .await?; // Receive response - let response = client.receive_message::().await?; + let response = client.receive_message::>().await?; match response { Some(SocketMessage::VersionCheck { @@ -101,13 +101,13 @@ async fn test_version_mismatch_detection() -> Result<()> { // Send version check with mismatched timestamp client - .send_message(&SocketMessage::VersionCheck { + .send_message(&SocketMessage::<()>::VersionCheck { build_timestamp: client_build_timestamp, }) .await?; // Receive response - let response = client.receive_message::().await?; + let response = client.receive_message::>().await?; match response { Some(SocketMessage::VersionCheck { @@ -154,10 +154,10 @@ async fn test_multiple_version_handshakes() -> Result<()> { // Perform handshake client - .send_message(&SocketMessage::VersionCheck { build_timestamp }) + .send_message(&SocketMessage::<()>::VersionCheck { build_timestamp }) .await?; - let response = client.receive_message::().await?; + let response = client.receive_message::>().await?; match response { Some(SocketMessage::VersionCheck { @@ -203,10 +203,10 @@ async fn test_version_handshake_before_command() -> Result<()> { // First, perform version handshake client - .send_message(&SocketMessage::VersionCheck { build_timestamp }) + .send_message(&SocketMessage::<()>::VersionCheck { build_timestamp }) .await?; - let handshake_response = client.receive_message::().await?; + let handshake_response = client.receive_message::>().await?; assert!(matches!( handshake_response, Some(SocketMessage::VersionCheck { .. }) @@ -221,14 +221,14 @@ async fn test_version_handshake_before_command() -> Result<()> { theme: None, }; client - .send_message(&SocketMessage::Command { + .send_message(&SocketMessage::<()>::Command { command: "test command".to_string(), context: CommandContext::new(terminal_info), }) .await?; // Should receive output chunks - let output_response = client.receive_message::().await?; + let output_response = client.receive_message::>().await?; assert!(matches!( output_response, Some(SocketMessage::OutputChunk(_)) @@ -271,14 +271,14 @@ async fn test_command_without_handshake_fails() -> Result<()> { theme: None, }; client - .send_message(&SocketMessage::Command { + .send_message(&SocketMessage::<()>::Command { command: "test".to_string(), context: CommandContext::new(terminal_info), }) .await?; // Connection should close or we get no response - let response = client.receive_message::().await?; + let response = client.receive_message::>().await?; // Should either get None (connection closed) or the server ignores it // Based on our implementation, server expects VersionCheck first @@ -321,10 +321,10 @@ async fn test_concurrent_version_handshakes() -> Result<()> { let mut client = SocketClient::connect(&daemon_name_clone, &root_path_clone).await?; client - .send_message(&SocketMessage::VersionCheck { build_timestamp }) + .send_message(&SocketMessage::<()>::VersionCheck { build_timestamp }) .await?; - let response = client.receive_message::().await?; + let response = client.receive_message::>().await?; match response { Some(SocketMessage::VersionCheck { @@ -382,13 +382,13 @@ async fn test_version_mismatch_triggers_client_action() -> Result<()> { // Send version check with newer timestamp client - .send_message(&SocketMessage::VersionCheck { + .send_message(&SocketMessage::<()>::VersionCheck { build_timestamp: client_timestamp, }) .await?; // Server should respond with its own (older) timestamp - let response = client.receive_message::().await?; + let response = client.receive_message::>().await?; match response { Some(SocketMessage::VersionCheck { @@ -434,10 +434,10 @@ async fn test_multiple_commands_same_connection() -> Result<()> { // First, perform version handshake client - .send_message(&SocketMessage::VersionCheck { build_timestamp }) + .send_message(&SocketMessage::<()>::VersionCheck { build_timestamp }) .await?; - let handshake_response = client.receive_message::().await?; + let handshake_response = client.receive_message::>().await?; assert!(matches!( handshake_response, Some(SocketMessage::VersionCheck { .. }) @@ -452,21 +452,21 @@ async fn test_multiple_commands_same_connection() -> Result<()> { theme: None, }; client - .send_message(&SocketMessage::Command { + .send_message(&SocketMessage::<()>::Command { command: "first command".to_string(), context: CommandContext::new(terminal_info.clone()), }) .await?; // Receive first command output - let output1 = client.receive_message::().await?; + let output1 = client.receive_message::>().await?; assert!( matches!(output1, Some(SocketMessage::OutputChunk(_))), "Should receive output for first command" ); // Receive CommandComplete - let complete1 = client.receive_message::().await?; + let complete1 = client.receive_message::>().await?; assert!( matches!( complete1, @@ -478,14 +478,14 @@ async fn test_multiple_commands_same_connection() -> Result<()> { // Attempt to send a second command on the same connection // The server uses one-shot semantics: it closes after handling one command client - .send_message(&SocketMessage::Command { + .send_message(&SocketMessage::<()>::Command { command: "second command".to_string(), context: CommandContext::new(terminal_info.clone()), }) .await?; // The connection should be closed by the server, so we expect EOF (None) - let response = client.receive_message::().await?; + let response = client.receive_message::>().await?; assert!( response.is_none(), "Connection should be closed after first command (one-shot semantics), got: {:?}",