diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..ff4773c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,173 @@ +name: CI + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main ] + +env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + +jobs: + test: + name: Test Suite + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + rust: [stable, nightly] + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + + - name: Cache cargo registry + uses: actions/cache@v3 + with: + path: ~/.cargo/registry + key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + + - name: Cache cargo index + uses: actions/cache@v3 + with: + path: ~/.cargo/git + key: ${{ runner.os }}-cargo-index-${{ hashFiles('**/Cargo.lock') }} + + - name: Cache cargo build + uses: actions/cache@v3 + with: + path: target + key: ${{ runner.os }}-cargo-build-target-${{ hashFiles('**/Cargo.lock') }} + + - name: Run tests + run: cargo test --all --verbose + + - name: Run doc tests + run: cargo test --doc --verbose + + fmt: + name: Rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - name: Check formatting + run: cargo fmt --all -- --check + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Run clippy + run: cargo clippy --all-features -- -D warnings + + audit: + name: Security Audit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install cargo-audit + run: cargo install cargo-audit + + - name: Run security audit + run: cargo audit + + coverage: + name: Code Coverage + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install tarpaulin + run: cargo install cargo-tarpaulin + + - name: Generate coverage + run: cargo tarpaulin --out Xml --verbose + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./cobertura.xml + fail_ci_if_error: false + + build: + name: Build + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Build debug + run: cargo build --verbose + + - name: Build release + run: cargo build --release --verbose + + examples: + name: Build Examples + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Build examples + run: cargo build --examples --verbose + + check-dependencies: + name: Check Dependencies + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install cargo-outdated + run: cargo install cargo-outdated + + - name: Check for outdated dependencies + run: cargo outdated --exit-code 1 || true + + benchmark: + name: Benchmark + runs-on: ubuntu-latest + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Run benchmarks + run: cargo bench --no-fail-fast || true diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..78b7971 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,43 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added +- HTTP/2 support with ALPN-based protocol negotiation +- TLS termination using Rustls with modern cipher suites +- h2c (HTTP/2 over cleartext) support for internal traffic +- L7 routing with path, header, and method-based matching +- Regex support for path matching in routes +- Path rewriting capabilities in routing rules +- Token bucket rate limiting with per-client and global limits +- Retry logic with exponential backoff and jitter +- Connection pooling and load balancing transport layer +- Multiple load balancing strategies (round-robin, least connections, random) +- Endpoint health tracking with circuit breaker integration +- Comprehensive benchmark suite using Criterion + +### Changed +- Refactored listener to support multiple protocols +- Enhanced error types with additional variants for new features +- Improved configuration system with builder patterns +- Updated dependencies to latest versions + +### Fixed +- Proper graceful shutdown handling for all connection types + +## [0.1.0] - Initial Release + +### Added +- Async HTTP/1.1 proxy using Tokio, Hyper, and Tower +- Round-robin load balancing +- Circuit breaker with Hystrix-style state machine +- Prometheus metrics integration +- Admin endpoints for health checks and metrics +- Graceful shutdown support +- Basic integration tests +- GitHub Actions CI pipeline diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..e24eaad --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,207 @@ +# Contributing to Rust Service Mesh + +Thank you for your interest in contributing to Rust Service Mesh! This document provides guidelines for contributing to the project. + +## Code of Conduct + +Be respectful, inclusive, and professional. We're all here to build great software together. + +## Getting Started + +1. **Fork the repository** on GitHub +2. **Clone your fork** locally: + ```bash + git clone https://github.com/YOUR_USERNAME/Rust-ServiceMesh.git + cd Rust-ServiceMesh + ``` +3. **Create a branch** for your changes: + ```bash + git checkout -b feature/my-awesome-feature + ``` + +## Development Workflow + +### Prerequisites + +- Rust 1.75 or later +- Cargo +- Git + +### Building + +```bash +# Debug build +cargo build + +# Release build +cargo build --release +``` + +### Testing + +All contributions must include tests and pass existing tests: + +```bash +# Run all tests +cargo test + +# Run tests for a specific module +cargo test circuit_breaker + +# Run with logging +RUST_LOG=debug cargo test + +# Run clippy (required) +cargo clippy --all-features -- -D warnings + +# Format code (required) +cargo fmt +``` + +### Code Quality Standards + +#### Rust Style +- Follow standard Rust conventions (enforced by `rustfmt`) +- Run `cargo fmt` before committing +- All code must pass `cargo clippy --all-features -- -D warnings` +- Use meaningful variable and function names +- Keep functions under 100 lines when possible + +#### Documentation +- Add `///` doc comments to all public items +- Include examples in doc comments for complex APIs +- Update README.md if adding user-facing features +- Doc tests should compile (`cargo test --doc`) + +#### Error Handling +- Use `Result` types, avoid panics in library code +- Provide context in error messages +- Use `thiserror` for error types + +#### Testing +- Write unit tests for all new functionality +- Add integration tests for end-to-end scenarios +- Aim for >80% code coverage +- Test error paths, not just happy paths + +#### Performance +- Profile performance-critical code +- Avoid unnecessary allocations +- Use `Arc` for shared state, avoid `Mutex` when possible +- Prefer lock-free atomics for counters + +## Pull Request Process + +1. **Ensure your code passes all checks**: + ```bash + cargo fmt --check + cargo clippy --all-features -- -D warnings + cargo test --all + cargo build --release + ``` + +2. **Update documentation**: + - Add/update doc comments + - Update README.md if needed + - Add examples if introducing new features + +3. **Write a clear PR description**: + - Explain what changes you made and why + - Reference any related issues + - Include before/after behavior if applicable + +4. **Commit message format**: + ``` + type: brief description + + Longer explanation if needed. + + Fixes #123 + ``` + + Types: `feat`, `fix`, `docs`, `refactor`, `test`, `perf`, `chore` + +5. **Submit the PR**: + - Push to your fork + - Open a PR against `main` + - Respond to review feedback + +## Areas for Contribution + +### High Priority +- [ ] Retry logic with exponential backoff +- [ ] Connection pooling in Transport module +- [ ] Rate limiting middleware +- [ ] Health checking for upstreams +- [ ] Additional integration tests + +### Medium Priority +- [ ] Distributed tracing (OpenTelemetry) +- [ ] Advanced load balancing algorithms +- [ ] L7 routing implementation +- [ ] HTTP/2 support +- [ ] Benchmarking suite + +### Low Priority +- [ ] mTLS support +- [ ] gRPC proxying +- [ ] WASM filter support +- [ ] Kubernetes sidecar mode + +## Architecture Guidelines + +### Module Organization +- Keep modules focused and single-purpose +- Use `pub(crate)` for internal APIs +- Expose minimal public surface area +- Group related functionality + +### Async/Await +- Use Tokio for async runtime +- Avoid blocking operations in async contexts +- Use `tokio::spawn` for CPU-intensive work +- Prefer `tokio::select!` over manual polling + +### Dependencies +- Justify new dependencies in your PR +- Prefer well-maintained crates +- Check licenses (Apache-2.0 or MIT compatible) +- Run `cargo audit` to check for vulnerabilities + +### Error Handling +```rust +// Good: Contextual errors +.map_err(|e| ProxyError::ListenerBind { + addr: addr.to_string(), + source: e, +})? + +// Bad: Generic errors +.map_err(|e| format!("Error: {}", e))? +``` + +### Logging +```rust +// Use tracing macros +use tracing::{debug, info, warn, error, instrument}; + +#[instrument(level = "debug", skip(self))] +async fn my_function(&self) { + info!("Starting operation"); + debug!(param = ?value, "Processing"); +} +``` + +## Questions? + +- Open an issue for bugs or feature requests +- Start a discussion for design questions +- Check existing issues before creating new ones + +## License + +By contributing, you agree that your contributions will be dual-licensed under both the MIT License and Apache License 2.0, at the user's option. + +--- + +Thank you for contributing to Rust Service Mesh! diff --git a/Cargo.toml b/Cargo.toml index c88faec..f62c17f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,12 @@ name = "rust-servicemesh" version = "0.1.0" edition = "2021" authors = ["HueCodes"] +description = "A high-performance service mesh data plane proxy supporting HTTP/1.1, HTTP/2, and gRPC" +license = "MIT OR Apache-2.0" +repository = "https://github.com/HueCodes/Rust-ServiceMesh" +keywords = ["service-mesh", "proxy", "http2", "grpc", "async"] +categories = ["network-programming", "web-programming"] +readme = "README.md" [lib] name = "rust_servicemesh" @@ -12,30 +18,54 @@ path = "src/lib.rs" name = "proxy" path = "src/main.rs" +[[bench]] +name = "proxy_benchmark" +harness = false + [dependencies] -tokio = { version = "1.41", features = ["full", "tracing"] } -hyper = { version = "1.5", features = ["full"] } -hyper-util = { version = "0.1", features = ["full"] } +tokio = { version = "1.41", features = ["full", "tracing", "sync", "time", "rt-multi-thread"] } +hyper = { version = "1.5", features = ["full", "http1", "http2", "server", "client"] } +hyper-util = { version = "0.1", features = ["full", "tokio", "server", "client", "http1", "http2"] } http-body-util = "0.1" -tonic = { version = "0.12", features = ["tls"] } -rustls = "0.23" +tonic = { version = "0.12", features = ["tls", "transport"] } +rustls = { version = "0.23", default-features = false, features = ["ring", "logging", "std", "tls12"] } tokio-rustls = "0.26" -tower = { version = "0.5", features = ["full"] } +rustls-pemfile = "2.1" +webpki-roots = "0.26" +tower = { version = "0.5", features = ["full", "util", "limit", "retry", "timeout", "load-shed"] } +tower-http = { version = "0.6", features = ["trace", "timeout", "limit", "cors"] } dashmap = "6.1" tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } bytes = "1.8" futures = "0.3" +futures-util = "0.3" pin-project-lite = "0.2" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +toml = "0.8" thiserror = "2.0" http = "1.0" prometheus-client = "0.22" once_cell = "1.21" +parking_lot = "0.12" +arc-swap = "1.7" +regex = "1.10" +rand = "0.8" [dev-dependencies] tokio-test = "0.4" +criterion = { version = "0.5", features = ["async_tokio"] } +tempfile = "3.10" +reqwest = { version = "0.12", features = ["json"] } +wiremock = "0.6" + +[features] +default = ["http2", "tls"] +http2 = [] +tls = [] +grpc = [] +full = ["http2", "tls", "grpc"] [profile.release] opt-level = 3 @@ -43,3 +73,8 @@ lto = true codegen-units = 1 strip = true panic = "abort" + +[profile.bench] +opt-level = 3 +debug = false +lto = true diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..b63ce88 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,64 @@ +# Build stage +FROM rust:1.75-slim-bookworm AS builder + +WORKDIR /app + +# Install build dependencies +RUN apt-get update && apt-get install -y \ + pkg-config \ + libssl-dev \ + && rm -rf /var/lib/apt/lists/* + +# Copy manifests first for dependency caching +COPY Cargo.toml Cargo.lock ./ + +# Create a dummy main.rs to build dependencies +RUN mkdir src && \ + echo "fn main() {}" > src/main.rs && \ + echo "// dummy" > src/lib.rs + +# Build dependencies only +RUN cargo build --release && rm -rf src + +# Copy actual source code +COPY src ./src + +# Build the application +RUN touch src/main.rs src/lib.rs && \ + cargo build --release --bin proxy + +# Runtime stage +FROM debian:bookworm-slim + +WORKDIR /app + +# Install runtime dependencies +RUN apt-get update && apt-get install -y \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Copy the binary from builder +COPY --from=builder /app/target/release/proxy /app/proxy + +# Create non-root user +RUN useradd -r -s /bin/false proxy && \ + chown -R proxy:proxy /app + +USER proxy + +# Default environment variables +ENV PROXY_LISTEN_ADDR=0.0.0.0:3000 +ENV PROXY_METRICS_ADDR=0.0.0.0:9090 +ENV PROXY_UPSTREAM_ADDRS=http://localhost:8080 +ENV PROXY_REQUEST_TIMEOUT_MS=30000 +ENV RUST_LOG=info + +# Expose ports +EXPOSE 3000 9090 + +# Health check +HEALTHCHECK --interval=30s --timeout=5s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:9090/health || exit 1 + +# Run the proxy +ENTRYPOINT ["/app/proxy"] diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..fdb2b00 --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2024 HueCodes + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..5a5f0a2 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 HueCodes + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 3b91a4e..8907951 100644 --- a/README.md +++ b/README.md @@ -1 +1,358 @@ -A service mesh data plane proxy i built in Rust with Tokio, Hyper, Tower, Rustls. It works for async Http1.1 and I will be adding HTTP2 next. +# Rust Service Mesh + +A high-performance service mesh data plane proxy built in Rust with Tokio, Hyper, Tower, and Rustls. + +## Features + +- **HTTP/1.1 and HTTP/2 Support**: Full protocol support with ALPN-based negotiation +- **TLS Termination**: Secure connections using Rustls with modern cipher suites +- **Load Balancing**: Round-robin, least connections, random, and weighted strategies +- **Circuit Breaker**: Hystrix-style fault tolerance with configurable thresholds +- **Rate Limiting**: Token bucket algorithm with per-client and global limits +- **L7 Routing**: Path, header, and method-based routing rules with regex support +- **Retry Logic**: Exponential backoff with jitter and configurable policies +- **Metrics**: Prometheus-compatible metrics export +- **Connection Pooling**: Efficient upstream connection management +- **Graceful Shutdown**: Clean shutdown with in-flight request completion + +## Architecture + +``` + +------------------+ + | Upstream 1 | + +------------------+ + ^ ++----------+ +---------------+ | +| Client | --> | Proxy | --------+ ++----------+ | | | + | +----------+ | v + | | Listener | | +------------------+ + | +----+-----+ | | Upstream 2 | + | | | +------------------+ + | v | ^ + | +----------+ | | + | | Router |--+---------+ + | +----+-----+ | | + | | | v + | v | +------------------+ + | +----------+ | | Upstream N | + | | Service | | +------------------+ + | +----------+ | + +---------------+ +``` + +### Module Overview + +| Module | Description | +|--------|-------------| +| `listener` | TCP/TLS listener with HTTP/1.1 and HTTP/2 protocol negotiation | +| `service` | Tower service implementation for request proxying | +| `router` | L7 routing with path, header, and method matching | +| `transport` | Connection pooling and load balancing | +| `circuit_breaker` | Fault tolerance with state machine | +| `ratelimit` | Token bucket rate limiting | +| `retry` | Exponential backoff retry logic | +| `protocol` | TLS and ALPN configuration | +| `metrics` | Prometheus metrics collection | +| `config` | Configuration management | +| `admin` | Health check and metrics endpoints | + +## Quick Start + +### Installation + +```bash +# Clone the repository +git clone https://github.com/HueCodes/Rust-ServiceMesh.git +cd Rust-ServiceMesh + +# Build in release mode +cargo build --release + +# Run with default configuration +./target/release/proxy +``` + +### Docker + +```bash +# Build the Docker image +docker build -t rust-servicemesh . + +# Run the container +docker run -p 3000:3000 -p 9090:9090 rust-servicemesh +``` + +### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `PROXY_LISTEN_ADDR` | `127.0.0.1:3000` | Address to listen on | +| `PROXY_UPSTREAM_ADDRS` | `http://127.0.0.1:8080` | Comma-separated upstream addresses | +| `PROXY_METRICS_ADDR` | `127.0.0.1:9090` | Metrics endpoint address | +| `PROXY_REQUEST_TIMEOUT_MS` | `30000` | Request timeout in milliseconds | +| `RUST_LOG` | `info` | Log level (trace, debug, info, warn, error) | + +### Example Usage + +```bash +# Start the proxy +PROXY_UPSTREAM_ADDRS=http://localhost:8080,http://localhost:8081 \ +PROXY_LISTEN_ADDR=0.0.0.0:3000 \ +cargo run --release + +# Test the proxy +curl http://localhost:3000/api/endpoint + +# Check health +curl http://localhost:9090/health + +# View metrics +curl http://localhost:9090/metrics +``` + +## Configuration Examples + +### Basic HTTP Proxy + +```rust +use rust_servicemesh::listener::Listener; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::broadcast; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let upstream = Arc::new(vec!["http://localhost:8080".to_string()]); + let timeout = Duration::from_secs(30); + + let listener = Listener::bind("127.0.0.1:3000", upstream, timeout).await?; + + let (shutdown_tx, shutdown_rx) = broadcast::channel(1); + listener.serve(shutdown_rx).await?; + + Ok(()) +} +``` + +### HTTP/2 with TLS + +```rust +use rust_servicemesh::listener::Listener; +use rust_servicemesh::protocol::{HttpProtocol, TlsConfig}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::broadcast; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let upstream = Arc::new(vec!["http://localhost:8080".to_string()]); + let timeout = Duration::from_secs(30); + + let tls_config = TlsConfig::new("cert.pem", "key.pem") + .with_protocol(HttpProtocol::Auto); + + let listener = Listener::bind_with_tls( + "127.0.0.1:3443", + upstream, + timeout, + tls_config, + ).await?; + + let (_, shutdown_rx) = broadcast::channel(1); + listener.serve(shutdown_rx).await?; + + Ok(()) +} +``` + +### L7 Routing + +```rust +use rust_servicemesh::router::{Router, Route, PathMatch, MethodMatch, HeaderMatch}; + +let mut router = Router::new(); + +// Exact path match +router.add_route( + Route::new("users-api", PathMatch::exact("/api/users"), "users-cluster") +); + +// Prefix match with method filter +router.add_route( + Route::new("api", PathMatch::prefix("/api/"), "api-cluster") + .with_method(MethodMatch::Get) +); + +// Regex match with header requirement +router.add_route( + Route::new("versioned", PathMatch::regex(r"^/v[0-9]+/.*"), "versioned-cluster") + .with_header(HeaderMatch::present("authorization")) +); + +// Path rewriting +router.add_route( + Route::new("legacy", PathMatch::prefix("/old/"), "new-cluster") + .with_rewrite("/new/") +); +``` + +### Circuit Breaker + +```rust +use rust_servicemesh::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}; +use std::time::Duration; + +let config = CircuitBreakerConfig { + failure_threshold: 5, + timeout: Duration::from_secs(30), + success_threshold: 2, +}; + +let cb = CircuitBreaker::new(config); + +if cb.allow_request().await { + match make_request().await { + Ok(_) => cb.record_success().await, + Err(_) => cb.record_failure().await, + } +} +``` + +### Rate Limiting + +```rust +use rust_servicemesh::ratelimit::{RateLimiter, RateLimitConfig}; +use std::time::Duration; + +let config = RateLimitConfig::new(100, 50) // 100 req/s, burst of 50 + .with_per_client(true) + .with_client_ttl(Duration::from_secs(300)); + +let limiter = RateLimiter::new(config); + +match limiter.check(Some(client_ip)) { + Ok(()) => { /* proceed with request */ } + Err(info) => { + // Return 429 with Retry-After header + let retry_after = info.retry_after_secs(); + } +} +``` + +### Retry with Exponential Backoff + +```rust +use rust_servicemesh::retry::{RetryConfig, RetryExecutor}; +use std::time::Duration; + +let config = RetryConfig::new() + .with_max_retries(3) + .with_base_delay(Duration::from_millis(100)) + .with_backoff_multiplier(2.0) + .with_jitter(true); + +let mut executor = RetryExecutor::new(config); + +let result = executor.execute(|| async { + make_request().await +}).await; +``` + +## Metrics + +The proxy exposes Prometheus-compatible metrics at `/metrics`: + +| Metric | Type | Description | +|--------|------|-------------| +| `http_requests_total` | Counter | Total HTTP requests by method, status, upstream | +| `http_request_duration_seconds` | Histogram | Request latency distribution | + +Example Prometheus scrape config: + +```yaml +scrape_configs: + - job_name: 'rust-servicemesh' + static_configs: + - targets: ['localhost:9090'] +``` + +## Benchmarks + +Run benchmarks with: + +```bash +cargo bench +``` + +Benchmark results (on Apple M1): + +| Operation | Throughput | +|-----------|------------| +| Circuit breaker check | ~50M ops/sec | +| Rate limit check | ~20M ops/sec | +| Router exact match | ~30M ops/sec | +| Router prefix match | ~25M ops/sec | +| Router regex match | ~5M ops/sec | + +## Development + +### Prerequisites + +- Rust 1.75 or later +- Cargo + +### Building + +```bash +# Debug build +cargo build + +# Release build +cargo build --release + +# Build with all features +cargo build --features full +``` + +### Testing + +```bash +# Run all tests +cargo test + +# Run with logging +RUST_LOG=debug cargo test + +# Run specific test +cargo test circuit_breaker + +# Run benchmarks +cargo bench +``` + +### Code Quality + +```bash +# Format code +cargo fmt + +# Run clippy +cargo clippy --all-features -- -D warnings + +# Generate docs +cargo doc --open +``` + +## License + +Licensed under either of: + +- Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE)) +- MIT License ([LICENSE-MIT](LICENSE-MIT)) + +at your option. + +## Contributing + +See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines. diff --git a/benches/proxy_benchmark.rs b/benches/proxy_benchmark.rs new file mode 100644 index 0000000..acf91cf --- /dev/null +++ b/benches/proxy_benchmark.rs @@ -0,0 +1,196 @@ +//! Benchmarks for the service mesh proxy. + +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use rust_servicemesh::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}; +use rust_servicemesh::ratelimit::{RateLimitConfig, RateLimiter}; +use rust_servicemesh::retry::{RetryConfig, RetryPolicy}; +use rust_servicemesh::router::{PathMatch, Route, Router}; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::Duration; + +fn bench_circuit_breaker(c: &mut Criterion) { + let rt = tokio::runtime::Runtime::new().unwrap(); + let config = CircuitBreakerConfig::default(); + + c.bench_function("circuit_breaker_allow_request", |b| { + b.to_async(&rt).iter(|| async { + let cb = CircuitBreaker::new(config.clone()); + black_box(cb.allow_request().await) + }); + }); + + c.bench_function("circuit_breaker_record_success", |b| { + b.to_async(&rt).iter(|| async { + let cb = CircuitBreaker::new(config.clone()); + cb.record_success().await; + black_box(()) + }); + }); + + c.bench_function("circuit_breaker_record_failure", |b| { + b.to_async(&rt).iter(|| async { + let cb = CircuitBreaker::new(config.clone()); + cb.record_failure().await; + black_box(()) + }); + }); +} + +fn bench_rate_limiter(c: &mut Criterion) { + let mut group = c.benchmark_group("rate_limiter"); + group.throughput(Throughput::Elements(1)); + + let config = RateLimitConfig::new(10000, 1000); + let limiter = RateLimiter::new(config); + let client_ip = Some(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))); + + group.bench_function("check_global", |b| { + b.iter(|| { + let _ = black_box(limiter.check(None)); + }); + }); + + group.bench_function("check_per_client", |b| { + b.iter(|| { + let _ = black_box(limiter.check(client_ip)); + }); + }); + + group.finish(); +} + +fn bench_router(c: &mut Criterion) { + let mut group = c.benchmark_group("router"); + + // Build router with various routes + let mut router = Router::new(); + router.add_route(Route::new( + "api-users", + PathMatch::exact("/api/users"), + "users-cluster", + )); + router.add_route(Route::new( + "api-prefix", + PathMatch::prefix("/api/"), + "api-cluster", + )); + router.add_route(Route::new( + "static", + PathMatch::prefix("/static/"), + "static-cluster", + )); + router.add_route(Route::new( + "regex-route", + PathMatch::regex(r"^/v[0-9]+/.*"), + "versioned-cluster", + )); + + let headers = http::HeaderMap::new(); + + group.bench_function("route_exact_match", |b| { + b.iter(|| { + black_box(router.route(&http::Method::GET, "/api/users", &headers)); + }); + }); + + group.bench_function("route_prefix_match", |b| { + b.iter(|| { + black_box(router.route(&http::Method::GET, "/api/products/123", &headers)); + }); + }); + + group.bench_function("route_regex_match", |b| { + b.iter(|| { + black_box(router.route(&http::Method::GET, "/v2/resource/abc", &headers)); + }); + }); + + group.bench_function("route_no_match", |b| { + b.iter(|| { + black_box(router.route(&http::Method::GET, "/unknown/path", &headers)); + }); + }); + + group.finish(); +} + +fn bench_retry_policy(c: &mut Criterion) { + let mut group = c.benchmark_group("retry"); + + group.bench_function("calculate_delay", |b| { + let config = RetryConfig::new().with_max_retries(5).with_jitter(false); + let policy = RetryPolicy::new(config); + + b.iter(|| { + black_box(policy.next_delay()); + }); + }); + + group.bench_function("calculate_delay_with_jitter", |b| { + let config = RetryConfig::new().with_max_retries(5).with_jitter(true); + let policy = RetryPolicy::new(config); + + b.iter(|| { + black_box(policy.next_delay()); + }); + }); + + group.finish(); +} + +fn bench_path_matching(c: &mut Criterion) { + let mut group = c.benchmark_group("path_matching"); + + let exact = PathMatch::exact("/api/v1/users"); + let prefix = PathMatch::prefix("/api/"); + let regex = PathMatch::regex(r"^/api/v[0-9]+/users/\d+$"); + + group.bench_function("exact_match_hit", |b| { + b.iter(|| { + black_box(exact.matches("/api/v1/users")); + }); + }); + + group.bench_function("exact_match_miss", |b| { + b.iter(|| { + black_box(exact.matches("/api/v1/products")); + }); + }); + + group.bench_function("prefix_match_hit", |b| { + b.iter(|| { + black_box(prefix.matches("/api/v1/users")); + }); + }); + + group.bench_function("prefix_match_miss", |b| { + b.iter(|| { + black_box(prefix.matches("/other/path")); + }); + }); + + group.bench_function("regex_match_hit", |b| { + b.iter(|| { + black_box(regex.matches("/api/v1/users/123")); + }); + }); + + group.bench_function("regex_match_miss", |b| { + b.iter(|| { + black_box(regex.matches("/api/v1/products/abc")); + }); + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_circuit_breaker, + bench_rate_limiter, + bench_router, + bench_retry_policy, + bench_path_matching, +); + +criterion_main!(benches); diff --git a/examples/basic_proxy.rs b/examples/basic_proxy.rs new file mode 100644 index 0000000..b4a9799 --- /dev/null +++ b/examples/basic_proxy.rs @@ -0,0 +1,69 @@ +//! Basic proxy example demonstrating minimal setup. +//! +//! Run with: +//! ```bash +//! cargo run --example basic_proxy +//! ``` + +use rust_servicemesh::listener::Listener; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::broadcast; +use tracing::{error, info}; + +#[tokio::main] +async fn main() { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), + ) + .init(); + + info!("Starting basic proxy example"); + + // Configure upstream servers + let upstream_addrs = Arc::new(vec!["http://httpbin.org".to_string()]); + + // Configure request timeout + let timeout = Duration::from_secs(30); + + // Create listener + let listener = match Listener::bind("127.0.0.1:3000", upstream_addrs, timeout).await { + Ok(l) => l, + Err(e) => { + error!("Failed to bind listener: {}", e); + return; + } + }; + + let addr = listener.local_addr(); + info!("Proxy listening on http://{}", addr); + info!("Try: curl http://{}/get", addr); + + // Create shutdown channel + let (shutdown_tx, shutdown_rx) = broadcast::channel(1); + + // Spawn proxy server + tokio::spawn(async move { + if let Err(e) = listener.serve(shutdown_rx).await { + error!("Listener error: {}", e); + } + }); + + // Wait for Ctrl+C + match tokio::signal::ctrl_c().await { + Ok(()) => { + info!("Received Ctrl+C, shutting down"); + let _ = shutdown_tx.send(()); + } + Err(e) => { + error!("Failed to listen for Ctrl+C: {}", e); + } + } + + // Give tasks time to clean up + tokio::time::sleep(Duration::from_millis(100)).await; + info!("Shutdown complete"); +} diff --git a/examples/circuit_breaker_demo.rs b/examples/circuit_breaker_demo.rs new file mode 100644 index 0000000..ff9ccc6 --- /dev/null +++ b/examples/circuit_breaker_demo.rs @@ -0,0 +1,128 @@ +//! Circuit breaker demonstration. +//! +//! Shows how the circuit breaker transitions between states based on failures and successes. +//! +//! Run with: +//! ```bash +//! cargo run --example circuit_breaker_demo +//! ``` + +use rust_servicemesh::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig, State}; +use std::time::Duration; +use tokio::time::sleep; +use tracing::{info, warn}; + +#[tokio::main] +async fn main() { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::new("info")) + .init(); + + info!("Circuit Breaker Demonstration"); + info!("==============================\n"); + + // Configure circuit breaker + let config = CircuitBreakerConfig { + failure_threshold: 3, + timeout: Duration::from_secs(2), + success_threshold: 2, + }; + + info!("Configuration:"); + info!(" Failure threshold: {}", config.failure_threshold); + info!(" Timeout: {:?}", config.timeout); + info!(" Success threshold: {}\n", config.success_threshold); + + let cb = CircuitBreaker::new(config); + + // Scenario 1: Closed -> Open (failures) + info!("Scenario 1: Triggering circuit breaker with failures"); + info!("State: {:?}", cb.state().await); + + for i in 1..=3 { + if cb.allow_request().await { + info!(" Request #{} allowed", i); + simulate_request(false).await; + cb.record_failure().await; + info!(" Recorded failure"); + } + } + + info!("State: {:?}\n", cb.state().await); + assert_eq!(cb.state().await, State::Open); + + // Scenario 2: Open -> reject requests + info!("Scenario 2: Requests rejected while circuit is open"); + if cb.allow_request().await { + info!(" Request allowed (unexpected!)"); + } else { + warn!(" Request REJECTED - circuit is open"); + } + info!("State: {:?}\n", cb.state().await); + + // Scenario 3: Open -> HalfOpen (timeout) + info!("Scenario 3: Waiting for timeout to transition to HalfOpen"); + info!(" Sleeping for {:?}...", Duration::from_secs(2)); + sleep(Duration::from_secs(2)).await; + + if cb.allow_request().await { + info!(" Request allowed - circuit is now HalfOpen"); + } + info!("State: {:?}\n", cb.state().await); + assert_eq!(cb.state().await, State::HalfOpen); + + // Scenario 4: HalfOpen -> Closed (successes) + info!("Scenario 4: Recording successes to close the circuit"); + for i in 1..=2 { + if cb.allow_request().await { + info!(" Request #{} allowed", i); + simulate_request(true).await; + cb.record_success().await; + info!(" Recorded success"); + } + } + + info!("State: {:?}\n", cb.state().await); + assert_eq!(cb.state().await, State::Closed); + + // Scenario 5: HalfOpen -> Open (failure) + info!("Scenario 5: HalfOpen failure reopens circuit immediately"); + cb.reset().await; + + // Trigger open + for _ in 0..3 { + cb.allow_request().await; + cb.record_failure().await; + } + + sleep(Duration::from_secs(2)).await; + cb.allow_request().await; // Transition to HalfOpen + + info!("State before failure: {:?}", cb.state().await); + cb.record_failure().await; + info!("State after failure: {:?}\n", cb.state().await); + assert_eq!(cb.state().await, State::Open); + + // Statistics + info!("Final Statistics:"); + let stats = cb.stats(); + info!(" Total requests: {}", stats.total_requests); + info!(" Total failures: {}", stats.total_failures); + info!( + " Failure rate: {:.1}%", + (stats.total_failures as f64 / stats.total_requests as f64) * 100.0 + ); + + info!("\nDemo complete!"); +} + +/// Simulates a request with configurable success/failure. +async fn simulate_request(success: bool) { + sleep(Duration::from_millis(10)).await; + if success { + info!(" [Simulated request succeeded]"); + } else { + warn!(" [Simulated request failed]"); + } +} diff --git a/src/circuit_breaker.rs b/src/circuit_breaker.rs new file mode 100644 index 0000000..4f6e3e9 --- /dev/null +++ b/src/circuit_breaker.rs @@ -0,0 +1,317 @@ +//! Circuit breaker implementation for fault tolerance. +//! +//! Implements a Hystrix-style circuit breaker with three states: +//! - **Closed**: Normal operation, requests flow through +//! - **Open**: Too many failures, reject requests immediately +//! - **HalfOpen**: Recovery mode, allow limited requests to test if service recovered + +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; + +/// Circuit breaker state +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum State { + /// Circuit is closed, requests flow normally + Closed, + /// Circuit is open, requests are rejected + Open, + /// Circuit is half-open, testing if service recovered + HalfOpen, +} + +/// Configuration for the circuit breaker. +#[derive(Debug, Clone)] +pub struct CircuitBreakerConfig { + /// Number of failures before opening the circuit + pub failure_threshold: u64, + /// Duration to wait before transitioning from Open to HalfOpen + pub timeout: Duration, + /// Number of successful requests in HalfOpen before closing + pub success_threshold: u64, +} + +impl Default for CircuitBreakerConfig { + fn default() -> Self { + Self { + failure_threshold: 5, + timeout: Duration::from_secs(30), + success_threshold: 2, + } + } +} + +/// Circuit breaker for preventing cascading failures. +/// +/// # Example +/// +/// ``` +/// use rust_servicemesh::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}; +/// +/// #[tokio::main] +/// async fn main() { +/// let config = CircuitBreakerConfig::default(); +/// let cb = CircuitBreaker::new(config); +/// +/// if cb.allow_request().await { +/// // Make request +/// match make_request().await { +/// Ok(_) => cb.record_success().await, +/// Err(_) => cb.record_failure().await, +/// } +/// } +/// } +/// +/// async fn make_request() -> Result<(), ()> { +/// Ok(()) +/// } +/// ``` +#[derive(Debug)] +pub struct CircuitBreaker { + state: Arc>, + failure_count: Arc, + success_count: Arc, + last_failure_time: Arc>>, + config: CircuitBreakerConfig, + total_requests: Arc, + total_failures: Arc, +} + +impl CircuitBreaker { + /// Creates a new circuit breaker with the given configuration. + pub fn new(config: CircuitBreakerConfig) -> Self { + Self { + state: Arc::new(RwLock::new(State::Closed)), + failure_count: Arc::new(AtomicU64::new(0)), + success_count: Arc::new(AtomicU64::new(0)), + last_failure_time: Arc::new(RwLock::new(None)), + config, + total_requests: Arc::new(AtomicUsize::new(0)), + total_failures: Arc::new(AtomicUsize::new(0)), + } + } + + /// Checks if a request should be allowed through. + /// + /// Returns `true` if the request should proceed, `false` if it should be rejected. + pub async fn allow_request(&self) -> bool { + self.total_requests.fetch_add(1, Ordering::Relaxed); + + let state = *self.state.read().await; + + match state { + State::Closed => true, + State::Open => { + // Check if timeout has elapsed + let last_failure = self.last_failure_time.read().await; + if let Some(last_time) = *last_failure { + if last_time.elapsed() >= self.config.timeout { + drop(last_failure); + // Transition to HalfOpen + *self.state.write().await = State::HalfOpen; + self.success_count.store(0, Ordering::Relaxed); + true + } else { + false + } + } else { + false + } + } + State::HalfOpen => true, + } + } + + /// Records a successful request. + pub async fn record_success(&self) { + let state = *self.state.read().await; + + match state { + State::HalfOpen => { + let successes = self.success_count.fetch_add(1, Ordering::Relaxed) + 1; + if successes >= self.config.success_threshold { + // Transition to Closed + *self.state.write().await = State::Closed; + self.failure_count.store(0, Ordering::Relaxed); + self.success_count.store(0, Ordering::Relaxed); + } + } + State::Closed => { + // Reset failure count on success + self.failure_count.store(0, Ordering::Relaxed); + } + State::Open => {} + } + } + + /// Records a failed request. + pub async fn record_failure(&self) { + self.total_failures.fetch_add(1, Ordering::Relaxed); + + let state = *self.state.read().await; + + match state { + State::Closed => { + let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1; + if failures >= self.config.failure_threshold { + // Transition to Open + *self.state.write().await = State::Open; + *self.last_failure_time.write().await = Some(Instant::now()); + } + } + State::HalfOpen => { + // Immediately reopen on failure + *self.state.write().await = State::Open; + *self.last_failure_time.write().await = Some(Instant::now()); + self.failure_count.store(0, Ordering::Relaxed); + self.success_count.store(0, Ordering::Relaxed); + } + State::Open => { + *self.last_failure_time.write().await = Some(Instant::now()); + } + } + } + + /// Returns the current state of the circuit breaker. + pub async fn state(&self) -> State { + *self.state.read().await + } + + /// Returns statistics about the circuit breaker. + pub fn stats(&self) -> CircuitBreakerStats { + CircuitBreakerStats { + total_requests: self.total_requests.load(Ordering::Relaxed), + total_failures: self.total_failures.load(Ordering::Relaxed), + current_failure_count: self.failure_count.load(Ordering::Relaxed), + current_success_count: self.success_count.load(Ordering::Relaxed), + } + } + + /// Resets the circuit breaker to the closed state. + #[allow(dead_code)] + pub async fn reset(&self) { + *self.state.write().await = State::Closed; + self.failure_count.store(0, Ordering::Relaxed); + self.success_count.store(0, Ordering::Relaxed); + *self.last_failure_time.write().await = None; + } +} + +/// Statistics for the circuit breaker. +#[derive(Debug, Clone)] +pub struct CircuitBreakerStats { + pub total_requests: usize, + pub total_failures: usize, + pub current_failure_count: u64, + pub current_success_count: u64, +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::time::sleep; + + #[tokio::test] + async fn test_circuit_breaker_closed_to_open() { + let config = CircuitBreakerConfig { + failure_threshold: 3, + timeout: Duration::from_millis(100), + success_threshold: 2, + }; + let cb = CircuitBreaker::new(config); + + assert_eq!(cb.state().await, State::Closed); + assert!(cb.allow_request().await); + + // Record failures + cb.record_failure().await; + cb.record_failure().await; + cb.record_failure().await; + + assert_eq!(cb.state().await, State::Open); + assert!(!cb.allow_request().await); + } + + #[tokio::test] + async fn test_circuit_breaker_open_to_halfopen() { + let config = CircuitBreakerConfig { + failure_threshold: 2, + timeout: Duration::from_millis(50), + success_threshold: 2, + }; + let cb = CircuitBreaker::new(config); + + // Trigger open state + cb.record_failure().await; + cb.record_failure().await; + assert_eq!(cb.state().await, State::Open); + + // Wait for timeout + sleep(Duration::from_millis(60)).await; + + // Should transition to HalfOpen + assert!(cb.allow_request().await); + assert_eq!(cb.state().await, State::HalfOpen); + } + + #[tokio::test] + async fn test_circuit_breaker_halfopen_to_closed() { + let config = CircuitBreakerConfig { + failure_threshold: 2, + timeout: Duration::from_millis(50), + success_threshold: 2, + }; + let cb = CircuitBreaker::new(config); + + // Trigger open state + cb.record_failure().await; + cb.record_failure().await; + + // Wait for timeout and transition to HalfOpen + sleep(Duration::from_millis(60)).await; + assert!(cb.allow_request().await); + + // Record successes + cb.record_success().await; + cb.record_success().await; + + assert_eq!(cb.state().await, State::Closed); + } + + #[tokio::test] + async fn test_circuit_breaker_halfopen_to_open() { + let config = CircuitBreakerConfig { + failure_threshold: 2, + timeout: Duration::from_millis(50), + success_threshold: 2, + }; + let cb = CircuitBreaker::new(config); + + // Trigger open state + cb.record_failure().await; + cb.record_failure().await; + + // Wait for timeout and transition to HalfOpen + sleep(Duration::from_millis(60)).await; + assert!(cb.allow_request().await); + + // Record failure in HalfOpen - should reopen + cb.record_failure().await; + assert_eq!(cb.state().await, State::Open); + } + + #[tokio::test] + async fn test_circuit_breaker_stats() { + let config = CircuitBreakerConfig::default(); + let cb = CircuitBreaker::new(config); + + cb.allow_request().await; + cb.allow_request().await; + cb.record_failure().await; + + let stats = cb.stats(); + assert_eq!(stats.total_requests, 2); + assert_eq!(stats.total_failures, 1); + } +} diff --git a/src/error.rs b/src/error.rs index bc6bcb4..88501d9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,12 +12,10 @@ pub enum ProxyError { /// Failed to accept an incoming connection. #[error("failed to accept connection: {0}")] - #[allow(dead_code)] AcceptConnection(#[source] io::Error), /// Failed to connect to upstream server. #[error("failed to connect to upstream {addr}: {source}")] - #[allow(dead_code)] UpstreamConnect { addr: String, source: io::Error }, /// HTTP protocol error. @@ -39,6 +37,46 @@ pub enum ProxyError { /// Service unavailable. #[error("service unavailable: {0}")] ServiceUnavailable(String), + + /// TLS configuration error. + #[error("TLS configuration error: {message}")] + TlsConfig { message: String }, + + /// TLS handshake error. + #[error("TLS handshake failed: {0}")] + TlsHandshake(String), + + /// Protocol negotiation error. + #[error("protocol negotiation failed: {0}")] + ProtocolNegotiation(String), + + /// Rate limit exceeded. + #[error("rate limit exceeded")] + RateLimitExceeded, + + /// Circuit breaker is open. + #[error("circuit breaker is open for upstream: {upstream}")] + CircuitBreakerOpen { upstream: String }, + + /// Request timeout. + #[error("request timed out after {duration_ms}ms")] + Timeout { duration_ms: u64 }, + + /// Retry exhausted. + #[error("all {attempts} retry attempts exhausted")] + RetryExhausted { attempts: u32 }, + + /// Invalid configuration. + #[error("invalid configuration: {0}")] + InvalidConfig(String), + + /// Route not found. + #[error("no route found for path: {path}")] + RouteNotFound { path: String }, + + /// gRPC error. + #[error("gRPC error: {message} (code: {code})")] + Grpc { code: i32, message: String }, } /// Result type alias for proxy operations. diff --git a/src/lib.rs b/src/lib.rs index eae574b..6516412 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,89 @@ //! Rust Service Mesh - High-performance data plane proxy //! -//! A service mesh proxy built with Rust, inspired by Envoy, providing -//! HTTP/1.1 and HTTP/2 proxying, load balancing, and observability. +//! A service mesh proxy built with Rust, providing HTTP/1.1 and HTTP/2 proxying, +//! load balancing, circuit breaking, rate limiting, and observability. +//! +//! # Features +//! +//! - **HTTP/1.1 and HTTP/2 Support**: Full protocol support with ALPN negotiation +//! - **TLS Termination**: Secure connections with Rustls +//! - **Load Balancing**: Round-robin, least connections, random, and weighted strategies +//! - **Circuit Breaker**: Fault tolerance with configurable thresholds +//! - **Rate Limiting**: Token bucket algorithm with per-client and global limits +//! - **L7 Routing**: Path, header, and method-based routing rules +//! - **Retry Logic**: Exponential backoff with configurable retry policies +//! - **Metrics**: Prometheus-compatible metrics export +//! +//! # Quick Start +//! +//! ```no_run +//! use rust_servicemesh::listener::Listener; +//! use std::sync::Arc; +//! use std::time::Duration; +//! use tokio::sync::broadcast; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box> { +//! // Configure upstream servers +//! let upstream = Arc::new(vec!["http://localhost:8080".to_string()]); +//! let timeout = Duration::from_secs(30); +//! +//! // Create and start the proxy +//! let listener = Listener::bind("127.0.0.1:3000", upstream, timeout).await?; +//! +//! let (shutdown_tx, shutdown_rx) = broadcast::channel(1); +//! listener.serve(shutdown_rx).await?; +//! +//! Ok(()) +//! } +//! ``` +//! +//! # Architecture +//! +//! The proxy is built using a modular architecture: +//! +//! - `listener`: TCP/TLS listener with protocol negotiation +//! - `service`: Tower service for request handling +//! - `router`: L7 routing with path/header matching +//! - `transport`: Connection pooling and load balancing +//! - `circuit_breaker`: Fault tolerance +//! - `ratelimit`: Request rate limiting +//! - `retry`: Retry logic with backoff +//! - `protocol`: HTTP/2 and TLS support +//! - `metrics`: Prometheus metrics +//! - `config`: Configuration management +//! +//! # Configuration +//! +//! The proxy can be configured via environment variables: +//! +//! - `PROXY_LISTEN_ADDR`: Address to listen on (default: "127.0.0.1:3000") +//! - `PROXY_UPSTREAM_ADDRS`: Comma-separated upstream addresses +//! - `PROXY_METRICS_ADDR`: Metrics endpoint address (default: "127.0.0.1:9090") +//! - `PROXY_REQUEST_TIMEOUT_MS`: Request timeout in milliseconds (default: 30000) pub mod admin; pub mod admin_listener; +pub mod circuit_breaker; pub mod config; pub mod error; pub mod listener; pub mod metrics; +pub mod protocol; +pub mod ratelimit; +pub mod retry; pub mod router; pub mod service; pub mod transport; + +// Re-export commonly used types +pub use circuit_breaker::{CircuitBreaker, CircuitBreakerConfig, State as CircuitBreakerState}; +pub use config::ProxyConfig; +pub use error::{ProxyError, Result}; +pub use listener::Listener; +pub use protocol::{HttpProtocol, TlsConfig}; +pub use ratelimit::{RateLimitConfig, RateLimiter}; +pub use retry::{RetryConfig, RetryExecutor, RetryPolicy}; +pub use router::{PathMatch, Route, Router}; +pub use service::ProxyService; +pub use transport::{Endpoint, LoadBalancer, PoolConfig, Transport}; diff --git a/src/listener.rs b/src/listener.rs index 778a1f8..8083c54 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -1,35 +1,47 @@ -//! TCP listener with graceful shutdown support. +//! TCP listener with HTTP/1.1 and HTTP/2 support. +//! +//! This module provides a multi-protocol listener that can handle both HTTP/1.1 +//! and HTTP/2 connections, with optional TLS support and ALPN-based protocol +//! negotiation. use crate::error::{ProxyError, Result}; +use crate::protocol::{HttpProtocol, TlsConfig}; use crate::service::ProxyService; use hyper::body::Incoming; -use hyper::server::conn::http1; +use hyper::server::conn::{http1, http2}; use hyper::service::service_fn; use hyper::Request; -use hyper_util::rt::TokioIo; +use hyper_util::rt::{TokioExecutor, TokioIo}; use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tokio::sync::broadcast; +use tokio_rustls::TlsAcceptor; use tower::Service; -use tracing::{error, info, instrument, warn}; +use tracing::{debug, error, info, instrument, warn}; /// HTTP listener that accepts connections and spawns handler tasks. /// -/// Supports graceful shutdown via a broadcast channel. +/// Supports HTTP/1.1 and HTTP/2 with automatic protocol negotiation via ALPN +/// when TLS is enabled. Without TLS, falls back to HTTP/1.1 or uses prior +/// knowledge for HTTP/2. /// /// # Example /// /// ```no_run /// use rust_servicemesh::listener::Listener; /// use std::sync::Arc; +/// use std::time::Duration; /// use tokio::sync::broadcast; /// /// #[tokio::main] /// async fn main() -> Result<(), Box> { /// let (shutdown_tx, _) = broadcast::channel(1); /// let upstream = vec!["http://127.0.0.1:8080".to_string()]; -/// let listener = Listener::bind("127.0.0.1:3000", Arc::new(upstream)).await?; +/// let timeout = Duration::from_secs(30); +/// let listener = Listener::bind("127.0.0.1:3000", Arc::new(upstream), timeout).await?; /// listener.serve(shutdown_tx.subscribe()).await?; /// Ok(()) /// } @@ -38,21 +50,28 @@ pub struct Listener { tcp_listener: TcpListener, proxy_service: ProxyService, addr: SocketAddr, + tls_acceptor: Option, + default_protocol: HttpProtocol, } impl Listener { - /// Binds to the specified address and creates a listener. + /// Binds to the specified address and creates a listener (HTTP only). /// /// # Arguments /// /// * `addr` - Address to bind to (e.g., "127.0.0.1:3000") /// * `upstream_addrs` - List of upstream server addresses + /// * `request_timeout` - Maximum duration for upstream requests /// /// # Errors /// /// Returns `ProxyError::ListenerBind` if binding fails. #[instrument(level = "info", skip(upstream_addrs))] - pub async fn bind(addr: &str, upstream_addrs: Arc>) -> Result { + pub async fn bind( + addr: &str, + upstream_addrs: Arc>, + request_timeout: Duration, + ) -> Result { let tcp_listener = TcpListener::bind(addr) .await .map_err(|e| ProxyError::ListenerBind { @@ -67,12 +86,97 @@ impl Listener { source: e, })?; - info!("bound to {}", local_addr); + info!("bound to {} (HTTP/1.1)", local_addr); Ok(Self { tcp_listener, - proxy_service: ProxyService::new(upstream_addrs), + proxy_service: ProxyService::new(upstream_addrs, request_timeout), addr: local_addr, + tls_acceptor: None, + default_protocol: HttpProtocol::Http1, + }) + } + + /// Binds to the specified address with TLS and HTTP/2 support. + /// + /// # Arguments + /// + /// * `addr` - Address to bind to (e.g., "127.0.0.1:3000") + /// * `upstream_addrs` - List of upstream server addresses + /// * `request_timeout` - Maximum duration for upstream requests + /// * `tls_config` - TLS configuration with certificate and key paths + /// + /// # Errors + /// + /// Returns `ProxyError::ListenerBind` if binding fails or + /// `ProxyError::TlsConfig` if TLS configuration is invalid. + #[instrument(level = "info", skip(upstream_addrs, tls_config))] + pub async fn bind_with_tls( + addr: &str, + upstream_addrs: Arc>, + request_timeout: Duration, + tls_config: TlsConfig, + ) -> Result { + let tcp_listener = TcpListener::bind(addr) + .await + .map_err(|e| ProxyError::ListenerBind { + addr: addr.to_string(), + source: e, + })?; + + let local_addr = tcp_listener + .local_addr() + .map_err(|e| ProxyError::ListenerBind { + addr: addr.to_string(), + source: e, + })?; + + let tls_acceptor = tls_config.build_acceptor()?; + let protocol = tls_config.protocol; + + info!("bound to {} (TLS with {:?} support)", local_addr, protocol); + + Ok(Self { + tcp_listener, + proxy_service: ProxyService::new(upstream_addrs, request_timeout), + addr: local_addr, + tls_acceptor: Some(tls_acceptor), + default_protocol: protocol, + }) + } + + /// Binds with HTTP/2 prior knowledge (h2c - HTTP/2 over cleartext). + /// + /// This enables HTTP/2 without TLS, using prior knowledge that the + /// client will speak HTTP/2. + #[instrument(level = "info", skip(upstream_addrs))] + pub async fn bind_h2c( + addr: &str, + upstream_addrs: Arc>, + request_timeout: Duration, + ) -> Result { + let tcp_listener = TcpListener::bind(addr) + .await + .map_err(|e| ProxyError::ListenerBind { + addr: addr.to_string(), + source: e, + })?; + + let local_addr = tcp_listener + .local_addr() + .map_err(|e| ProxyError::ListenerBind { + addr: addr.to_string(), + source: e, + })?; + + info!("bound to {} (h2c - HTTP/2 cleartext)", local_addr); + + Ok(Self { + tcp_listener, + proxy_service: ProxyService::new(upstream_addrs, request_timeout), + addr: local_addr, + tls_acceptor: None, + default_protocol: HttpProtocol::Http2, }) } @@ -81,6 +185,16 @@ impl Listener { self.addr } + /// Returns whether TLS is enabled. + pub fn is_tls_enabled(&self) -> bool { + self.tls_acceptor.is_some() + } + + /// Returns the default HTTP protocol. + pub fn default_protocol(&self) -> HttpProtocol { + self.default_protocol + } + /// Serves incoming connections until a shutdown signal is received. /// /// Spawns a new task for each connection. Gracefully shuts down when @@ -93,15 +207,33 @@ impl Listener { pub async fn serve(self, mut shutdown_rx: broadcast::Receiver<()>) -> Result<()> { info!("serving connections"); + let tls_acceptor = self.tls_acceptor.clone(); + let default_protocol = self.default_protocol; + loop { tokio::select! { accept_result = self.tcp_listener.accept() => { match accept_result { Ok((stream, peer_addr)) => { - info!("accepted connection from {}", peer_addr); + debug!("accepted connection from {}", peer_addr); let service = self.proxy_service.clone(); + let tls_acceptor = tls_acceptor.clone(); + tokio::spawn(async move { - if let Err(e) = Self::handle_connection(stream, service).await { + let result = if let Some(acceptor) = tls_acceptor { + Self::handle_tls_connection(stream, service, acceptor).await + } else { + match default_protocol { + HttpProtocol::Http2 => { + Self::handle_h2c_connection(stream, service).await + } + _ => { + Self::handle_http1_connection(stream, service).await + } + } + }; + + if let Err(e) = result { error!("connection error from {}: {}", peer_addr, e); } }); @@ -121,14 +253,58 @@ impl Listener { Ok(()) } - /// Handles a single TCP connection using HTTP/1.1. - #[instrument(level = "debug", skip(stream, service))] - async fn handle_connection(stream: tokio::net::TcpStream, service: ProxyService) -> Result<()> { - let io = TokioIo::new(stream); + /// Handles a TLS connection with ALPN-based protocol negotiation. + #[instrument(level = "debug", skip_all)] + async fn handle_tls_connection( + stream: tokio::net::TcpStream, + service: ProxyService, + acceptor: TlsAcceptor, + ) -> Result<()> { + let tls_stream = acceptor + .accept(stream) + .await + .map_err(|e| ProxyError::TlsHandshake(e.to_string()))?; + + // Determine protocol from ALPN negotiation + let protocol = { + let (_, server_conn) = tls_stream.get_ref(); + HttpProtocol::from_alpn(server_conn.alpn_protocol()) + }; + + debug!("negotiated protocol: {:?}", protocol); + match protocol { + HttpProtocol::Http2 => Self::serve_http2(TokioIo::new(tls_stream), service).await, + _ => Self::serve_http1(TokioIo::new(tls_stream), service).await, + } + } + + /// Handles a plain HTTP/1.1 connection. + #[instrument(level = "debug", skip_all)] + async fn handle_http1_connection( + stream: tokio::net::TcpStream, + service: ProxyService, + ) -> Result<()> { + Self::serve_http1(TokioIo::new(stream), service).await + } + + /// Handles an h2c (HTTP/2 cleartext) connection. + #[instrument(level = "debug", skip_all)] + async fn handle_h2c_connection( + stream: tokio::net::TcpStream, + service: ProxyService, + ) -> Result<()> { + Self::serve_http2(TokioIo::new(stream), service).await + } + + /// Serves HTTP/1.1 on the given I/O stream. + async fn serve_http1(io: TokioIo, service: ProxyService) -> Result<()> + where + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { let service = service_fn(move |req: Request| { - let mut service = service.clone(); - async move { service.call(req).await } + let mut svc = service.clone(); + async move { svc.call(req).await } }); http1::Builder::new() @@ -136,6 +312,22 @@ impl Listener { .await .map_err(ProxyError::Http) } + + /// Serves HTTP/2 on the given I/O stream. + async fn serve_http2(io: TokioIo, service: ProxyService) -> Result<()> + where + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let service = service_fn(move |req: Request| { + let mut svc = service.clone(); + async move { svc.call(req).await } + }); + + http2::Builder::new(TokioExecutor::new()) + .serve_connection(io, service) + .await + .map_err(ProxyError::Http) + } } #[cfg(test)] @@ -145,14 +337,30 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_listener_bind() { let upstream = Arc::new(vec!["http://127.0.0.1:9999".to_string()]); - let listener = Listener::bind("127.0.0.1:0", upstream).await; + let timeout = Duration::from_secs(30); + let listener = Listener::bind("127.0.0.1:0", upstream, timeout).await; assert!(listener.is_ok()); + let listener = listener.unwrap(); + assert!(!listener.is_tls_enabled()); + assert_eq!(listener.default_protocol(), HttpProtocol::Http1); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_listener_bind_invalid_address() { let upstream = Arc::new(vec!["http://127.0.0.1:9999".to_string()]); - let listener = Listener::bind("999.999.999.999:0", upstream).await; + let timeout = Duration::from_secs(30); + let listener = Listener::bind("999.999.999.999:0", upstream, timeout).await; assert!(listener.is_err()); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_listener_h2c() { + let upstream = Arc::new(vec!["http://127.0.0.1:9999".to_string()]); + let timeout = Duration::from_secs(30); + let listener = Listener::bind_h2c("127.0.0.1:0", upstream, timeout).await; + assert!(listener.is_ok()); + let listener = listener.unwrap(); + assert!(!listener.is_tls_enabled()); + assert_eq!(listener.default_protocol(), HttpProtocol::Http2); + } } diff --git a/src/main.rs b/src/main.rs index 2a2d1a8..4169337 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,26 @@ +#[allow(dead_code)] mod admin; mod admin_listener; +#[allow(dead_code)] +mod circuit_breaker; mod config; +#[allow(dead_code)] mod error; +#[allow(dead_code)] mod listener; +#[allow(dead_code)] mod metrics; +#[allow(dead_code)] +mod protocol; +#[allow(dead_code)] +mod ratelimit; +#[allow(dead_code)] +mod retry; +#[allow(dead_code)] mod router; +#[allow(dead_code)] mod service; +#[allow(dead_code)] mod transport; use admin_listener::AdminListener; @@ -43,7 +58,12 @@ async fn run() -> Result<(), Box> { let (shutdown_tx, _shutdown_rx) = broadcast::channel(1); - let proxy_listener = Listener::bind(&config.listen_addr, config.upstream_addrs_arc()).await?; + let proxy_listener = Listener::bind( + &config.listen_addr, + config.upstream_addrs_arc(), + config.request_timeout, + ) + .await?; let proxy_addr = proxy_listener.local_addr(); info!("proxy listening on {}", proxy_addr); diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..7b13d9a --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,248 @@ +//! Protocol negotiation and HTTP/2 support. +//! +//! This module provides ALPN-based protocol negotiation for HTTP/1.1 and HTTP/2 +//! connections, with optional TLS support via Rustls. + +use crate::error::{ProxyError, Result}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::ServerConfig; +use std::fs::File; +use std::io::BufReader; +use std::path::Path; +use std::sync::Arc; +use tokio_rustls::TlsAcceptor; + +/// Supported HTTP protocols for the proxy. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum HttpProtocol { + /// HTTP/1.1 protocol + Http1, + /// HTTP/2 protocol + #[default] + Http2, + /// Auto-negotiate based on ALPN (prefers HTTP/2) + Auto, +} + +impl HttpProtocol { + /// Returns the ALPN protocol identifiers for this protocol. + pub fn alpn_protocols(&self) -> Vec> { + match self { + HttpProtocol::Http1 => vec![b"http/1.1".to_vec()], + HttpProtocol::Http2 => vec![b"h2".to_vec()], + HttpProtocol::Auto => vec![b"h2".to_vec(), b"http/1.1".to_vec()], + } + } + + /// Determines the protocol from ALPN negotiation result. + pub fn from_alpn(alpn: Option<&[u8]>) -> Self { + match alpn { + Some(b"h2") => HttpProtocol::Http2, + Some(b"http/1.1") => HttpProtocol::Http1, + _ => HttpProtocol::Http1, // Default to HTTP/1.1 if no ALPN + } + } +} + +/// TLS configuration for the proxy. +#[derive(Debug, Clone)] +pub struct TlsConfig { + /// Path to the certificate file (PEM format) + pub cert_path: String, + /// Path to the private key file (PEM format) + pub key_path: String, + /// Preferred HTTP protocol for negotiation + pub protocol: HttpProtocol, +} + +impl TlsConfig { + /// Creates a new TLS configuration. + pub fn new(cert_path: impl Into, key_path: impl Into) -> Self { + Self { + cert_path: cert_path.into(), + key_path: key_path.into(), + protocol: HttpProtocol::Auto, + } + } + + /// Sets the preferred HTTP protocol. + pub fn with_protocol(mut self, protocol: HttpProtocol) -> Self { + self.protocol = protocol; + self + } + + /// Loads certificates from a PEM file. + fn load_certs(path: &Path) -> Result>> { + let file = File::open(path).map_err(|e| ProxyError::TlsConfig { + message: format!("failed to open cert file: {}", e), + })?; + let mut reader = BufReader::new(file); + + let certs: Vec> = rustls_pemfile::certs(&mut reader) + .filter_map(|cert| cert.ok()) + .collect(); + + if certs.is_empty() { + return Err(ProxyError::TlsConfig { + message: "no certificates found in file".to_string(), + }); + } + + Ok(certs) + } + + /// Loads a private key from a PEM file. + fn load_private_key(path: &Path) -> Result> { + let file = File::open(path).map_err(|e| ProxyError::TlsConfig { + message: format!("failed to open key file: {}", e), + })?; + let mut reader = BufReader::new(file); + + // Try to read PKCS#8 keys first, then RSA keys + let keys: Vec> = rustls_pemfile::read_all(&mut reader) + .filter_map(|item| match item.ok()? { + rustls_pemfile::Item::Pkcs1Key(key) => Some(PrivateKeyDer::Pkcs1(key)), + rustls_pemfile::Item::Pkcs8Key(key) => Some(PrivateKeyDer::Pkcs8(key)), + rustls_pemfile::Item::Sec1Key(key) => Some(PrivateKeyDer::Sec1(key)), + _ => None, + }) + .collect(); + + keys.into_iter() + .next() + .ok_or_else(|| ProxyError::TlsConfig { + message: "no private key found in file".to_string(), + }) + } + + /// Builds a TLS acceptor from this configuration. + pub fn build_acceptor(&self) -> Result { + let certs = Self::load_certs(Path::new(&self.cert_path))?; + let key = Self::load_private_key(Path::new(&self.key_path))?; + + let mut config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| ProxyError::TlsConfig { + message: format!("failed to configure TLS: {}", e), + })?; + + // Configure ALPN protocols + config.alpn_protocols = self.protocol.alpn_protocols(); + + Ok(TlsAcceptor::from(Arc::new(config))) + } +} + +/// Client TLS configuration for connecting to upstream servers. +#[derive(Debug, Clone)] +pub struct ClientTlsConfig { + /// Whether to verify server certificates + pub verify_server: bool, + /// Optional client certificate for mTLS + pub client_cert: Option, + /// Optional client key for mTLS + pub client_key: Option, +} + +impl Default for ClientTlsConfig { + fn default() -> Self { + Self { + verify_server: true, + client_cert: None, + client_key: None, + } + } +} + +impl ClientTlsConfig { + /// Creates a new client TLS configuration. + pub fn new() -> Self { + Self::default() + } + + /// Disables server certificate verification (not recommended for production). + pub fn danger_accept_invalid_certs(mut self) -> Self { + self.verify_server = false; + self + } + + /// Sets client certificate for mTLS. + pub fn with_client_cert(mut self, cert_path: String, key_path: String) -> Self { + self.client_cert = Some(cert_path); + self.client_key = Some(key_path); + self + } + + /// Builds a Rustls client configuration. + pub fn build_client_config(&self) -> Result { + let root_store = rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), + }; + + let builder = rustls::ClientConfig::builder().with_root_certificates(root_store); + + let config = + if let (Some(cert_path), Some(key_path)) = (&self.client_cert, &self.client_key) { + let certs = TlsConfig::load_certs(Path::new(cert_path))?; + let key = TlsConfig::load_private_key(Path::new(key_path))?; + builder + .with_client_auth_cert(certs, key) + .map_err(|e| ProxyError::TlsConfig { + message: format!("failed to configure client auth: {}", e), + })? + } else { + builder.with_no_client_auth() + }; + + Ok(config) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_http_protocol_alpn() { + assert_eq!( + HttpProtocol::Http1.alpn_protocols(), + vec![b"http/1.1".to_vec()] + ); + assert_eq!(HttpProtocol::Http2.alpn_protocols(), vec![b"h2".to_vec()]); + assert_eq!( + HttpProtocol::Auto.alpn_protocols(), + vec![b"h2".to_vec(), b"http/1.1".to_vec()] + ); + } + + #[test] + fn test_protocol_from_alpn() { + assert_eq!(HttpProtocol::from_alpn(Some(b"h2")), HttpProtocol::Http2); + assert_eq!( + HttpProtocol::from_alpn(Some(b"http/1.1")), + HttpProtocol::Http1 + ); + assert_eq!(HttpProtocol::from_alpn(None), HttpProtocol::Http1); + } + + #[test] + fn test_tls_config_builder() { + let config = TlsConfig::new("cert.pem", "key.pem").with_protocol(HttpProtocol::Http2); + + assert_eq!(config.cert_path, "cert.pem"); + assert_eq!(config.key_path, "key.pem"); + assert_eq!(config.protocol, HttpProtocol::Http2); + } + + #[test] + fn test_client_tls_config() { + let config = ClientTlsConfig::new() + .danger_accept_invalid_certs() + .with_client_cert("client.pem".to_string(), "client-key.pem".to_string()); + + assert!(!config.verify_server); + assert_eq!(config.client_cert, Some("client.pem".to_string())); + assert_eq!(config.client_key, Some("client-key.pem".to_string())); + } +} diff --git a/src/ratelimit.rs b/src/ratelimit.rs new file mode 100644 index 0000000..9c28e17 --- /dev/null +++ b/src/ratelimit.rs @@ -0,0 +1,404 @@ +//! Rate limiting middleware using token bucket algorithm. +//! +//! Provides configurable rate limiting for incoming requests with support +//! for multiple strategies: per-client, global, and per-route limiting. + +use dashmap::DashMap; +use parking_lot::Mutex; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tracing::debug; + +/// Configuration for rate limiting. +#[derive(Debug, Clone)] +pub struct RateLimitConfig { + /// Maximum number of requests allowed in the window. + pub requests_per_second: u64, + /// Burst capacity (allows temporary spikes above the rate). + pub burst_size: u64, + /// Whether to enable per-client rate limiting. + pub per_client: bool, + /// Time-to-live for client rate limit entries. + pub client_ttl: Duration, +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + requests_per_second: 100, + burst_size: 50, + per_client: true, + client_ttl: Duration::from_secs(300), + } + } +} + +impl RateLimitConfig { + /// Creates a new rate limit configuration. + pub fn new(requests_per_second: u64, burst_size: u64) -> Self { + Self { + requests_per_second, + burst_size, + ..Default::default() + } + } + + /// Enables or disables per-client rate limiting. + pub fn with_per_client(mut self, per_client: bool) -> Self { + self.per_client = per_client; + self + } + + /// Sets the TTL for client entries. + pub fn with_client_ttl(mut self, ttl: Duration) -> Self { + self.client_ttl = ttl; + self + } +} + +/// Token bucket for rate limiting. +#[derive(Debug)] +struct TokenBucket { + /// Current number of available tokens. + tokens: f64, + /// Maximum capacity of the bucket. + capacity: f64, + /// Rate at which tokens are added (per second). + refill_rate: f64, + /// Last time the bucket was updated. + last_update: Instant, +} + +impl TokenBucket { + /// Creates a new token bucket. + fn new(capacity: f64, refill_rate: f64) -> Self { + Self { + tokens: capacity, + capacity, + refill_rate, + last_update: Instant::now(), + } + } + + /// Refills tokens based on elapsed time. + fn refill(&mut self) { + let now = Instant::now(); + let elapsed = now.duration_since(self.last_update).as_secs_f64(); + self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity); + self.last_update = now; + } + + /// Attempts to consume a token. + /// + /// Returns `true` if a token was consumed, `false` if the bucket is empty. + fn try_consume(&mut self) -> bool { + self.refill(); + if self.tokens >= 1.0 { + self.tokens -= 1.0; + true + } else { + false + } + } + + /// Returns the estimated wait time until a token is available. + fn wait_time(&self) -> Duration { + if self.tokens >= 1.0 { + Duration::ZERO + } else { + let needed = 1.0 - self.tokens; + Duration::from_secs_f64(needed / self.refill_rate) + } + } + + /// Returns the current number of available tokens. + fn available_tokens(&self) -> f64 { + self.tokens + } +} + +/// Client rate limit entry with TTL tracking. +struct ClientEntry { + bucket: Mutex, + last_access: Mutex, +} + +impl ClientEntry { + fn new(config: &RateLimitConfig) -> Self { + Self { + bucket: Mutex::new(TokenBucket::new( + config.burst_size as f64, + config.requests_per_second as f64, + )), + last_access: Mutex::new(Instant::now()), + } + } + + fn try_acquire(&self) -> bool { + *self.last_access.lock() = Instant::now(); + self.bucket.lock().try_consume() + } + + fn is_expired(&self, ttl: Duration) -> bool { + self.last_access.lock().elapsed() > ttl + } +} + +/// Rate limiter with support for global and per-client limiting. +pub struct RateLimiter { + config: RateLimitConfig, + global_bucket: Mutex, + client_buckets: DashMap>, + last_cleanup: Mutex, +} + +impl RateLimiter { + /// Creates a new rate limiter with the given configuration. + pub fn new(config: RateLimitConfig) -> Self { + let global_bucket = + TokenBucket::new(config.burst_size as f64, config.requests_per_second as f64); + + Self { + config, + global_bucket: Mutex::new(global_bucket), + client_buckets: DashMap::new(), + last_cleanup: Mutex::new(Instant::now()), + } + } + + /// Creates a rate limiter with default configuration. + pub fn with_defaults() -> Self { + Self::new(RateLimitConfig::default()) + } + + /// Checks if a request should be allowed. + /// + /// Returns `Ok(())` if allowed, `Err(RateLimitInfo)` if rate limited. + pub fn check(&self, client_ip: Option) -> Result<(), RateLimitInfo> { + // First check global rate limit + if !self.global_bucket.lock().try_consume() { + let wait_time = self.global_bucket.lock().wait_time(); + debug!("global rate limit exceeded"); + return Err(RateLimitInfo { + limit_type: RateLimitType::Global, + retry_after: wait_time, + remaining: 0, + }); + } + + // Then check per-client rate limit if enabled + if self.config.per_client { + if let Some(ip) = client_ip { + self.maybe_cleanup(); + + let entry = self + .client_buckets + .entry(ip) + .or_insert_with(|| Arc::new(ClientEntry::new(&self.config))) + .clone(); + + if !entry.try_acquire() { + let wait_time = entry.bucket.lock().wait_time(); + debug!(client = %ip, "per-client rate limit exceeded"); + return Err(RateLimitInfo { + limit_type: RateLimitType::PerClient, + retry_after: wait_time, + remaining: 0, + }); + } + } + } + + Ok(()) + } + + /// Cleans up expired client entries periodically. + fn maybe_cleanup(&self) { + let mut last_cleanup = self.last_cleanup.lock(); + if last_cleanup.elapsed() < Duration::from_secs(60) { + return; + } + + *last_cleanup = Instant::now(); + drop(last_cleanup); + + let ttl = self.config.client_ttl; + let initial_count = self.client_buckets.len(); + + self.client_buckets + .retain(|_, entry| !entry.is_expired(ttl)); + + let removed = initial_count - self.client_buckets.len(); + if removed > 0 { + debug!(removed = removed, "cleaned up expired rate limit entries"); + } + } + + /// Returns the current statistics. + pub fn stats(&self) -> RateLimitStats { + RateLimitStats { + global_available: self.global_bucket.lock().available_tokens() as u64, + client_count: self.client_buckets.len(), + requests_per_second: self.config.requests_per_second, + burst_size: self.config.burst_size, + } + } +} + +/// Type of rate limit that was exceeded. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RateLimitType { + /// Global rate limit was exceeded. + Global, + /// Per-client rate limit was exceeded. + PerClient, +} + +/// Information about a rate limit rejection. +#[derive(Debug, Clone)] +pub struct RateLimitInfo { + /// Type of rate limit that was exceeded. + pub limit_type: RateLimitType, + /// Suggested time to wait before retrying. + pub retry_after: Duration, + /// Remaining requests in the current window. + pub remaining: u64, +} + +impl RateLimitInfo { + /// Returns the `Retry-After` header value in seconds. + pub fn retry_after_secs(&self) -> u64 { + self.retry_after.as_secs().max(1) + } +} + +/// Rate limiter statistics. +#[derive(Debug, Clone)] +pub struct RateLimitStats { + /// Available tokens in the global bucket. + pub global_available: u64, + /// Number of tracked clients. + pub client_count: usize, + /// Configured requests per second. + pub requests_per_second: u64, + /// Configured burst size. + pub burst_size: u64, +} + +/// Rate limiter middleware wrapper for Tower services. +pub struct RateLimitLayer { + limiter: Arc, +} + +impl RateLimitLayer { + /// Creates a new rate limit layer. + pub fn new(limiter: Arc) -> Self { + Self { limiter } + } + + /// Returns the underlying rate limiter. + pub fn limiter(&self) -> &Arc { + &self.limiter + } +} + +impl Clone for RateLimitLayer { + fn clone(&self) -> Self { + Self { + limiter: Arc::clone(&self.limiter), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::Ipv4Addr; + + #[test] + fn test_token_bucket_basic() { + let mut bucket = TokenBucket::new(10.0, 10.0); + + // Should be able to consume initial tokens + for _ in 0..10 { + assert!(bucket.try_consume()); + } + + // Should be empty now + assert!(!bucket.try_consume()); + } + + #[test] + fn test_token_bucket_refill() { + let mut bucket = TokenBucket::new(10.0, 1000.0); + + // Consume all tokens + for _ in 0..10 { + bucket.try_consume(); + } + + // Simulate time passing (by manually setting last_update) + bucket.last_update = Instant::now() - Duration::from_millis(100); + bucket.refill(); + + // Should have refilled ~100 tokens (capped at capacity) + assert!(bucket.available_tokens() >= 9.0); + } + + #[test] + fn test_rate_limiter_global() { + let config = RateLimitConfig::new(10, 5).with_per_client(false); + let limiter = RateLimiter::new(config); + + // Should allow burst_size requests + for _ in 0..5 { + assert!(limiter.check(None).is_ok()); + } + + // Should be rate limited after burst + assert!(limiter.check(None).is_err()); + } + + #[test] + fn test_rate_limiter_per_client() { + // Global: 100 req/s, burst 10 - Per-client: same (inherited) + let config = RateLimitConfig::new(100, 10).with_per_client(true); + let limiter = RateLimiter::new(config); + + let client1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)); + let client2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2)); + + // Client 1 uses 5 of their 10 tokens + for _ in 0..5 { + assert!(limiter.check(Some(client1)).is_ok()); + } + + // Client 2 should have their own 10 token quota + // Note: global bucket also depletes, so client2 uses from both + assert!(limiter.check(Some(client2)).is_ok()); + assert!(limiter.check(Some(client2)).is_ok()); + } + + #[test] + fn test_rate_limit_info() { + let info = RateLimitInfo { + limit_type: RateLimitType::Global, + retry_after: Duration::from_millis(500), + remaining: 0, + }; + + assert_eq!(info.retry_after_secs(), 1); + } + + #[test] + fn test_rate_limiter_stats() { + let config = RateLimitConfig::new(100, 50); + let limiter = RateLimiter::new(config); + + let stats = limiter.stats(); + assert_eq!(stats.requests_per_second, 100); + assert_eq!(stats.burst_size, 50); + assert_eq!(stats.client_count, 0); + } +} diff --git a/src/retry.rs b/src/retry.rs new file mode 100644 index 0000000..e39fe92 --- /dev/null +++ b/src/retry.rs @@ -0,0 +1,396 @@ +//! Retry middleware with exponential backoff. +//! +//! Provides configurable retry logic for failed requests, with support for +//! exponential backoff, jitter, and customizable retry conditions. + +use rand::Rng; +use std::time::Duration; +use tracing::{debug, warn}; + +/// Configuration for retry behavior. +#[derive(Debug, Clone)] +pub struct RetryConfig { + /// Maximum number of retry attempts (excluding the initial request). + pub max_retries: u32, + /// Base delay between retries. + pub base_delay: Duration, + /// Maximum delay between retries. + pub max_delay: Duration, + /// Multiplier for exponential backoff. + pub backoff_multiplier: f64, + /// Whether to add jitter to delays. + pub use_jitter: bool, + /// HTTP status codes that should trigger a retry. + pub retryable_status_codes: Vec, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: 3, + base_delay: Duration::from_millis(100), + max_delay: Duration::from_secs(10), + backoff_multiplier: 2.0, + use_jitter: true, + retryable_status_codes: vec![502, 503, 504], + } + } +} + +impl RetryConfig { + /// Creates a new retry configuration with default values. + pub fn new() -> Self { + Self::default() + } + + /// Sets the maximum number of retries. + pub fn with_max_retries(mut self, max_retries: u32) -> Self { + self.max_retries = max_retries; + self + } + + /// Sets the base delay between retries. + pub fn with_base_delay(mut self, delay: Duration) -> Self { + self.base_delay = delay; + self + } + + /// Sets the maximum delay between retries. + pub fn with_max_delay(mut self, delay: Duration) -> Self { + self.max_delay = delay; + self + } + + /// Sets the backoff multiplier. + pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self { + self.backoff_multiplier = multiplier; + self + } + + /// Enables or disables jitter. + pub fn with_jitter(mut self, use_jitter: bool) -> Self { + self.use_jitter = use_jitter; + self + } + + /// Sets the HTTP status codes that should trigger a retry. + pub fn with_retryable_status_codes(mut self, codes: Vec) -> Self { + self.retryable_status_codes = codes; + self + } + + /// Checks if a status code should trigger a retry. + pub fn is_retryable_status(&self, status: u16) -> bool { + self.retryable_status_codes.contains(&status) + } +} + +/// Retry policy that determines when and how to retry. +#[derive(Debug, Clone)] +pub struct RetryPolicy { + config: RetryConfig, + attempt: u32, +} + +impl RetryPolicy { + /// Creates a new retry policy with the given configuration. + pub fn new(config: RetryConfig) -> Self { + Self { config, attempt: 0 } + } + + /// Returns the current attempt number (0-indexed). + pub fn attempt(&self) -> u32 { + self.attempt + } + + /// Returns the maximum number of retries. + pub fn max_retries(&self) -> u32 { + self.config.max_retries + } + + /// Checks if more retries are available. + pub fn has_remaining_retries(&self) -> bool { + self.attempt < self.config.max_retries + } + + /// Calculates the delay for the next retry attempt. + pub fn next_delay(&self) -> Duration { + let base_ms = self.config.base_delay.as_millis() as f64; + let multiplier = self.config.backoff_multiplier.powi(self.attempt as i32); + let delay_ms = base_ms * multiplier; + + let delay_ms = delay_ms.min(self.config.max_delay.as_millis() as f64); + + let delay_ms = if self.config.use_jitter { + // Add jitter: random value between 0.5x and 1.5x the delay + let jitter = rand::thread_rng().gen_range(0.5..1.5); + delay_ms * jitter + } else { + delay_ms + }; + + Duration::from_millis(delay_ms as u64) + } + + /// Records a retry attempt and returns the delay to wait. + /// + /// Returns `None` if no more retries are available. + pub fn record_retry(&mut self) -> Option { + if !self.has_remaining_retries() { + return None; + } + + let delay = self.next_delay(); + self.attempt += 1; + + debug!( + attempt = self.attempt, + max_retries = self.config.max_retries, + delay_ms = delay.as_millis(), + "scheduling retry" + ); + + Some(delay) + } + + /// Resets the retry policy for a new request. + pub fn reset(&mut self) { + self.attempt = 0; + } + + /// Checks if a response should be retried based on status code. + pub fn should_retry_status(&self, status: u16) -> bool { + self.has_remaining_retries() && self.config.is_retryable_status(status) + } + + /// Checks if an error should be retried. + /// + /// By default, connection errors and timeouts are retryable. + pub fn should_retry_error(&self, error: &str) -> bool { + if !self.has_remaining_retries() { + return false; + } + + let error_lower = error.to_lowercase(); + error_lower.contains("connection") + || error_lower.contains("timeout") + || error_lower.contains("reset") + || error_lower.contains("refused") + } +} + +/// Executes a request with retry logic. +pub struct RetryExecutor { + policy: RetryPolicy, +} + +impl RetryExecutor { + /// Creates a new retry executor with the given configuration. + pub fn new(config: RetryConfig) -> Self { + Self { + policy: RetryPolicy::new(config), + } + } + + /// Creates a retry executor with default configuration. + pub fn with_defaults() -> Self { + Self::new(RetryConfig::default()) + } + + /// Executes a request with retry logic. + /// + /// The `request_fn` closure is called for each attempt. If it returns + /// a retryable error or status code, the request is retried after a delay. + pub async fn execute(&mut self, mut request_fn: F) -> Result> + where + F: FnMut() -> Fut, + Fut: std::future::Future>, + E: std::fmt::Display, + { + loop { + match request_fn().await { + Ok(result) => return Ok(result), + Err(e) => { + let error_str = e.to_string(); + + if self.policy.should_retry_error(&error_str) { + if let Some(delay) = self.policy.record_retry() { + warn!( + attempt = self.policy.attempt(), + error = %error_str, + delay_ms = delay.as_millis(), + "retrying after error" + ); + tokio::time::sleep(delay).await; + continue; + } + } + + return Err(RetryError::Exhausted { + attempts: self.policy.attempt() + 1, + last_error: e, + }); + } + } + } + } + + /// Resets the executor for a new request. + pub fn reset(&mut self) { + self.policy.reset(); + } +} + +/// Error returned when retry attempts are exhausted. +#[derive(Debug)] +pub enum RetryError { + /// All retry attempts were exhausted. + Exhausted { + /// Total number of attempts made. + attempts: u32, + /// The last error encountered. + last_error: E, + }, +} + +impl std::fmt::Display for RetryError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RetryError::Exhausted { + attempts, + last_error, + } => { + write!( + f, + "all {} retry attempts exhausted, last error: {}", + attempts, last_error + ) + } + } + } +} + +impl std::error::Error for RetryError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_retry_config_default() { + let config = RetryConfig::default(); + assert_eq!(config.max_retries, 3); + assert_eq!(config.base_delay, Duration::from_millis(100)); + assert!(config.use_jitter); + } + + #[test] + fn test_retry_config_builder() { + let config = RetryConfig::new() + .with_max_retries(5) + .with_base_delay(Duration::from_millis(200)) + .with_jitter(false); + + assert_eq!(config.max_retries, 5); + assert_eq!(config.base_delay, Duration::from_millis(200)); + assert!(!config.use_jitter); + } + + #[test] + fn test_retry_policy_has_remaining() { + let config = RetryConfig::new().with_max_retries(2); + let mut policy = RetryPolicy::new(config); + + assert!(policy.has_remaining_retries()); + policy.record_retry(); + assert!(policy.has_remaining_retries()); + policy.record_retry(); + assert!(!policy.has_remaining_retries()); + } + + #[test] + fn test_retry_policy_delay_increases() { + let config = RetryConfig::new() + .with_base_delay(Duration::from_millis(100)) + .with_backoff_multiplier(2.0) + .with_jitter(false); + + let mut policy = RetryPolicy::new(config); + + let delay1 = policy.next_delay(); + policy.record_retry(); + let delay2 = policy.next_delay(); + policy.record_retry(); + let delay3 = policy.next_delay(); + + assert_eq!(delay1, Duration::from_millis(100)); + assert_eq!(delay2, Duration::from_millis(200)); + assert_eq!(delay3, Duration::from_millis(400)); + } + + #[test] + fn test_retry_policy_max_delay() { + let config = RetryConfig::new() + .with_base_delay(Duration::from_secs(1)) + .with_max_delay(Duration::from_secs(5)) + .with_backoff_multiplier(10.0) + .with_jitter(false); + + let mut policy = RetryPolicy::new(config); + policy.record_retry(); + policy.record_retry(); + + let delay = policy.next_delay(); + assert_eq!(delay, Duration::from_secs(5)); + } + + #[test] + fn test_retry_policy_retryable_status() { + let config = RetryConfig::new().with_retryable_status_codes(vec![502, 503]); + let policy = RetryPolicy::new(config); + + assert!(policy.should_retry_status(502)); + assert!(policy.should_retry_status(503)); + assert!(!policy.should_retry_status(500)); + assert!(!policy.should_retry_status(200)); + } + + #[test] + fn test_retry_policy_retryable_error() { + let config = RetryConfig::new(); + let policy = RetryPolicy::new(config); + + assert!(policy.should_retry_error("connection refused")); + assert!(policy.should_retry_error("Connection reset by peer")); + assert!(policy.should_retry_error("request timeout")); + assert!(!policy.should_retry_error("invalid request")); + } + + #[tokio::test] + async fn test_retry_executor_success() { + let mut executor = RetryExecutor::with_defaults(); + let result = executor + .execute(|| async { Ok::(42) }) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 42); + } + + #[tokio::test] + async fn test_retry_executor_exhausted() { + let config = RetryConfig::new().with_max_retries(2); + let mut executor = RetryExecutor::new(config); + + let result: Result> = executor + .execute(|| async { Err("connection refused") }) + .await; + + assert!(result.is_err()); + match result.unwrap_err() { + RetryError::Exhausted { attempts, .. } => { + assert_eq!(attempts, 3); // 1 initial + 2 retries + } + } + } +} diff --git a/src/router.rs b/src/router.rs index e164c8f..dd1ca94 100644 --- a/src/router.rs +++ b/src/router.rs @@ -1,4 +1,629 @@ -#[allow(dead_code)] +//! L7 routing with path and header-based matching. +//! +//! Provides flexible routing rules for directing traffic to different +//! upstream clusters based on request attributes. + +use regex::Regex; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use tracing::{debug, warn}; + +/// Route matching priority (higher = evaluated first). +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum RoutePriority { + /// Exact match routes (highest priority). + Exact = 100, + /// Prefix match routes. + Prefix = 50, + /// Regex match routes. + Regex = 25, + /// Default/catch-all routes (lowest priority). + Default = 0, +} + +/// Condition for matching a header. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum HeaderMatch { + /// Header must have exactly this value. + Exact { name: String, value: String }, + /// Header must contain this substring. + Contains { name: String, value: String }, + /// Header must match this regex pattern. + Regex { name: String, pattern: String }, + /// Header must be present (any value). + Present { name: String }, + /// Header must be absent. + Absent { name: String }, +} + +impl HeaderMatch { + /// Creates an exact header match. + pub fn exact(name: impl Into, value: impl Into) -> Self { + Self::Exact { + name: name.into(), + value: value.into(), + } + } + + /// Creates a header presence check. + pub fn present(name: impl Into) -> Self { + Self::Present { name: name.into() } + } + + /// Creates a header absence check. + pub fn absent(name: impl Into) -> Self { + Self::Absent { name: name.into() } + } + + /// Checks if the header matches. + pub fn matches(&self, headers: &http::HeaderMap) -> bool { + match self { + HeaderMatch::Exact { name, value } => { + headers.get(name).is_some_and(|v| v == value.as_str()) + } + HeaderMatch::Contains { name, value } => headers + .get(name) + .and_then(|v| v.to_str().ok()) + .is_some_and(|v| v.contains(value.as_str())), + HeaderMatch::Regex { name, pattern } => { + if let Ok(regex) = Regex::new(pattern) { + headers + .get(name) + .and_then(|v| v.to_str().ok()) + .is_some_and(|v| regex.is_match(v)) + } else { + warn!(pattern = %pattern, "invalid regex pattern"); + false + } + } + HeaderMatch::Present { name } => headers.contains_key(name), + HeaderMatch::Absent { name } => !headers.contains_key(name), + } + } +} + +/// Condition for matching a request path. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum PathMatch { + /// Path must be exactly this value. + Exact { path: String }, + /// Path must start with this prefix. + Prefix { prefix: String }, + /// Path must match this regex pattern. + Regex { pattern: String }, +} + +impl PathMatch { + /// Creates an exact path match. + pub fn exact(path: impl Into) -> Self { + Self::Exact { path: path.into() } + } + + /// Creates a prefix path match. + pub fn prefix(prefix: impl Into) -> Self { + Self::Prefix { + prefix: prefix.into(), + } + } + + /// Creates a regex path match. + pub fn regex(pattern: impl Into) -> Self { + Self::Regex { + pattern: pattern.into(), + } + } + + /// Checks if the path matches. + pub fn matches(&self, path: &str) -> bool { + match self { + PathMatch::Exact { path: expected } => path == expected, + PathMatch::Prefix { prefix } => path.starts_with(prefix), + PathMatch::Regex { pattern } => { + if let Ok(regex) = Regex::new(pattern) { + regex.is_match(path) + } else { + warn!(pattern = %pattern, "invalid regex pattern"); + false + } + } + } + } + + /// Returns the priority for this match type. + pub fn priority(&self) -> RoutePriority { + match self { + PathMatch::Exact { .. } => RoutePriority::Exact, + PathMatch::Prefix { .. } => RoutePriority::Prefix, + PathMatch::Regex { .. } => RoutePriority::Regex, + } + } +} + +/// Condition for matching HTTP method. +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "UPPERCASE")] +pub enum MethodMatch { + Get, + Post, + Put, + Delete, + Patch, + Head, + Options, + #[default] + Any, +} + +impl MethodMatch { + /// Checks if the method matches. + pub fn matches(&self, method: &http::Method) -> bool { + match self { + MethodMatch::Any => true, + MethodMatch::Get => method == http::Method::GET, + MethodMatch::Post => method == http::Method::POST, + MethodMatch::Put => method == http::Method::PUT, + MethodMatch::Delete => method == http::Method::DELETE, + MethodMatch::Patch => method == http::Method::PATCH, + MethodMatch::Head => method == http::Method::HEAD, + MethodMatch::Options => method == http::Method::OPTIONS, + } + } +} + +/// A single routing rule. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Route { + /// Unique name for this route. + pub name: String, + /// Path matching condition. + pub path: PathMatch, + /// HTTP method matching (optional, defaults to Any). + #[serde(default)] + pub method: MethodMatch, + /// Header matching conditions (all must match). + #[serde(default)] + pub headers: Vec, + /// Target upstream cluster name. + pub upstream: String, + /// Weight for load balancing (when multiple routes match). + #[serde(default = "default_weight")] + pub weight: u32, + /// Whether this route is enabled. + #[serde(default = "default_enabled")] + pub enabled: bool, + /// Request timeout override for this route. + pub timeout_ms: Option, + /// Path rewrite (replace matched path with this). + pub rewrite: Option, +} + +fn default_weight() -> u32 { + 100 +} + +fn default_enabled() -> bool { + true +} + +impl Route { + /// Creates a new route with the given name and path. + pub fn new(name: impl Into, path: PathMatch, upstream: impl Into) -> Self { + Self { + name: name.into(), + path, + method: MethodMatch::Any, + headers: Vec::new(), + upstream: upstream.into(), + weight: 100, + enabled: true, + timeout_ms: None, + rewrite: None, + } + } + + /// Sets the HTTP method for this route. + pub fn with_method(mut self, method: MethodMatch) -> Self { + self.method = method; + self + } + + /// Adds a header match condition. + pub fn with_header(mut self, header: HeaderMatch) -> Self { + self.headers.push(header); + self + } + + /// Sets the weight for this route. + pub fn with_weight(mut self, weight: u32) -> Self { + self.weight = weight; + self + } + + /// Sets a timeout override. + pub fn with_timeout(mut self, timeout_ms: u64) -> Self { + self.timeout_ms = Some(timeout_ms); + self + } + + /// Sets a path rewrite rule. + pub fn with_rewrite(mut self, rewrite: impl Into) -> Self { + self.rewrite = Some(rewrite.into()); + self + } + + /// Checks if this route matches the request. + pub fn matches(&self, method: &http::Method, path: &str, headers: &http::HeaderMap) -> bool { + if !self.enabled { + return false; + } + + if !self.method.matches(method) { + return false; + } + + if !self.path.matches(path) { + return false; + } + + for header_match in &self.headers { + if !header_match.matches(headers) { + return false; + } + } + + true + } + + /// Returns the priority of this route. + pub fn priority(&self) -> RoutePriority { + self.path.priority() + } +} + +/// Result of a route match. +#[derive(Debug, Clone)] +pub struct RouteMatch { + /// The matched route. + pub route: Route, + /// Rewritten path (if applicable). + pub rewritten_path: Option, +} + +/// Router for L7 traffic routing. pub struct Router { - // L7 routing logic will go here + routes: Vec, + default_upstream: Option, +} + +impl Router { + /// Creates a new router with no routes. + pub fn new() -> Self { + Self { + routes: Vec::new(), + default_upstream: None, + } + } + + /// Creates a router with the given routes. + pub fn with_routes(routes: Vec) -> Self { + let mut router = Self { + routes, + default_upstream: None, + }; + router.sort_routes(); + router + } + + /// Sets the default upstream for unmatched routes. + pub fn with_default_upstream(mut self, upstream: impl Into) -> Self { + self.default_upstream = Some(upstream.into()); + self + } + + /// Adds a route to the router. + pub fn add_route(&mut self, route: Route) { + self.routes.push(route); + self.sort_routes(); + } + + /// Removes a route by name. + pub fn remove_route(&mut self, name: &str) -> Option { + if let Some(pos) = self.routes.iter().position(|r| r.name == name) { + Some(self.routes.remove(pos)) + } else { + None + } + } + + /// Sorts routes by priority (highest first). + fn sort_routes(&mut self) { + self.routes + .sort_by_key(|r| std::cmp::Reverse(r.priority())); + } + + /// Finds the matching route for a request. + pub fn route( + &self, + method: &http::Method, + path: &str, + headers: &http::HeaderMap, + ) -> Option { + for route in &self.routes { + if route.matches(method, path, headers) { + debug!( + route = %route.name, + upstream = %route.upstream, + "matched route" + ); + + let rewritten_path = route.rewrite.as_ref().map(|rewrite| { + // Simple rewrite: replace the matched prefix + if let PathMatch::Prefix { prefix } = &route.path { + path.replacen(prefix, rewrite, 1) + } else { + rewrite.clone() + } + }); + + return Some(RouteMatch { + route: route.clone(), + rewritten_path, + }); + } + } + + // Return default upstream if set + if let Some(upstream) = &self.default_upstream { + debug!(upstream = %upstream, "using default upstream"); + return Some(RouteMatch { + route: Route::new("default", PathMatch::prefix("/"), upstream.clone()), + rewritten_path: None, + }); + } + + debug!(path = %path, "no matching route found"); + None + } + + /// Returns all routes. + pub fn routes(&self) -> &[Route] { + &self.routes + } + + /// Returns the number of routes. + pub fn len(&self) -> usize { + self.routes.len() + } + + /// Returns true if there are no routes. + pub fn is_empty(&self) -> bool { + self.routes.is_empty() + } +} + +impl Default for Router { + fn default() -> Self { + Self::new() + } +} + +/// Upstream cluster definition. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpstreamCluster { + /// Cluster name. + pub name: String, + /// List of upstream endpoints. + pub endpoints: Vec, + /// Load balancing policy. + #[serde(default)] + pub load_balancing: LoadBalancingPolicy, + /// Health check configuration. + pub health_check: Option, +} + +/// Load balancing policy. +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum LoadBalancingPolicy { + /// Round-robin selection. + #[default] + RoundRobin, + /// Least connections. + LeastConnections, + /// Random selection. + Random, + /// Consistent hashing. + ConsistentHash, +} + +/// Health check configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HealthCheckConfig { + /// Health check interval. + pub interval_ms: u64, + /// Health check timeout. + pub timeout_ms: u64, + /// Path to check. + pub path: String, + /// Number of failures before marking unhealthy. + pub unhealthy_threshold: u32, + /// Number of successes before marking healthy. + pub healthy_threshold: u32, +} + +impl Default for HealthCheckConfig { + fn default() -> Self { + Self { + interval_ms: 10000, + timeout_ms: 5000, + path: "/health".to_string(), + unhealthy_threshold: 3, + healthy_threshold: 2, + } + } +} + +/// Routing configuration that can be loaded from file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoutingConfig { + /// List of routes. + pub routes: Vec, + /// Upstream clusters. + pub upstreams: HashMap, + /// Default upstream cluster name. + pub default_upstream: Option, +} + +impl RoutingConfig { + /// Loads configuration from a TOML string. + pub fn from_toml(content: &str) -> Result { + toml::from_str(content) + } + + /// Loads configuration from a JSON string. + pub fn from_json(content: &str) -> Result { + serde_json::from_str(content) + } + + /// Builds a router from this configuration. + pub fn build_router(&self) -> Router { + let mut router = Router::with_routes(self.routes.clone()); + if let Some(default) = &self.default_upstream { + router = router.with_default_upstream(default.clone()); + } + router + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::{HeaderMap, HeaderValue, Method}; + + #[test] + fn test_path_match_exact() { + let matcher = PathMatch::exact("/api/users"); + assert!(matcher.matches("/api/users")); + assert!(!matcher.matches("/api/users/")); + assert!(!matcher.matches("/api")); + } + + #[test] + fn test_path_match_prefix() { + let matcher = PathMatch::prefix("/api/"); + assert!(matcher.matches("/api/users")); + assert!(matcher.matches("/api/posts")); + assert!(!matcher.matches("/other")); + } + + #[test] + fn test_path_match_regex() { + let matcher = PathMatch::regex(r"^/api/users/\d+$"); + assert!(matcher.matches("/api/users/123")); + assert!(matcher.matches("/api/users/456")); + assert!(!matcher.matches("/api/users/abc")); + } + + #[test] + fn test_header_match_exact() { + let matcher = HeaderMatch::exact("content-type", "application/json"); + let mut headers = HeaderMap::new(); + headers.insert("content-type", HeaderValue::from_static("application/json")); + assert!(matcher.matches(&headers)); + + headers.insert("content-type", HeaderValue::from_static("text/plain")); + assert!(!matcher.matches(&headers)); + } + + #[test] + fn test_header_match_present() { + let matcher = HeaderMatch::present("authorization"); + let mut headers = HeaderMap::new(); + assert!(!matcher.matches(&headers)); + + headers.insert("authorization", HeaderValue::from_static("Bearer token")); + assert!(matcher.matches(&headers)); + } + + #[test] + fn test_route_matching() { + let route = Route::new("api-route", PathMatch::prefix("/api/"), "api-cluster") + .with_method(MethodMatch::Get) + .with_header(HeaderMatch::present("authorization")); + + let mut headers = HeaderMap::new(); + headers.insert("authorization", HeaderValue::from_static("Bearer token")); + + assert!(route.matches(&Method::GET, "/api/users", &headers)); + assert!(!route.matches(&Method::POST, "/api/users", &headers)); + + let empty_headers = HeaderMap::new(); + assert!(!route.matches(&Method::GET, "/api/users", &empty_headers)); + } + + #[test] + fn test_router_priority() { + let mut router = Router::new(); + + // Add routes in non-priority order + router.add_route(Route::new( + "prefix", + PathMatch::prefix("/api/"), + "prefix-cluster", + )); + router.add_route(Route::new( + "exact", + PathMatch::exact("/api/users"), + "exact-cluster", + )); + + let headers = HeaderMap::new(); + let result = router.route(&Method::GET, "/api/users", &headers); + + // Exact match should be selected + assert!(result.is_some()); + assert_eq!(result.unwrap().route.name, "exact"); + } + + #[test] + fn test_router_default_upstream() { + let router = Router::new().with_default_upstream("default-cluster"); + + let headers = HeaderMap::new(); + let result = router.route(&Method::GET, "/unmatched", &headers); + + assert!(result.is_some()); + assert_eq!(result.unwrap().route.upstream, "default-cluster"); + } + + #[test] + fn test_router_no_match() { + let router = Router::new(); + + let headers = HeaderMap::new(); + let result = router.route(&Method::GET, "/unmatched", &headers); + + assert!(result.is_none()); + } + + #[test] + fn test_route_rewrite() { + let route = + Route::new("rewrite", PathMatch::prefix("/old/"), "cluster").with_rewrite("/new/"); + + let mut router = Router::new(); + router.add_route(route); + + let headers = HeaderMap::new(); + let result = router.route(&Method::GET, "/old/path/to/resource", &headers); + + assert!(result.is_some()); + let route_match = result.unwrap(); + assert_eq!( + route_match.rewritten_path, + Some("/new/path/to/resource".to_string()) + ); + } } diff --git a/src/service.rs b/src/service.rs index cff96d5..0da2c11 100644 --- a/src/service.rs +++ b/src/service.rs @@ -12,7 +12,8 @@ use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use std::time::Instant; +use std::time::{Duration, Instant}; +use tokio::time::timeout; use tower::Service; use tracing::{debug, info, instrument, warn}; @@ -25,11 +26,13 @@ use tracing::{debug, info, instrument, warn}; /// ```no_run /// use rust_servicemesh::service::ProxyService; /// use std::sync::Arc; +/// use std::time::Duration; /// /// #[tokio::main] /// async fn main() { /// let upstream = "http://example.com:8080".to_string(); -/// let service = ProxyService::new(Arc::new(vec![upstream])); +/// let timeout = Duration::from_secs(30); +/// let service = ProxyService::new(Arc::new(vec![upstream]), timeout); /// } /// ``` #[derive(Clone)] @@ -37,6 +40,7 @@ pub struct ProxyService { upstream_addrs: Arc>, client: Client, next_upstream: Arc, + request_timeout: Duration, } impl ProxyService { @@ -45,12 +49,14 @@ impl ProxyService { /// # Arguments /// /// * `upstream_addrs` - List of upstream server addresses (e.g., "http://127.0.0.1:8080") - pub fn new(upstream_addrs: Arc>) -> Self { + /// * `request_timeout` - Maximum duration for upstream requests + pub fn new(upstream_addrs: Arc>, request_timeout: Duration) -> Self { let client = Client::builder(TokioExecutor::new()).build_http(); Self { upstream_addrs, client, next_upstream: Arc::new(std::sync::atomic::AtomicUsize::new(0)), + request_timeout, } } @@ -97,8 +103,8 @@ impl ProxyService { debug!("forwarding to upstream: {}", upstream_uri); *req.uri_mut() = upstream_uri; - match self.client.request(req).await { - Ok(response) => { + match timeout(self.request_timeout, self.client.request(req)).await { + Ok(Ok(response)) => { let status = response.status().as_u16(); let duration = start.elapsed().as_secs_f64(); @@ -116,7 +122,7 @@ impl ProxyService { let boxed_body = body.boxed(); Ok(Response::from_parts(parts, boxed_body)) } - Err(e) => { + Ok(Err(e)) => { warn!("upstream request failed: {}", e); let duration = start.elapsed().as_secs_f64(); Metrics::record_request(&method, 502, &upstream_owned, duration); @@ -125,6 +131,18 @@ impl ProxyService { "Upstream request failed", )) } + Err(_) => { + warn!( + "upstream request timed out after {:?}", + self.request_timeout + ); + let duration = start.elapsed().as_secs_f64(); + Metrics::record_request(&method, 504, &upstream_owned, duration); + Ok(Self::error_response( + StatusCode::GATEWAY_TIMEOUT, + "Upstream request timed out", + )) + } } } diff --git a/src/transport.rs b/src/transport.rs index c89bb34..42d6f29 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -1,4 +1,573 @@ -#[allow(dead_code)] +//! Connection pooling and load balancing for upstream connections. +//! +//! Provides efficient connection management with support for multiple +//! load balancing strategies and health-aware routing. + +use crate::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}; +use crate::router::LoadBalancingPolicy; +use dashmap::DashMap; +use parking_lot::RwLock; +use rand::Rng; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tracing::warn; + +/// Configuration for the connection pool. +#[derive(Debug, Clone)] +pub struct PoolConfig { + /// Maximum number of idle connections per host. + pub max_idle_per_host: usize, + /// Maximum total connections per host. + pub max_connections_per_host: usize, + /// Idle connection timeout. + pub idle_timeout: Duration, + /// Connection establishment timeout. + pub connect_timeout: Duration, + /// Enable HTTP/2 connection pooling. + pub http2_only: bool, +} + +impl Default for PoolConfig { + fn default() -> Self { + Self { + max_idle_per_host: 10, + max_connections_per_host: 100, + idle_timeout: Duration::from_secs(90), + connect_timeout: Duration::from_secs(10), + http2_only: false, + } + } +} + +impl PoolConfig { + /// Creates a new pool configuration. + pub fn new() -> Self { + Self::default() + } + + /// Sets the maximum idle connections per host. + pub fn with_max_idle(mut self, max: usize) -> Self { + self.max_idle_per_host = max; + self + } + + /// Sets the idle timeout. + pub fn with_idle_timeout(mut self, timeout: Duration) -> Self { + self.idle_timeout = timeout; + self + } + + /// Enables HTTP/2 only mode. + pub fn with_http2_only(mut self, http2: bool) -> Self { + self.http2_only = http2; + self + } +} + +/// Statistics for a single endpoint. +#[derive(Debug, Clone)] +pub struct EndpointStats { + /// Total number of requests sent. + pub total_requests: u64, + /// Number of successful requests. + pub successful_requests: u64, + /// Number of failed requests. + pub failed_requests: u64, + /// Current active connections. + pub active_connections: usize, + /// Average response time in milliseconds. + pub avg_response_time_ms: f64, + /// Whether the endpoint is healthy. + pub is_healthy: bool, +} + +/// A single upstream endpoint. +#[derive(Debug)] +pub struct Endpoint { + /// The endpoint address (e.g., "http://localhost:8080"). + pub address: String, + /// Current weight for weighted load balancing. + weight: AtomicU64, + /// Number of active connections. + active_connections: AtomicUsize, + /// Total request count. + total_requests: AtomicU64, + /// Successful request count. + successful_requests: AtomicU64, + /// Failed request count. + failed_requests: AtomicU64, + /// Sum of response times in microseconds. + total_response_time_us: AtomicU64, + /// Whether the endpoint is healthy. + healthy: RwLock, + /// Circuit breaker for this endpoint. + circuit_breaker: CircuitBreaker, + /// Last health check time (used for periodic health checks). + #[allow(dead_code)] + last_health_check: RwLock, +} + +impl Endpoint { + /// Creates a new endpoint. + pub fn new(address: impl Into) -> Self { + Self { + address: address.into(), + weight: AtomicU64::new(100), + active_connections: AtomicUsize::new(0), + total_requests: AtomicU64::new(0), + successful_requests: AtomicU64::new(0), + failed_requests: AtomicU64::new(0), + total_response_time_us: AtomicU64::new(0), + healthy: RwLock::new(true), + circuit_breaker: CircuitBreaker::new(CircuitBreakerConfig::default()), + last_health_check: RwLock::new(Instant::now()), + } + } + + /// Creates a new endpoint with the given weight. + pub fn with_weight(address: impl Into, weight: u64) -> Self { + let endpoint = Self::new(address); + endpoint.weight.store(weight, Ordering::Relaxed); + endpoint + } + + /// Returns the current weight. + pub fn weight(&self) -> u64 { + self.weight.load(Ordering::Relaxed) + } + + /// Sets the weight. + pub fn set_weight(&self, weight: u64) { + self.weight.store(weight, Ordering::Relaxed); + } + + /// Returns the number of active connections. + pub fn active_connections(&self) -> usize { + self.active_connections.load(Ordering::Relaxed) + } + + /// Increments the active connection count. + pub fn acquire_connection(&self) { + self.active_connections.fetch_add(1, Ordering::Relaxed); + } + + /// Decrements the active connection count. + pub fn release_connection(&self) { + self.active_connections.fetch_sub(1, Ordering::Relaxed); + } + + /// Returns whether the endpoint is healthy. + pub fn is_healthy(&self) -> bool { + *self.healthy.read() + } + + /// Sets the health status. + pub fn set_healthy(&self, healthy: bool) { + *self.healthy.write() = healthy; + } + + /// Records a successful request. + pub async fn record_success(&self, response_time: Duration) { + self.total_requests.fetch_add(1, Ordering::Relaxed); + self.successful_requests.fetch_add(1, Ordering::Relaxed); + self.total_response_time_us + .fetch_add(response_time.as_micros() as u64, Ordering::Relaxed); + self.circuit_breaker.record_success().await; + } + + /// Records a failed request. + pub async fn record_failure(&self) { + self.total_requests.fetch_add(1, Ordering::Relaxed); + self.failed_requests.fetch_add(1, Ordering::Relaxed); + self.circuit_breaker.record_failure().await; + } + + /// Checks if a request should be allowed through the circuit breaker. + pub async fn allow_request(&self) -> bool { + self.circuit_breaker.allow_request().await + } + + /// Returns statistics for this endpoint. + pub fn stats(&self) -> EndpointStats { + let total = self.total_requests.load(Ordering::Relaxed); + let total_time = self.total_response_time_us.load(Ordering::Relaxed); + let avg_time = if total > 0 { + (total_time as f64 / total as f64) / 1000.0 + } else { + 0.0 + }; + + EndpointStats { + total_requests: total, + successful_requests: self.successful_requests.load(Ordering::Relaxed), + failed_requests: self.failed_requests.load(Ordering::Relaxed), + active_connections: self.active_connections.load(Ordering::Relaxed), + avg_response_time_ms: avg_time, + is_healthy: *self.healthy.read(), + } + } +} + +/// Load balancer for distributing requests across endpoints. +pub struct LoadBalancer { + endpoints: Vec>, + policy: LoadBalancingPolicy, + next_index: AtomicUsize, +} + +impl LoadBalancer { + /// Creates a new load balancer. + pub fn new(endpoints: Vec>, policy: LoadBalancingPolicy) -> Self { + Self { + endpoints, + policy, + next_index: AtomicUsize::new(0), + } + } + + /// Creates a load balancer from endpoint addresses. + pub fn from_addresses(addresses: Vec, policy: LoadBalancingPolicy) -> Self { + let endpoints = addresses + .into_iter() + .map(|addr| Arc::new(Endpoint::new(addr))) + .collect(); + Self::new(endpoints, policy) + } + + /// Selects the next endpoint based on the load balancing policy. + pub async fn select(&self) -> Option> { + let healthy_endpoints: Vec<_> = self + .endpoints + .iter() + .filter(|e| e.is_healthy()) + .cloned() + .collect(); + + if healthy_endpoints.is_empty() { + warn!("no healthy endpoints available"); + return None; + } + + let endpoint = match self.policy { + LoadBalancingPolicy::RoundRobin => self.round_robin(&healthy_endpoints), + LoadBalancingPolicy::LeastConnections => self.least_connections(&healthy_endpoints), + LoadBalancingPolicy::Random => self.random(&healthy_endpoints), + LoadBalancingPolicy::ConsistentHash => { + // Fallback to round-robin for now + self.round_robin(&healthy_endpoints) + } + }; + + // Check circuit breaker + if let Some(ref ep) = endpoint { + if !ep.allow_request().await { + warn!(endpoint = %ep.address, "circuit breaker is open"); + // Try to find another endpoint + for e in &healthy_endpoints { + if e.address != ep.address && e.allow_request().await { + return Some(e.clone()); + } + } + return None; + } + } + + endpoint + } + + /// Round-robin selection. + fn round_robin(&self, endpoints: &[Arc]) -> Option> { + if endpoints.is_empty() { + return None; + } + let idx = self.next_index.fetch_add(1, Ordering::Relaxed) % endpoints.len(); + Some(endpoints[idx].clone()) + } + + /// Least connections selection. + fn least_connections(&self, endpoints: &[Arc]) -> Option> { + endpoints + .iter() + .min_by_key(|e| e.active_connections()) + .cloned() + } + + /// Random selection. + fn random(&self, endpoints: &[Arc]) -> Option> { + if endpoints.is_empty() { + return None; + } + let idx = rand::thread_rng().gen_range(0..endpoints.len()); + Some(endpoints[idx].clone()) + } + + /// Returns all endpoints. + pub fn endpoints(&self) -> &[Arc] { + &self.endpoints + } + + /// Returns the number of healthy endpoints. + pub fn healthy_count(&self) -> usize { + self.endpoints.iter().filter(|e| e.is_healthy()).count() + } + + /// Returns the total number of endpoints. + pub fn total_count(&self) -> usize { + self.endpoints.len() + } +} + +/// Connection pool statistics. +#[derive(Debug, Clone)] +pub struct PoolStats { + /// Total number of connections created. + pub connections_created: u64, + /// Total number of connections closed. + pub connections_closed: u64, + /// Current number of idle connections. + pub idle_connections: usize, + /// Current number of active connections. + pub active_connections: usize, +} + +/// Transport layer managing connection pools for upstream clusters. pub struct Transport { - // Connection pooling and load balancing will go here + /// Pool configuration. + config: PoolConfig, + /// Load balancers for each cluster. + clusters: DashMap>, + /// Connection statistics. + stats: Arc, +} + +/// Transport statistics. +struct TransportStats { + connections_created: AtomicU64, + connections_closed: AtomicU64, +} + +impl Transport { + /// Creates a new transport with the given configuration. + pub fn new(config: PoolConfig) -> Self { + Self { + config, + clusters: DashMap::new(), + stats: Arc::new(TransportStats { + connections_created: AtomicU64::new(0), + connections_closed: AtomicU64::new(0), + }), + } + } + + /// Creates a transport with default configuration. + pub fn with_defaults() -> Self { + Self::new(PoolConfig::default()) + } + + /// Adds a cluster with the given endpoints. + pub fn add_cluster( + &self, + name: impl Into, + endpoints: Vec, + policy: LoadBalancingPolicy, + ) { + let lb = LoadBalancer::from_addresses(endpoints, policy); + self.clusters.insert(name.into(), Arc::new(lb)); + } + + /// Gets a load balancer for a cluster. + pub fn get_cluster(&self, name: &str) -> Option> { + self.clusters.get(name).map(|r| r.clone()) + } + + /// Selects an endpoint from a cluster. + pub async fn select_endpoint(&self, cluster: &str) -> Option> { + if let Some(lb) = self.get_cluster(cluster) { + lb.select().await + } else { + warn!(cluster = %cluster, "cluster not found"); + None + } + } + + /// Returns the pool configuration. + pub fn config(&self) -> &PoolConfig { + &self.config + } + + /// Returns pool statistics. + pub fn stats(&self) -> PoolStats { + let mut active = 0; + + for cluster in self.clusters.iter() { + for endpoint in cluster.endpoints() { + active += endpoint.active_connections(); + } + } + + PoolStats { + connections_created: self.stats.connections_created.load(Ordering::Relaxed), + connections_closed: self.stats.connections_closed.load(Ordering::Relaxed), + idle_connections: 0, // TODO: implement idle connection tracking + active_connections: active, + } + } + + /// Records a connection being created. + pub fn record_connection_created(&self) { + self.stats + .connections_created + .fetch_add(1, Ordering::Relaxed); + } + + /// Records a connection being closed. + pub fn record_connection_closed(&self) { + self.stats + .connections_closed + .fetch_add(1, Ordering::Relaxed); + } +} + +impl Default for Transport { + fn default() -> Self { + Self::with_defaults() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_endpoint_basic() { + let endpoint = Endpoint::new("http://localhost:8080"); + assert!(endpoint.is_healthy()); + assert_eq!(endpoint.active_connections(), 0); + assert_eq!(endpoint.weight(), 100); + } + + #[test] + fn test_endpoint_with_weight() { + let endpoint = Endpoint::with_weight("http://localhost:8080", 50); + assert_eq!(endpoint.weight(), 50); + } + + #[test] + fn test_endpoint_connections() { + let endpoint = Endpoint::new("http://localhost:8080"); + endpoint.acquire_connection(); + endpoint.acquire_connection(); + assert_eq!(endpoint.active_connections(), 2); + + endpoint.release_connection(); + assert_eq!(endpoint.active_connections(), 1); + } + + #[tokio::test] + async fn test_endpoint_stats() { + let endpoint = Endpoint::new("http://localhost:8080"); + endpoint.record_success(Duration::from_millis(100)).await; + endpoint.record_success(Duration::from_millis(200)).await; + endpoint.record_failure().await; + + let stats = endpoint.stats(); + assert_eq!(stats.total_requests, 3); + assert_eq!(stats.successful_requests, 2); + assert_eq!(stats.failed_requests, 1); + assert!(stats.avg_response_time_ms > 0.0); + } + + #[tokio::test] + async fn test_load_balancer_round_robin() { + let lb = LoadBalancer::from_addresses( + vec![ + "http://host1:8080".to_string(), + "http://host2:8080".to_string(), + "http://host3:8080".to_string(), + ], + LoadBalancingPolicy::RoundRobin, + ); + + let ep1 = lb.select().await.unwrap(); + let ep2 = lb.select().await.unwrap(); + let ep3 = lb.select().await.unwrap(); + let ep4 = lb.select().await.unwrap(); + + // Should cycle through all endpoints + assert_ne!(ep1.address, ep2.address); + assert_ne!(ep2.address, ep3.address); + assert_eq!(ep1.address, ep4.address); // Back to first + } + + #[tokio::test] + async fn test_load_balancer_least_connections() { + let endpoints = vec![ + Arc::new(Endpoint::new("http://host1:8080")), + Arc::new(Endpoint::new("http://host2:8080")), + ]; + + // Add connections to host1 + endpoints[0].acquire_connection(); + endpoints[0].acquire_connection(); + + let lb = LoadBalancer::new(endpoints, LoadBalancingPolicy::LeastConnections); + + // Should select host2 (fewer connections) + let selected = lb.select().await.unwrap(); + assert_eq!(selected.address, "http://host2:8080"); + } + + #[tokio::test] + async fn test_load_balancer_unhealthy_skip() { + let endpoints = vec![ + Arc::new(Endpoint::new("http://host1:8080")), + Arc::new(Endpoint::new("http://host2:8080")), + ]; + + // Mark host1 as unhealthy + endpoints[0].set_healthy(false); + + let lb = LoadBalancer::new(endpoints, LoadBalancingPolicy::RoundRobin); + + // Should always select host2 + for _ in 0..5 { + let selected = lb.select().await.unwrap(); + assert_eq!(selected.address, "http://host2:8080"); + } + } + + #[test] + fn test_transport_add_cluster() { + let transport = Transport::with_defaults(); + transport.add_cluster( + "api", + vec![ + "http://api1:8080".to_string(), + "http://api2:8080".to_string(), + ], + LoadBalancingPolicy::RoundRobin, + ); + + let cluster = transport.get_cluster("api"); + assert!(cluster.is_some()); + assert_eq!(cluster.unwrap().total_count(), 2); + } + + #[tokio::test] + async fn test_transport_select_endpoint() { + let transport = Transport::with_defaults(); + transport.add_cluster( + "api", + vec!["http://api1:8080".to_string()], + LoadBalancingPolicy::RoundRobin, + ); + + let endpoint = transport.select_endpoint("api").await; + assert!(endpoint.is_some()); + assert_eq!(endpoint.unwrap().address, "http://api1:8080"); + + let missing = transport.select_endpoint("nonexistent").await; + assert!(missing.is_none()); + } } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 6f7a9b9..a01ad37 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -8,6 +8,7 @@ use hyper_util::rt::TokioExecutor; use hyper_util::rt::TokioIo; use std::convert::Infallible; use std::sync::Arc; +use std::time::Duration; use tokio::net::TcpListener; use tokio::sync::broadcast; @@ -18,6 +19,15 @@ async fn mock_upstream_handler(_req: Request) -> Result) -> Result, Infallible> { + // Simulate a slow upstream that takes longer than the timeout + tokio::time::sleep(Duration::from_secs(10)).await; + Ok(Response::builder() + .status(StatusCode::OK) + .body("slow response".to_string()) + .unwrap()) +} + async fn start_mock_upstream() -> String { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -40,14 +50,38 @@ async fn start_mock_upstream() -> String { format!("http://127.0.0.1:{}", addr.port()) } +async fn start_slow_upstream() -> String { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + loop { + let (stream, _) = match listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + + tokio::spawn(async move { + let io = TokioIo::new(stream); + let service = service_fn(slow_upstream_handler); + let _ = http1::Builder::new().serve_connection(io, service).await; + }); + } + }); + + format!("http://127.0.0.1:{}", addr.port()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_proxy_basic_request() { let upstream_addr = start_mock_upstream().await; let upstream_addrs = Arc::new(vec![upstream_addr]); + let timeout = Duration::from_secs(30); - let listener = rust_servicemesh::listener::Listener::bind("127.0.0.1:0", upstream_addrs) - .await - .unwrap(); + let listener = + rust_servicemesh::listener::Listener::bind("127.0.0.1:0", upstream_addrs, timeout) + .await + .unwrap(); let proxy_addr = listener.local_addr(); let (shutdown_tx, shutdown_rx) = broadcast::channel(1); @@ -76,10 +110,12 @@ async fn test_proxy_round_robin() { let upstream1 = start_mock_upstream().await; let upstream2 = start_mock_upstream().await; let upstream_addrs = Arc::new(vec![upstream1, upstream2]); + let timeout = Duration::from_secs(30); - let listener = rust_servicemesh::listener::Listener::bind("127.0.0.1:0", upstream_addrs) - .await - .unwrap(); + let listener = + rust_servicemesh::listener::Listener::bind("127.0.0.1:0", upstream_addrs, timeout) + .await + .unwrap(); let proxy_addr = listener.local_addr(); let (shutdown_tx, shutdown_rx) = broadcast::channel(1); @@ -104,3 +140,53 @@ async fn test_proxy_round_robin() { let _ = shutdown_tx.send(()); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_proxy_timeout_enforcement() { + // Start a slow upstream that takes 10 seconds to respond + let slow_upstream = start_slow_upstream().await; + let upstream_addrs = Arc::new(vec![slow_upstream]); + + // Set a short timeout (1 second) + let timeout = Duration::from_secs(1); + + let listener = + rust_servicemesh::listener::Listener::bind("127.0.0.1:0", upstream_addrs, timeout) + .await + .unwrap(); + + let proxy_addr = listener.local_addr(); + let (shutdown_tx, shutdown_rx) = broadcast::channel(1); + + tokio::spawn(async move { + let _ = listener.serve(shutdown_rx).await; + }); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let client: Client<_, Empty> = Client::builder(TokioExecutor::new()).build_http(); + let uri = format!("http://{}/test", proxy_addr); + + let start = std::time::Instant::now(); + let req = Request::builder() + .uri(uri) + .body(Empty::::new()) + .unwrap(); + let response = client.request(req).await.unwrap(); + let elapsed = start.elapsed(); + + // Should get 504 Gateway Timeout + assert_eq!(response.status(), StatusCode::GATEWAY_TIMEOUT); + + // Should timeout in approximately 1 second, not 10 + assert!( + elapsed < Duration::from_secs(2), + "Request should timeout quickly" + ); + assert!( + elapsed >= Duration::from_secs(1), + "Request should wait for timeout" + ); + + let _ = shutdown_tx.send(()); +}