From fcde83ff1866e420fb80f69eeb93fa14433e8ff8 Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 15:23:54 -0800 Subject: [PATCH 01/12] Add connection limiter module --- src/connection.rs | 214 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 src/connection.rs diff --git a/src/connection.rs b/src/connection.rs new file mode 100644 index 0000000..38102c0 --- /dev/null +++ b/src/connection.rs @@ -0,0 +1,214 @@ +//! Connection management and limiting. +//! +//! Provides connection tracking and limits to prevent resource exhaustion. + +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::Semaphore; +use tracing::{debug, warn}; + +/// Configuration for connection limits. +#[derive(Debug, Clone)] +pub struct ConnectionConfig { + /// Maximum number of concurrent connections. + pub max_connections: usize, + /// Maximum number of connections per client IP. + pub max_connections_per_ip: usize, +} + +impl Default for ConnectionConfig { + fn default() -> Self { + Self { + max_connections: 10_000, + max_connections_per_ip: 100, + } + } +} + +impl ConnectionConfig { + /// Creates a new connection configuration. + pub fn new(max_connections: usize) -> Self { + Self { + max_connections, + ..Default::default() + } + } + + /// Sets the maximum connections per IP. + pub fn with_max_per_ip(mut self, max: usize) -> Self { + self.max_connections_per_ip = max; + self + } +} + +/// Connection limiter to prevent resource exhaustion. +/// +/// Uses a semaphore to limit the total number of concurrent connections. +#[derive(Debug)] +pub struct ConnectionLimiter { + /// Semaphore for limiting total connections. + semaphore: Arc, + /// Current number of active connections. + active_connections: AtomicUsize, + /// Total connections accepted. + total_accepted: AtomicUsize, + /// Total connections rejected due to limits. + total_rejected: AtomicUsize, + /// Configuration. + config: ConnectionConfig, +} + +impl ConnectionLimiter { + /// Creates a new connection limiter with the given configuration. + pub fn new(config: ConnectionConfig) -> Self { + Self { + semaphore: Arc::new(Semaphore::new(config.max_connections)), + active_connections: AtomicUsize::new(0), + total_accepted: AtomicUsize::new(0), + total_rejected: AtomicUsize::new(0), + config, + } + } + + /// Creates a connection limiter with default configuration. + pub fn with_defaults() -> Self { + Self::new(ConnectionConfig::default()) + } + + /// Attempts to acquire a connection permit. + /// + /// Returns `Some(ConnectionGuard)` if a connection is allowed, + /// or `None` if the limit has been reached. + pub fn try_acquire(&self) -> Option> { + match self.semaphore.clone().try_acquire_owned() { + Ok(permit) => { + self.active_connections.fetch_add(1, Ordering::Relaxed); + self.total_accepted.fetch_add(1, Ordering::Relaxed); + debug!( + active = self.active_connections.load(Ordering::Relaxed), + "connection acquired" + ); + Some(ConnectionGuard { + _permit: permit, + active_counter: &self.active_connections, + }) + } + Err(_) => { + self.total_rejected.fetch_add(1, Ordering::Relaxed); + warn!( + max = self.config.max_connections, + "connection limit reached, rejecting" + ); + None + } + } + } + + /// Returns the number of active connections. + pub fn active_connections(&self) -> usize { + self.active_connections.load(Ordering::Relaxed) + } + + /// Returns the total number of accepted connections. + pub fn total_accepted(&self) -> usize { + self.total_accepted.load(Ordering::Relaxed) + } + + /// Returns the total number of rejected connections. + pub fn total_rejected(&self) -> usize { + self.total_rejected.load(Ordering::Relaxed) + } + + /// Returns the maximum allowed connections. + pub fn max_connections(&self) -> usize { + self.config.max_connections + } + + /// Returns connection statistics. + pub fn stats(&self) -> ConnectionStats { + ConnectionStats { + active: self.active_connections.load(Ordering::Relaxed), + total_accepted: self.total_accepted.load(Ordering::Relaxed), + total_rejected: self.total_rejected.load(Ordering::Relaxed), + max_connections: self.config.max_connections, + } + } +} + +/// Guard that releases a connection permit when dropped. +pub struct ConnectionGuard<'a> { + _permit: tokio::sync::OwnedSemaphorePermit, + active_counter: &'a AtomicUsize, +} + +impl<'a> Drop for ConnectionGuard<'a> { + fn drop(&mut self) { + self.active_counter.fetch_sub(1, Ordering::Relaxed); + debug!( + active = self.active_counter.load(Ordering::Relaxed), + "connection released" + ); + } +} + +/// Statistics about connection usage. +#[derive(Debug, Clone)] +pub struct ConnectionStats { + /// Current number of active connections. + pub active: usize, + /// Total number of accepted connections. + pub total_accepted: usize, + /// Total number of rejected connections due to limits. + pub total_rejected: usize, + /// Maximum allowed connections. + pub max_connections: usize, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_connection_limiter_basic() { + let limiter = ConnectionLimiter::new(ConnectionConfig::new(2)); + + let _guard1 = limiter.try_acquire(); + let _guard2 = limiter.try_acquire(); + assert!(_guard1.is_some()); + assert!(_guard2.is_some()); + assert_eq!(limiter.active_connections(), 2); + + // Should reject - at limit + assert!(limiter.try_acquire().is_none()); + assert_eq!(limiter.total_rejected(), 1); + } + + #[test] + fn test_connection_guard_release() { + let limiter = ConnectionLimiter::new(ConnectionConfig::new(1)); + + { + let _guard = limiter.try_acquire(); + assert_eq!(limiter.active_connections(), 1); + } + + // Guard dropped, should be able to acquire again + assert_eq!(limiter.active_connections(), 0); + assert!(limiter.try_acquire().is_some()); + } + + #[test] + fn test_connection_stats() { + let limiter = ConnectionLimiter::new(ConnectionConfig::new(2)); + + let _guard1 = limiter.try_acquire(); + let _guard2 = limiter.try_acquire(); + let _ = limiter.try_acquire(); // Rejected + + let stats = limiter.stats(); + assert_eq!(stats.active, 2); + assert_eq!(stats.total_accepted, 2); + assert_eq!(stats.total_rejected, 1); + assert_eq!(stats.max_connections, 2); + } +} From 13fe07a358f733f3872dd98b997f756fecc19c9b Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 15:23:59 -0800 Subject: [PATCH 02/12] Add regex caching to router --- src/router.rs | 41 +++++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/src/router.rs b/src/router.rs index dd1ca94..6c801be 100644 --- a/src/router.rs +++ b/src/router.rs @@ -3,11 +3,43 @@ //! Provides flexible routing rules for directing traffic to different //! upstream clusters based on request attributes. +use once_cell::sync::Lazy; +use parking_lot::RwLock; use regex::Regex; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::sync::Arc; use tracing::{debug, warn}; +/// Global regex cache to avoid recompiling patterns on every request. +static REGEX_CACHE: Lazy>>> = + Lazy::new(|| RwLock::new(HashMap::new())); + +/// Gets or compiles a regex pattern, caching the result. +fn get_or_compile_regex(pattern: &str) -> Option> { + // Fast path: check if already cached + { + let cache = REGEX_CACHE.read(); + if let Some(regex) = cache.get(pattern) { + return Some(Arc::clone(regex)); + } + } + + // Slow path: compile and cache + match Regex::new(pattern) { + Ok(regex) => { + let regex = Arc::new(regex); + let mut cache = REGEX_CACHE.write(); + cache.insert(pattern.to_string(), Arc::clone(®ex)); + Some(regex) + } + Err(e) => { + warn!(pattern = %pattern, error = %e, "invalid regex pattern"); + None + } + } +} + /// Route matching priority (higher = evaluated first). #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum RoutePriority { @@ -67,13 +99,12 @@ impl HeaderMatch { .and_then(|v| v.to_str().ok()) .is_some_and(|v| v.contains(value.as_str())), HeaderMatch::Regex { name, pattern } => { - if let Ok(regex) = Regex::new(pattern) { + if let Some(regex) = get_or_compile_regex(pattern) { headers .get(name) .and_then(|v| v.to_str().ok()) .is_some_and(|v| regex.is_match(v)) } else { - warn!(pattern = %pattern, "invalid regex pattern"); false } } @@ -121,10 +152,9 @@ impl PathMatch { PathMatch::Exact { path: expected } => path == expected, PathMatch::Prefix { prefix } => path.starts_with(prefix), PathMatch::Regex { pattern } => { - if let Ok(regex) = Regex::new(pattern) { + if let Some(regex) = get_or_compile_regex(pattern) { regex.is_match(path) } else { - warn!(pattern = %pattern, "invalid regex pattern"); false } } @@ -339,8 +369,7 @@ impl Router { /// Sorts routes by priority (highest first). fn sort_routes(&mut self) { - self.routes - .sort_by_key(|r| std::cmp::Reverse(r.priority())); + self.routes.sort_by_key(|r| std::cmp::Reverse(r.priority())); } /// Finds the matching route for a request. From fdc0a9e58f143d095a5996d110a73f536d157b99 Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 15:24:03 -0800 Subject: [PATCH 03/12] Enhance metrics with RED methodology --- src/metrics.rs | 251 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 248 insertions(+), 3 deletions(-) diff --git a/src/metrics.rs b/src/metrics.rs index efb2c01..150fcc4 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1,13 +1,20 @@ //! Prometheus metrics collection and export. +//! +//! Provides comprehensive observability following RED methodology: +//! - **Rate**: Request rate and throughput +//! - **Errors**: Error counts and error rates +//! - **Duration**: Request latency histograms use once_cell::sync::Lazy; use prometheus_client::encoding::text::encode; use prometheus_client::encoding::EncodeLabelSet; use prometheus_client::metrics::counter::Counter; use prometheus_client::metrics::family::Family; +use prometheus_client::metrics::gauge::Gauge; use prometheus_client::metrics::histogram::{exponential_buckets, Histogram}; use prometheus_client::registry::Registry; use std::io; +use std::sync::atomic::AtomicI64; use std::sync::{Arc, Mutex}; /// Labels for HTTP request metrics. @@ -21,6 +28,29 @@ pub struct HttpLabels { pub upstream: String, } +/// Labels for circuit breaker metrics. +#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] +pub struct CircuitBreakerLabels { + /// Upstream or endpoint name + pub upstream: String, + /// Circuit breaker state (closed, open, half_open) + pub state: String, +} + +/// Labels for rate limiter metrics. +#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] +pub struct RateLimitLabels { + /// Type of rate limit (global, per_client) + pub limit_type: String, +} + +/// Labels for connection metrics. +#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] +pub struct ConnectionLabels { + /// Connection state or type + pub state: String, +} + /// Global metrics registry. /// /// Initialized once at startup and shared across all tasks. @@ -28,11 +58,24 @@ static METRICS: Lazy>> = Lazy::new(|| Arc::new(Mutex::new(Met /// Metrics collector for the proxy. /// -/// Tracks request counts, latencies, and upstream health. +/// Tracks request counts, latencies, circuit breaker states, rate limits, +/// and connection statistics following RED methodology. pub struct Metrics { registry: Registry, + // Request metrics (RED) requests_total: Family, request_duration_seconds: Family, + requests_in_flight: Gauge, + // Error metrics + errors_total: Family, + // Circuit breaker metrics + circuit_breaker_state: Family, + circuit_breaker_trips_total: Family, + // Rate limiter metrics + rate_limit_rejections_total: Family, + // Connection metrics + connections_total: Family, + connections_active: Gauge, } impl Metrics { @@ -40,6 +83,7 @@ impl Metrics { fn new() -> Self { let mut registry = Registry::default(); + // Request metrics let requests_total = Family::::default(); registry.register( "http_requests_total", @@ -49,7 +93,8 @@ impl Metrics { let request_duration_seconds = Family::::new_with_constructor(|| { - Histogram::new(exponential_buckets(0.001, 2.0, 10)) + // Buckets: 1ms, 2ms, 4ms, 8ms, 16ms, 32ms, 64ms, 128ms, 256ms, 512ms, 1s, 2s, 4s + Histogram::new(exponential_buckets(0.001, 2.0, 13)) }); registry.register( "http_request_duration_seconds", @@ -57,10 +102,70 @@ impl Metrics { request_duration_seconds.clone(), ); + let requests_in_flight = Gauge::::default(); + registry.register( + "http_requests_in_flight", + "Number of HTTP requests currently being processed", + requests_in_flight.clone(), + ); + + // Error metrics + let errors_total = Family::::default(); + registry.register( + "http_errors_total", + "Total number of HTTP errors (4xx and 5xx responses)", + errors_total.clone(), + ); + + // Circuit breaker metrics + let circuit_breaker_state = Family::::default(); + registry.register( + "circuit_breaker_state", + "Current state of circuit breakers (0=closed, 1=open, 2=half_open)", + circuit_breaker_state.clone(), + ); + + let circuit_breaker_trips_total = Family::::default(); + registry.register( + "circuit_breaker_trips_total", + "Total number of times circuit breakers have tripped", + circuit_breaker_trips_total.clone(), + ); + + // Rate limiter metrics + let rate_limit_rejections_total = Family::::default(); + registry.register( + "rate_limit_rejections_total", + "Total number of requests rejected due to rate limiting", + rate_limit_rejections_total.clone(), + ); + + // Connection metrics + let connections_total = Family::::default(); + registry.register( + "connections_total", + "Total number of connections by state", + connections_total.clone(), + ); + + let connections_active = Gauge::::default(); + registry.register( + "connections_active", + "Number of currently active connections", + connections_active.clone(), + ); + Self { registry, requests_total, request_duration_seconds, + requests_in_flight, + errors_total, + circuit_breaker_state, + circuit_breaker_trips_total, + rate_limit_rejections_total, + connections_total, + connections_active, } } @@ -69,7 +174,7 @@ impl Metrics { /// # Arguments /// /// * `method` - HTTP method (e.g., "GET", "POST") - /// * `status` - HTTP status code (e.g., "200", "404") + /// * `status` - HTTP status code (e.g., 200, 404) /// * `upstream` - Upstream server address /// * `duration_secs` - Request duration in seconds pub fn record_request(method: &str, status: u16, upstream: &str, duration_secs: f64) { @@ -85,6 +190,99 @@ impl Metrics { .request_duration_seconds .get_or_create(&labels) .observe(duration_secs); + + // Track errors (4xx and 5xx) + if status >= 400 { + metrics.errors_total.get_or_create(&labels).inc(); + } + } + } + + /// Increments the in-flight request counter. + pub fn inc_requests_in_flight() { + if let Ok(metrics) = METRICS.lock() { + metrics.requests_in_flight.inc(); + } + } + + /// Decrements the in-flight request counter. + pub fn dec_requests_in_flight() { + if let Ok(metrics) = METRICS.lock() { + metrics.requests_in_flight.dec(); + } + } + + /// Records a circuit breaker state change. + /// + /// # Arguments + /// + /// * `upstream` - The upstream endpoint name + /// * `state` - The new state (closed, open, half_open) + /// * `is_trip` - Whether this is a trip (closed -> open) + pub fn record_circuit_breaker_state(upstream: &str, state: &str, is_trip: bool) { + let labels = CircuitBreakerLabels { + upstream: upstream.to_string(), + state: state.to_string(), + }; + + if let Ok(metrics) = METRICS.lock() { + let state_value = match state { + "closed" => 0, + "open" => 1, + "half_open" => 2, + _ => 0, + }; + metrics + .circuit_breaker_state + .get_or_create(&labels) + .set(state_value); + + if is_trip { + metrics + .circuit_breaker_trips_total + .get_or_create(&labels) + .inc(); + } + } + } + + /// Records a rate limit rejection. + /// + /// # Arguments + /// + /// * `limit_type` - The type of limit (global, per_client) + pub fn record_rate_limit_rejection(limit_type: &str) { + let labels = RateLimitLabels { + limit_type: limit_type.to_string(), + }; + + if let Ok(metrics) = METRICS.lock() { + metrics + .rate_limit_rejections_total + .get_or_create(&labels) + .inc(); + } + } + + /// Records a connection event. + /// + /// # Arguments + /// + /// * `state` - The connection state (accepted, rejected, closed) + pub fn record_connection(state: &str) { + let labels = ConnectionLabels { + state: state.to_string(), + }; + + if let Ok(metrics) = METRICS.lock() { + metrics.connections_total.get_or_create(&labels).inc(); + } + } + + /// Sets the number of active connections. + pub fn set_active_connections(count: i64) { + if let Ok(metrics) = METRICS.lock() { + metrics.connections_active.set(count); } } @@ -120,6 +318,53 @@ mod tests { assert!(encoded.contains("http_request_duration_seconds")); } + #[test] + fn test_record_error() { + Metrics::record_request("GET", 500, "http://upstream:8080", 0.05); + Metrics::record_request("GET", 404, "http://upstream:8080", 0.1); + + let encoded = Metrics::encode().unwrap(); + assert!(encoded.contains("http_errors_total")); + } + + #[test] + fn test_requests_in_flight() { + Metrics::inc_requests_in_flight(); + Metrics::inc_requests_in_flight(); + Metrics::dec_requests_in_flight(); + + let encoded = Metrics::encode().unwrap(); + assert!(encoded.contains("http_requests_in_flight")); + } + + #[test] + fn test_circuit_breaker_metrics() { + Metrics::record_circuit_breaker_state("http://upstream:8080", "open", true); + + let encoded = Metrics::encode().unwrap(); + assert!(encoded.contains("circuit_breaker_state")); + assert!(encoded.contains("circuit_breaker_trips_total")); + } + + #[test] + fn test_rate_limit_metrics() { + Metrics::record_rate_limit_rejection("global"); + Metrics::record_rate_limit_rejection("per_client"); + + let encoded = Metrics::encode().unwrap(); + assert!(encoded.contains("rate_limit_rejections_total")); + } + + #[test] + fn test_connection_metrics() { + Metrics::record_connection("accepted"); + Metrics::set_active_connections(5); + + let encoded = Metrics::encode().unwrap(); + assert!(encoded.contains("connections_total")); + assert!(encoded.contains("connections_active")); + } + #[test] fn test_metrics_encoding() { let encoded = Metrics::encode(); From e4a5773473b2ddc58f701da6b98024b4b63968bd Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 15:24:07 -0800 Subject: [PATCH 04/12] Add configuration validation --- src/config.rs | 235 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 235 insertions(+) diff --git a/src/config.rs b/src/config.rs index 2c0675a..1bab42d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,8 +2,39 @@ use serde::{Deserialize, Serialize}; use std::env; +use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; +use thiserror::Error; + +/// Configuration validation errors. +#[derive(Error, Debug)] +#[allow(dead_code)] +pub enum ConfigError { + /// Invalid listen address format. + #[error("invalid listen address '{addr}': {reason}")] + InvalidListenAddr { addr: String, reason: String }, + + /// Invalid metrics address format. + #[error("invalid metrics address '{addr}': {reason}")] + InvalidMetricsAddr { addr: String, reason: String }, + + /// Invalid upstream address format. + #[error("invalid upstream address '{addr}': {reason}")] + InvalidUpstreamAddr { addr: String, reason: String }, + + /// No upstream addresses configured. + #[error("at least one upstream address is required")] + NoUpstreamAddrs, + + /// Invalid timeout value. + #[error("invalid timeout value: {reason}")] + InvalidTimeout { reason: String }, + + /// Duplicate listen and metrics addresses. + #[error("listen address and metrics address cannot be the same: {addr}")] + DuplicateAddrs { addr: String }, +} /// Proxy configuration loaded at startup. /// @@ -16,6 +47,7 @@ use std::time::Duration; /// * `PROXY_UPSTREAM_ADDRS` - Comma-separated upstream addresses (default: "http://127.0.0.1:8080") /// * `PROXY_METRICS_ADDR` - Metrics endpoint address (default: "127.0.0.1:9090") /// * `PROXY_REQUEST_TIMEOUT_MS` - Request timeout in milliseconds (default: 30000) +/// * `PROXY_MAX_CONNECTIONS` - Maximum concurrent connections (default: 10000) /// /// # Example /// @@ -37,6 +69,14 @@ pub struct ProxyConfig { /// Request timeout duration. pub request_timeout: Duration, + + /// Maximum concurrent connections. + #[serde(default = "default_max_connections")] + pub max_connections: usize, +} + +fn default_max_connections() -> usize { + 10_000 } impl Default for ProxyConfig { @@ -46,6 +86,7 @@ impl Default for ProxyConfig { upstream_addrs: vec!["http://127.0.0.1:8080".to_string()], metrics_addr: "127.0.0.1:9090".to_string(), request_timeout: Duration::from_secs(30), + max_connections: default_max_connections(), } } } @@ -80,12 +121,103 @@ impl ProxyConfig { .and_then(|s| s.parse::().ok()) .unwrap_or(30000); + let max_connections = env::var("PROXY_MAX_CONNECTIONS") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or_else(default_max_connections); + Self { listen_addr, upstream_addrs, metrics_addr, request_timeout: Duration::from_millis(request_timeout_ms), + max_connections, + } + } + + /// Loads configuration from environment variables and validates it. + /// + /// Returns an error if the configuration is invalid. + #[allow(dead_code)] + pub fn from_env_validated() -> Result { + let config = Self::from_env(); + config.validate()?; + Ok(config) + } + + /// Validates the configuration. + /// + /// # Errors + /// + /// Returns an error if: + /// - Listen address is not a valid socket address + /// - Metrics address is not a valid socket address + /// - No upstream addresses are configured + /// - Upstream addresses have invalid URL format + /// - Listen and metrics addresses are the same + /// - Timeout is zero or too large + #[allow(dead_code)] + pub fn validate(&self) -> Result<(), ConfigError> { + // Validate listen address + self.listen_addr + .parse::() + .map_err(|e| ConfigError::InvalidListenAddr { + addr: self.listen_addr.clone(), + reason: e.to_string(), + })?; + + // Validate metrics address + self.metrics_addr + .parse::() + .map_err(|e| ConfigError::InvalidMetricsAddr { + addr: self.metrics_addr.clone(), + reason: e.to_string(), + })?; + + // Check for duplicate addresses + if self.listen_addr == self.metrics_addr { + return Err(ConfigError::DuplicateAddrs { + addr: self.listen_addr.clone(), + }); + } + + // Validate upstream addresses + if self.upstream_addrs.is_empty() { + return Err(ConfigError::NoUpstreamAddrs); + } + + for addr in &self.upstream_addrs { + // Basic URL validation + if !addr.starts_with("http://") && !addr.starts_with("https://") { + return Err(ConfigError::InvalidUpstreamAddr { + addr: addr.clone(), + reason: "must start with http:// or https://".to_string(), + }); + } + + // Try to parse the URL + if url::Url::parse(addr).is_err() { + return Err(ConfigError::InvalidUpstreamAddr { + addr: addr.clone(), + reason: "invalid URL format".to_string(), + }); + } + } + + // Validate timeout + if self.request_timeout.is_zero() { + return Err(ConfigError::InvalidTimeout { + reason: "timeout must be greater than zero".to_string(), + }); + } + + if self.request_timeout > Duration::from_secs(3600) { + return Err(ConfigError::InvalidTimeout { + reason: "timeout must not exceed 1 hour".to_string(), + }); } + + Ok(()) } /// Wraps the upstream addresses in an `Arc` for shared ownership. @@ -110,6 +242,7 @@ mod tests { assert_eq!(config.listen_addr, "127.0.0.1:3000"); assert_eq!(config.upstream_addrs.len(), 1); assert_eq!(config.metrics_addr, "127.0.0.1:9090"); + assert_eq!(config.max_connections, 10_000); } #[test] @@ -125,4 +258,106 @@ mod tests { let config = ProxyConfig::default(); assert_eq!(config.timeout(), Duration::from_secs(30)); } + + #[test] + fn test_validate_valid_config() { + let config = ProxyConfig::default(); + assert!(config.validate().is_ok()); + } + + #[test] + fn test_validate_invalid_listen_addr() { + let config = ProxyConfig { + listen_addr: "invalid".to_string(), + ..Default::default() + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + ConfigError::InvalidListenAddr { .. } + )); + } + + #[test] + fn test_validate_invalid_metrics_addr() { + let config = ProxyConfig { + metrics_addr: "invalid".to_string(), + ..Default::default() + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + ConfigError::InvalidMetricsAddr { .. } + )); + } + + #[test] + fn test_validate_duplicate_addrs() { + let config = ProxyConfig { + listen_addr: "127.0.0.1:3000".to_string(), + metrics_addr: "127.0.0.1:3000".to_string(), + ..Default::default() + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + ConfigError::DuplicateAddrs { .. } + )); + } + + #[test] + fn test_validate_no_upstream() { + let config = ProxyConfig { + upstream_addrs: vec![], + ..Default::default() + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), ConfigError::NoUpstreamAddrs)); + } + + #[test] + fn test_validate_invalid_upstream() { + let config = ProxyConfig { + upstream_addrs: vec!["not-a-url".to_string()], + ..Default::default() + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + ConfigError::InvalidUpstreamAddr { .. } + )); + } + + #[test] + fn test_validate_zero_timeout() { + let config = ProxyConfig { + request_timeout: Duration::ZERO, + ..Default::default() + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + ConfigError::InvalidTimeout { .. } + )); + } + + #[test] + fn test_validate_excessive_timeout() { + let config = ProxyConfig { + request_timeout: Duration::from_secs(7200), + ..Default::default() + }; + let result = config.validate(); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + ConfigError::InvalidTimeout { .. } + )); + } } From cb8e54734c15b741e087828c2f6759a67c618cc8 Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 15:24:11 -0800 Subject: [PATCH 05/12] Add readiness endpoint to admin service --- src/admin.rs | 170 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 160 insertions(+), 10 deletions(-) diff --git a/src/admin.rs b/src/admin.rs index a64af58..9c4a3c7 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -1,49 +1,125 @@ -//! Admin endpoints for health checks and metrics. +//! Admin endpoints for health checks, readiness, and metrics. use crate::metrics::Metrics; use http::{Request, Response, StatusCode}; use http_body_util::{combinators::BoxBody, BodyExt, Full}; use hyper::body::{Bytes, Incoming}; +use serde::Serialize; use std::convert::Infallible; use std::future::Future; use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use std::task::{Context, Poll}; use tower::Service; use tracing::{debug, warn}; +/// Readiness state that can be shared across the application. +#[derive(Debug)] +pub struct ReadinessState { + ready: AtomicBool, +} + +impl ReadinessState { + /// Creates a new readiness state, initially not ready. + pub fn new() -> Self { + Self { + ready: AtomicBool::new(false), + } + } + + /// Sets the readiness state to ready. + pub fn set_ready(&self) { + self.ready.store(true, Ordering::SeqCst); + } + + /// Sets the readiness state to not ready. + pub fn set_not_ready(&self) { + self.ready.store(false, Ordering::SeqCst); + } + + /// Checks if the service is ready. + pub fn is_ready(&self) -> bool { + self.ready.load(Ordering::SeqCst) + } +} + +impl Default for ReadinessState { + fn default() -> Self { + Self::new() + } +} + +/// Health check response format. +#[derive(Debug, Serialize)] +struct HealthResponse { + status: &'static str, +} + +/// Readiness check response format. +#[derive(Debug, Serialize)] +struct ReadinessResponse { + ready: bool, + #[serde(skip_serializing_if = "Option::is_none")] + reason: Option, +} + /// Admin service for health checks and metrics endpoints. /// /// Serves: -/// - `/health` - Health check endpoint returning 200 OK +/// - `/health` - Liveness check endpoint returning 200 OK +/// - `/ready` - Readiness check endpoint (200 if ready, 503 if not) /// - `/metrics` - Prometheus metrics in text format /// /// # Example /// /// ```no_run /// use rust_servicemesh::admin::AdminService; +/// use std::sync::Arc; /// /// let service = AdminService::new(); +/// // Or with custom readiness state: +/// // let service = AdminService::with_readiness(Arc::new(ReadinessState::new())); /// ``` #[derive(Clone)] -pub struct AdminService; +pub struct AdminService { + readiness: Arc, +} impl AdminService { - /// Creates a new admin service. + /// Creates a new admin service with default readiness (starts as ready). pub fn new() -> Self { - Self + let readiness = Arc::new(ReadinessState::new()); + readiness.set_ready(); // Default to ready for backwards compatibility + Self { readiness } } - /// Handles admin requests for health and metrics endpoints. + /// Creates an admin service with custom readiness state. + pub fn with_readiness(readiness: Arc) -> Self { + Self { readiness } + } + + /// Returns the readiness state reference. + pub fn readiness(&self) -> &Arc { + &self.readiness + } + + /// Handles admin requests for health, readiness, and metrics endpoints. async fn handle_request( + readiness: Arc, req: Request, ) -> Result>, Infallible> { let path = req.uri().path(); match path { - "/health" => { + "/health" | "/healthz" => { debug!("health check requested"); Ok(Self::health_response()) } + "/ready" | "/readyz" => { + debug!("readiness check requested"); + Ok(Self::readiness_response(&readiness)) + } "/metrics" => { debug!("metrics requested"); match Metrics::encode() { @@ -61,12 +137,54 @@ impl AdminService { } } - /// Creates a health check response. + /// Creates a health check response (liveness probe). fn health_response() -> Response> { + let response = HealthResponse { status: "healthy" }; + let body = serde_json::to_string(&response) + .unwrap_or_else(|_| r#"{"status":"healthy"}"#.to_string()); + Response::builder() .status(StatusCode::OK) + .header("Content-Type", "application/json") .body( - Full::new(Bytes::from("healthy")) + Full::new(Bytes::from(body)) + .map_err(|never| match never {}) + .boxed(), + ) + .unwrap_or_else(|_| { + Response::new( + Full::new(Bytes::new()) + .map_err(|never| match never {}) + .boxed(), + ) + }) + } + + /// Creates a readiness check response. + fn readiness_response(readiness: &ReadinessState) -> Response> { + let is_ready = readiness.is_ready(); + let response = ReadinessResponse { + ready: is_ready, + reason: if is_ready { + None + } else { + Some("service not ready".to_string()) + }, + }; + let body = serde_json::to_string(&response) + .unwrap_or_else(|_| format!(r#"{{"ready":{}}}"#, is_ready)); + + let status = if is_ready { + StatusCode::OK + } else { + StatusCode::SERVICE_UNAVAILABLE + }; + + Response::builder() + .status(status) + .header("Content-Type", "application/json") + .body( + Full::new(Bytes::from(body)) .map_err(|never| match never {}) .boxed(), ) @@ -134,7 +252,8 @@ impl Service> for AdminService { } fn call(&mut self, req: Request) -> Self::Future { - Box::pin(Self::handle_request(req)) + let readiness = Arc::clone(&self.readiness); + Box::pin(Self::handle_request(readiness, req)) } } @@ -146,6 +265,37 @@ mod tests { fn test_health_response() { let response = AdminService::health_response(); assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("Content-Type").unwrap(), + "application/json" + ); + } + + #[test] + fn test_readiness_response_ready() { + let state = ReadinessState::new(); + state.set_ready(); + let response = AdminService::readiness_response(&state); + assert_eq!(response.status(), StatusCode::OK); + } + + #[test] + fn test_readiness_response_not_ready() { + let state = ReadinessState::new(); + let response = AdminService::readiness_response(&state); + assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); + } + + #[test] + fn test_readiness_state() { + let state = ReadinessState::new(); + assert!(!state.is_ready()); + + state.set_ready(); + assert!(state.is_ready()); + + state.set_not_ready(); + assert!(!state.is_ready()); } #[test] From 4cc8f8fb07a53b1e5b2e7e0931e501633b62601a Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 15:24:15 -0800 Subject: [PATCH 06/12] Add X-Forwarded-For header support --- src/ratelimit.rs | 188 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 187 insertions(+), 1 deletion(-) diff --git a/src/ratelimit.rs b/src/ratelimit.rs index 9c28e17..78fa7b5 100644 --- a/src/ratelimit.rs +++ b/src/ratelimit.rs @@ -2,13 +2,17 @@ //! //! Provides configurable rate limiting for incoming requests with support //! for multiple strategies: per-client, global, and per-route limiting. +//! +//! Supports X-Forwarded-For header parsing for clients behind proxies. use dashmap::DashMap; +use http::HeaderMap; use parking_lot::Mutex; use std::net::IpAddr; +use std::str::FromStr; use std::sync::Arc; use std::time::{Duration, Instant}; -use tracing::debug; +use tracing::{debug, warn}; /// Configuration for rate limiting. #[derive(Debug, Clone)] @@ -21,6 +25,10 @@ pub struct RateLimitConfig { pub per_client: bool, /// Time-to-live for client rate limit entries. pub client_ttl: Duration, + /// Whether to trust X-Forwarded-For headers. + pub trust_forwarded_for: bool, + /// Maximum number of IPs to consider from X-Forwarded-For chain. + pub max_forwarded_ips: usize, } impl Default for RateLimitConfig { @@ -30,6 +38,8 @@ impl Default for RateLimitConfig { burst_size: 50, per_client: true, client_ttl: Duration::from_secs(300), + trust_forwarded_for: false, + max_forwarded_ips: 1, } } } @@ -55,6 +65,25 @@ impl RateLimitConfig { self.client_ttl = ttl; self } + + /// Enables trusting X-Forwarded-For headers. + /// + /// **Security Warning**: Only enable this if the proxy is behind a trusted + /// load balancer that sets this header. Untrusted clients can spoof this + /// header to bypass rate limiting. + pub fn with_trust_forwarded_for(mut self, trust: bool) -> Self { + self.trust_forwarded_for = trust; + self + } + + /// Sets the maximum number of IPs to consider from X-Forwarded-For. + /// + /// When set to 1, only the rightmost (closest to proxy) IP is used. + /// Higher values can be used in multi-proxy setups. + pub fn with_max_forwarded_ips(mut self, max: usize) -> Self { + self.max_forwarded_ips = max.max(1); + self + } } /// Token bucket for rate limiting. @@ -172,6 +201,83 @@ impl RateLimiter { Self::new(RateLimitConfig::default()) } + /// Extracts the client IP from request headers. + /// + /// If `trust_forwarded_for` is enabled, attempts to parse the + /// X-Forwarded-For header. Falls back to the provided socket address. + /// + /// # Arguments + /// + /// * `headers` - HTTP request headers + /// * `socket_addr` - The direct socket connection address + pub fn extract_client_ip( + &self, + headers: &HeaderMap, + socket_addr: Option, + ) -> Option { + if self.config.trust_forwarded_for { + if let Some(forwarded) = headers.get("x-forwarded-for") { + if let Ok(value) = forwarded.to_str() { + // X-Forwarded-For format: client, proxy1, proxy2, ... + // We want the rightmost N IPs (closest to our proxy) + let ips: Vec<&str> = value.split(',').map(|s| s.trim()).collect(); + + // Take the first IP (original client) by default + // In production with trusted proxies, you might want the last one + // before your own proxy (ips.len() - max_forwarded_ips) + if let Some(ip_str) = ips.first() { + match IpAddr::from_str(ip_str) { + Ok(ip) => { + debug!(ip = %ip, "using X-Forwarded-For client IP"); + return Some(ip); + } + Err(e) => { + warn!( + header = %value, + error = %e, + "invalid IP in X-Forwarded-For header" + ); + } + } + } + } + } + + // Also check X-Real-IP header (used by nginx) + if let Some(real_ip) = headers.get("x-real-ip") { + if let Ok(value) = real_ip.to_str() { + match IpAddr::from_str(value.trim()) { + Ok(ip) => { + debug!(ip = %ip, "using X-Real-IP client IP"); + return Some(ip); + } + Err(e) => { + warn!( + header = %value, + error = %e, + "invalid IP in X-Real-IP header" + ); + } + } + } + } + } + + socket_addr + } + + /// Checks if a request should be allowed using headers to determine client IP. + /// + /// Returns `Ok(())` if allowed, `Err(RateLimitInfo)` if rate limited. + pub fn check_with_headers( + &self, + headers: &HeaderMap, + socket_addr: Option, + ) -> Result<(), RateLimitInfo> { + let client_ip = self.extract_client_ip(headers, socket_addr); + self.check(client_ip) + } + /// Checks if a request should be allowed. /// /// Returns `Ok(())` if allowed, `Err(RateLimitInfo)` if rate limited. @@ -314,6 +420,7 @@ impl Clone for RateLimitLayer { #[cfg(test)] mod tests { use super::*; + use http::HeaderMap; use std::net::Ipv4Addr; #[test] @@ -401,4 +508,83 @@ mod tests { assert_eq!(stats.burst_size, 50); assert_eq!(stats.client_count, 0); } + + #[test] + fn test_extract_client_ip_no_trust() { + let config = RateLimitConfig::new(100, 50).with_trust_forwarded_for(false); + let limiter = RateLimiter::new(config); + + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-for", "1.2.3.4".parse().unwrap()); + + let socket_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)); + let result = limiter.extract_client_ip(&headers, Some(socket_ip)); + + // Should ignore X-Forwarded-For when trust is disabled + assert_eq!(result, Some(socket_ip)); + } + + #[test] + fn test_extract_client_ip_trust_enabled() { + let config = RateLimitConfig::new(100, 50).with_trust_forwarded_for(true); + let limiter = RateLimiter::new(config); + + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-for", "1.2.3.4, 5.6.7.8".parse().unwrap()); + + let socket_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)); + let result = limiter.extract_client_ip(&headers, Some(socket_ip)); + + // Should use the first IP from X-Forwarded-For + assert_eq!(result, Some(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)))); + } + + #[test] + fn test_extract_client_ip_x_real_ip() { + let config = RateLimitConfig::new(100, 50).with_trust_forwarded_for(true); + let limiter = RateLimiter::new(config); + + let mut headers = HeaderMap::new(); + headers.insert("x-real-ip", "10.0.0.1".parse().unwrap()); + + let socket_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)); + let result = limiter.extract_client_ip(&headers, Some(socket_ip)); + + // Should use X-Real-IP + assert_eq!(result, Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))); + } + + #[test] + fn test_extract_client_ip_invalid_header() { + let config = RateLimitConfig::new(100, 50).with_trust_forwarded_for(true); + let limiter = RateLimiter::new(config); + + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-for", "not-an-ip".parse().unwrap()); + + let socket_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)); + let result = limiter.extract_client_ip(&headers, Some(socket_ip)); + + // Should fall back to socket address when header is invalid + assert_eq!(result, Some(socket_ip)); + } + + #[test] + fn test_check_with_headers() { + let config = RateLimitConfig::new(100, 5) + .with_per_client(true) + .with_trust_forwarded_for(true); + let limiter = RateLimiter::new(config); + + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-for", "1.2.3.4".parse().unwrap()); + + // Should allow requests up to burst size + for _ in 0..5 { + assert!(limiter.check_with_headers(&headers, None).is_ok()); + } + + // Should be rate limited + assert!(limiter.check_with_headers(&headers, None).is_err()); + } } From bfa336fedeeb66aaf583135becbb9c81a0693a43 Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 15:24:19 -0800 Subject: [PATCH 07/12] Update lib.rs module exports --- src/lib.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 6516412..c516e65 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,9 +63,11 @@ //! - `PROXY_REQUEST_TIMEOUT_MS`: Request timeout in milliseconds (default: 30000) pub mod admin; +pub use admin::ReadinessState; pub mod admin_listener; pub mod circuit_breaker; pub mod config; +pub mod connection; pub mod error; pub mod listener; pub mod metrics; @@ -78,7 +80,8 @@ pub mod transport; // Re-export commonly used types pub use circuit_breaker::{CircuitBreaker, CircuitBreakerConfig, State as CircuitBreakerState}; -pub use config::ProxyConfig; +pub use config::{ConfigError, ProxyConfig}; +pub use connection::{ConnectionConfig, ConnectionLimiter, ConnectionStats}; pub use error::{ProxyError, Result}; pub use listener::Listener; pub use protocol::{HttpProtocol, TlsConfig}; From 61bc8ab89fb9a075e0d79455b607e47a760a27c0 Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 15:24:23 -0800 Subject: [PATCH 08/12] Add connection module declaration --- src/main.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main.rs b/src/main.rs index 4169337..1af3478 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,8 @@ mod admin_listener; mod circuit_breaker; mod config; #[allow(dead_code)] +mod connection; +#[allow(dead_code)] mod error; #[allow(dead_code)] mod listener; From 518945bb1eafb341e112649e8677bcc3c43c973e Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 15:24:27 -0800 Subject: [PATCH 09/12] Add url dependency --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index f62c17f..c86f59a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,7 @@ parking_lot = "0.12" arc-swap = "1.7" regex = "1.10" rand = "0.8" +url = "2.5" [dev-dependencies] tokio-test = "0.4" From c08ff919fd078fc70a709057a79a8cb77df52f43 Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 15:24:31 -0800 Subject: [PATCH 10/12] Remove unused import from benchmarks --- benches/proxy_benchmark.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/benches/proxy_benchmark.rs b/benches/proxy_benchmark.rs index acf91cf..307991d 100644 --- a/benches/proxy_benchmark.rs +++ b/benches/proxy_benchmark.rs @@ -6,7 +6,6 @@ use rust_servicemesh::ratelimit::{RateLimitConfig, RateLimiter}; use rust_servicemesh::retry::{RetryConfig, RetryPolicy}; use rust_servicemesh::router::{PathMatch, Route, Router}; use std::net::{IpAddr, Ipv4Addr}; -use std::time::Duration; fn bench_circuit_breaker(c: &mut Criterion) { let rt = tokio::runtime::Runtime::new().unwrap(); From 194504bfd8fd1a23c789ff5fa4f0905e39eb2ea9 Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 15:24:36 -0800 Subject: [PATCH 11/12] Format retry.rs --- src/retry.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/retry.rs b/src/retry.rs index e39fe92..88cbe74 100644 --- a/src/retry.rs +++ b/src/retry.rs @@ -369,9 +369,7 @@ mod tests { #[tokio::test] async fn test_retry_executor_success() { let mut executor = RetryExecutor::with_defaults(); - let result = executor - .execute(|| async { Ok::(42) }) - .await; + let result = executor.execute(|| async { Ok::(42) }).await; assert!(result.is_ok()); assert_eq!(result.unwrap(), 42); From 9e1630bc307a09d83353fde98c56f92ba424ceed Mon Sep 17 00:00:00 2001 From: Hugh Date: Wed, 24 Dec 2025 21:48:37 -0800 Subject: [PATCH 12/12] Enhance .gitignore with security patterns --- .gitignore | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/.gitignore b/.gitignore index e13ea56..efd26d4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,32 @@ /target/ Cargo.lock + +# Editor and IDE *.swp *.swo *~ .DS_Store +.idea/ +.vscode/ +*.iml + +# Environment and secrets +.env +.env.* +*.pem +*.key +*.p12 +*.pfx +secrets/ +credentials/ + +# Logs +*.log +logs/ + +# Coverage +*.profraw +*.profdata +coverage/ +lcov.info +tarpaulin-report.html