diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index d68b1b7..2bb143d 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -3,7 +3,43 @@ on: push: branches: - main + workflow_dispatch: + inputs: + bump: + type: choice + options: + - patch + - minor + - major + default: patch jobs: + bump: + name: bump-version + if: ${{ github.event_name == 'workflow_dispatch' }} + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + - name: Bump version + run: | + uvx hatch version ${{ inputs.bump }} + - name: Commit and tag + run: | + git config user.name "rhymiz" + git config user.email "lemuelboyce@gmail.com" + VERSION="$(python -c "from mas.__version__ import __version__; print(__version__)")" + git add src/mas/__version__.py + git commit -m "chore(release): v$VERSION" + git tag "v$VERSION" + git push --follow-tags publish: name: python runs-on: ubuntu-latest diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index b4171d2..82b8e9b 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -1008,17 +1008,15 @@ Additional keys used by gateway: | `message_nonces:{id}` | String | Nonce replay protection | | `priority_queue:{target}:{priority}` | Sorted Set | Priority queues | -### When to Use the Gateway +### Gateway Use Cases -**Use Gateway**: +**Typical use cases**: - Regulated industry (finance, healthcare, government) - Handling sensitive data (PII, PHI, PCI) - Compliance requirements (SOC2, HIPAA, GDPR) - Multi-tenant with strict isolation - Zero-trust security architecture required -P2P is not supported. - ### Gateway Performance **Throughput** (single gateway instance): @@ -1098,10 +1096,10 @@ per_minute: int = 100 # Messages per minute per agent per_hour: int = 1000 # Messages per hour per agent ``` -**FeaturesSettings** (secure-by-default): +**FeaturesSettings** (secure-by-default, queues opt-in): ```python dlp: bool = True # Data Loss Prevention scanning -priority_queue: bool = True # Message priority routing +priority_queue: bool = False # Message priority routing rbac: bool = True # Role-Based Access Control message_signing: bool = True # HMAC message signatures circuit_breaker: bool = True # Circuit breaker for reliability @@ -1175,7 +1173,7 @@ rate_limit: features: dlp: true - priority_queue: true + priority_queue: false rbac: true message_signing: true circuit_breaker: true diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 89e4e5a..c366cfa 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -49,8 +49,8 @@ A multi-agent system consists of multiple autonomous software agents that: - **Distributed AI Systems**: Multiple AI agents collaborating on complex tasks - **Microservices Coordination**: Services discovering and messaging each other - **Workflow Orchestration**: Agents coordinating multi-step processes -- **Healthcare Systems**: HIPAA-compliant agent communication (gateway mode) -- **Financial Services**: SOC2/PCI-compliant agent interactions (gateway mode) +- **Healthcare Systems**: HIPAA-compliant agent communication through the gateway +- **Financial Services**: SOC2/PCI-compliant agent interactions through the gateway - **Educational Platforms**: Tutoring systems with multiple specialized agents --- @@ -649,11 +649,10 @@ async def main(): gateway = GatewayService() # Uses default redis://localhost:6379 await gateway.start() - # 2. Create agent with gateway mode enabled + # 2. Create agent (gateway routing is standard) agent = Agent( "my_agent", capabilities=["worker"], - use_gateway=True # Enable gateway mode ) # 3. Connect agent to gateway @@ -723,7 +722,7 @@ Agents authenticate using tokens: ```python # Tokens are automatically generated on registration -agent = Agent("my_agent", use_gateway=True) +agent = Agent("my_agent") await agent.start() # Token stored in agent._token # Agent token is automatically included in messages @@ -1226,10 +1225,9 @@ def process_file(self, path: str): Route sensitive operations through gateway: ```python -# Create agents with gateway mode for sensitive data +# Create agents for sensitive data payment_agent = Agent( "payment_processor", - use_gateway=True, # Enable security features capabilities=["payment"] ) ``` @@ -1283,15 +1281,49 @@ brew services start redis docker run -d -p 6379:6379 redis:latest ``` -Start your agents: +Start MAS (recommended): + +The definitive way to run MAS is the built-in runner. It starts the MAS service, +then brings up all configured agents automatically. + +### Agent Runner (Config-Driven) + +Define agent instances in `agents.yaml` (required) and start them with the +built-in runner. MAS routes all messages through the gateway. +by default. + +`agents.yaml`: + +```yaml +agents: + - agent_id: worker_agent + class_path: my_app.agents.worker:WorkerAgent + instances: 3 + init_kwargs: + capabilities: ["worker"] + redis_url: redis://localhost:6379 + batch_size: 10 +start_service: true +service_redis_url: redis://localhost:6379 +start_gateway: true +gateway_config_file: gateway.yaml +``` + +`start_gateway` must remain `true`. + +Run MAS (auto-loads `agents.yaml` if present): ```bash -# Single agent -uv run python my_agent.py +uv run python -m mas +``` + +The runner searches upward from the current working directory to find the +nearest `agents.yaml` file. -# Multiple agents -uv run python agent_a.py & -uv run python agent_b.py & +Override the config file: + +```bash +MAS_RUNNER_CONFIG_FILE=./agents.yaml uv run python -m mas ``` ### Production Deployment @@ -1391,7 +1423,6 @@ import os agent = Agent( agent_id=os.getenv("AGENT_ID", "default_agent"), redis_url=os.getenv("REDIS_URL", "redis://localhost:6379"), - use_gateway=os.getenv("USE_GATEWAY", "false").lower() == "true" ) ``` @@ -1745,8 +1776,8 @@ class Agent: capabilities: list[str] | None = None, redis_url: str = "redis://localhost:6379", state_model: type[BaseModel] | None = None, - use_gateway: bool = False, - gateway_url: str | None = None + reclaim_idle_ms: int = 30000, + reclaim_batch_size: int = 50 ) ``` @@ -1755,14 +1786,16 @@ class Agent: - `capabilities`: List of capability tags for discovery - `redis_url`: Redis connection URL - `state_model`: Optional Pydantic model for typed state -- `use_gateway`: Enable gateway mode (default: False) -- `gateway_url`: Gateway service URL (if different from redis_url) +- `reclaim_idle_ms`: Idle time in milliseconds before reclaiming pending messages +- `reclaim_batch_size`: Max pending messages reclaimed per reclaim cycle + +Set `reclaim_idle_ms=0` to disable pending-message reclamation. **Properties:** - `id: str` - Agent identifier - `capabilities: list[str]` - Agent capabilities - `state: dict | BaseModel` - Current agent state -- `token: str | None` - Authentication token (gateway mode) +- `token: str | None` - Authentication token **Methods:** @@ -1870,7 +1903,7 @@ await agent.reset_state() ``` #### `set_gateway(gateway: GatewayService) -> None` -Set gateway instance for message routing (required for gateway mode). +Set gateway instance for message routing (optional, managed externally). ```python gateway = GatewayService() # Uses default redis://localhost:6379 diff --git a/GATEWAY.md b/GATEWAY.md index d7b5eed..bf76a34 100644 --- a/GATEWAY.md +++ b/GATEWAY.md @@ -3,14 +3,12 @@ ## Table of Contents - [Overview](#overview) - [Design Rationale](#design-rationale) -- [Architecture Comparison](#architecture-comparison) - [Core Components](#core-components) - [Message Flow](#message-flow) - [Security Model](#security-model) - [Audit & Compliance](#audit--compliance) - [Performance Characteristics](#performance-characteristics) - [Deployment Architecture](#deployment-architecture) -- [Migration Strategy](#migration-strategy) - [Implementation Roadmap](#implementation-roadmap) ## Overview @@ -31,18 +29,6 @@ This architectural approach provides: - **Operational control**: Circuit breakers, rate limiting, priority queues - **Compliance ready**: SOC2, HIPAA, GDPR, PCI-DSS compatible -### Trade-offs vs Pure P2P - -| Aspect | Pure P2P | Gateway Pattern | -|--------|----------|-----------------| -| **Latency** | <5ms (P50) | 10-20ms (P50) | -| **Throughput** | 10,000+ msg/s | 5,000 msg/s (single), 20,000+ (clustered) | -| **Audit Trail** | Optional, async | Complete, guaranteed | -| **Security** | Client-side | Server-side enforcement | -| **Reliability** | At-most-once | At-least-once | -| **Compliance** | Limited | Full support | -| **Operational Control** | Distributed | Centralized | - ## Design Rationale ### Why Enterprises Need Gateway Pattern @@ -56,9 +42,7 @@ Regulations require: - **PCI-DSS**: Credit card data must be detected and blocked - **FINRA**: Financial communications must be retained 7+ years -**Pure P2P limitation**: Cannot guarantee all messages are audited if agents bypass logging. - -**Gateway solution**: Single enforcement point ensures 100% audit coverage. +Single enforcement point ensures 100% audit coverage. **2. Zero-Trust Security** @@ -68,9 +52,7 @@ Enterprise security mandates: - Principle of least privilege - Defense in depth -**Pure P2P limitation**: Compromised agent can send arbitrary messages. - -**Gateway solution**: Centralized authentication, authorization, and validation. +Centralized authentication, authorization, and validation. **3. Operational Requirements** @@ -81,126 +63,7 @@ Enterprise operations need: - Traffic shaping and priority queues - Gradual rollouts and canary deployments -**Pure P2P limitation**: Distributed control plane, harder to monitor and react. - -**Gateway solution**: Single control plane with operational levers. - -### When to Use Gateway Pattern - -**Use Gateway if**: -- ✅ Regulated industry (finance, healthcare, government) -- ✅ Handling sensitive data (PII, PHI, PCI) -- ✅ SOC2/ISO27001/HIPAA compliance required -- ✅ Multi-tenant with strict isolation -- ✅ Need complete audit trail for legal/regulatory -- ✅ Security team requires zero-trust architecture - -**Use Pure P2P if**: -- ✅ Internal tools, trusted environment -- ✅ Performance is critical (high-frequency trading, gaming) -- ✅ No regulatory requirements -- ✅ Startup/rapid iteration phase - -### Hybrid Approach - -For organizations transitioning or with mixed requirements: -- **P2P for internal agents** (trusted, high-performance) -- **Gateway for external agents** (untrusted, audited) -- **Gateway for sensitive operations** (payment, PHI access) - -## Architecture Comparison - -### Pure P2P Architecture (Current) - -``` -┌─────────┐ ┌─────────┐ -│ Agent A │────────────────────────────────────│ Agent B │ -└─────────┘ Redis Pub/Sub └─────────┘ - │ (direct channel) │ - │ │ - └───────────────────┐ ┌───────────────────┘ - ↓ ↓ - ┌────────────────┐ - │ MAS Service │ - │ (optional) │ - │ - Registry │ - │ - Discovery │ - │ - Health │ - └────────────────┘ -``` - -**Characteristics**: -- Direct agent-to-agent communication -- No message inspection or validation -- Optional async audit logging -- High throughput, low latency -- Client-side security enforcement - -### Gateway Architecture (Enterprise) - -``` -┌─────────┐ ┌─────────┐ -│ Agent A │ │ Agent B │ -└────┬────┘ └────▲────┘ - │ │ - │ 1. Send message │ 5. Consume - │ (with token) │ (with ACK) - ↓ │ -┌─────────────────────────────────────────────────────────┐ -│ Gateway Service │ -│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ -│ │ Auth/Authz │→ │ DLP Scanner │→ │ Rate Limiter │ │ -│ └──────────────┘ └──────────────┘ └──────────────┘ │ -│ ↓ │ -│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ -│ │ Audit Logger │ │ Circuit │ │ Priority │ │ -│ │ (Streams) │ │ Breaker │ │ Queue │ │ -│ └──────────────┘ └──────────────┘ └──────────────┘ │ -└──────────────────────────┬──────────────────────────────┘ - │ 2. Validate - │ 3. Audit log - │ 4. Publish - ↓ - ┌──────────────┐ - │ Redis Streams│ - │ agent.stream:│ - │ {target_id}│ - └──────────────┘ -``` - -**Characteristics**: -- Centralized validation and control -- Complete message inspection and audit -- Server-side security enforcement -- Reliable delivery (at-least-once) -- Operational control levers - -### Hybrid Architecture (Recommended for Migration) - -``` - ┌──────────────────┐ - │ Gateway Service │ - │ (with feature │ - │ flags) │ - └────────┬─────────┘ - │ - ┌──────────────┼──────────────┐ - │ │ │ - External/ Internal/ Sensitive - Untrusted Trusted Operations - │ │ │ - ↓ ↓ ↓ - ┌────────────┐ ┌────────────┐ ┌────────────┐ - │ Gateway │ │ Pure P2P │ │ Gateway │ - │ (full) │ │ (fast) │ │ (audit) │ - └────────────┘ └────────────┘ └────────────┘ -``` - -**Route by**: -- Agent trust level (internal vs external) -- Message sensitivity (public vs PHI/PCI) -- Operation type (read vs write) -- Compliance requirements (audit vs no-audit) +Single control plane with operational levers. ## Core Components @@ -1252,102 +1115,6 @@ spec: - Capacity (CPU, memory, storage) - Business KPIs (message volume, active agents) -## Migration Strategy - -### Phase 1: Preparation (Month 1-2) - -**Objectives**: -- Build gateway service -- Implement core features (auth, audit) -- Deploy in shadow mode (observe, don't block) - -**Tasks**: -1. Develop gateway service (auth, authz, audit modules) -2. Set up infrastructure (K8s, Redis, monitoring) -3. Deploy gateway in "observe" mode -4. Agents send to both P2P and gateway -5. Compare behavior, tune configurations -6. Validate audit log completeness - -**Success Criteria**: -- Gateway handles 100% of message volume in shadow mode -- <5% error rate in gateway -- Audit logs match P2P messages (99%+ coverage) - -### Phase 2: Soft Launch (Month 3-4) - -**Objectives**: -- Route non-critical traffic through gateway -- Enable enforcement for test agents -- Validate security features - -**Tasks**: -1. Select pilot agents (internal, low-risk) -2. Route pilot traffic through gateway (enforce mode) -3. Enable DLP, rate limiting, circuit breakers -4. Monitor for issues (latency, errors) -5. Collect feedback from pilot users -6. Tune configurations based on feedback - -**Success Criteria**: -- Pilot agents operate normally through gateway -- P95 latency <30ms -- No false-positive DLP violations -- Zero incidents from pilot agents - -### Phase 3: Gradual Rollout (Month 5-6) - -**Objectives**: -- Migrate all agents to gateway -- Deprecate direct P2P -- Achieve full enforcement - -**Rollout Strategy**: -- Week 1: 10% of agents -- Week 2: 25% of agents -- Week 3: 50% of agents -- Week 4: 75% of agents -- Week 5: 90% of agents -- Week 6: 100% of agents - -**Feature Flags**: -- Gateway routing (per-agent toggle) -- Enforcement mode (observe vs enforce) -- DLP scanning (enabled/disabled) -- Rate limiting (thresholds per agent) - -**Rollback Plan**: -- If error rate > 1%: Pause rollout -- If P99 latency > 100ms: Rollback 50% -- If critical incident: Full rollback to P2P - -**Success Criteria**: -- 100% of agents using gateway -- P2P channels deprecated -- Audit coverage 100% -- SLA met (99.9% uptime) - -### Phase 4: Optimization (Month 7+) - -**Objectives**: -- Optimize performance -- Add advanced features -- Achieve compliance certifications - -**Tasks**: -1. Performance tuning (caching, batching) -2. Add RBAC support (beyond ACL) -3. Implement message signing -4. Add anomaly detection -5. Obtain SOC2 audit -6. Document for compliance - -**Success Criteria**: -- P95 latency <20ms -- Throughput 10,000+ msg/s (single gateway) -- SOC2 Type II certified -- HIPAA/GDPR compliance documented - ## Implementation Roadmap ### MVP (Minimum Viable Product) - 3 Months @@ -1424,11 +1191,11 @@ The Gateway Pattern transforms the MAS Framework into an enterprise-ready platfo 4. **Control**: Rate limiting, priority queues, monitoring 5. **Scalability**: Horizontal scaling, proven at scale -**Trade-offs**: -- 2-4x latency increase (5ms → 10-20ms) -- 50-70% throughput reduction (single gateway) -- Increased infrastructure cost (3-4x) -- Additional operational complexity +**Operational Considerations**: +- Added latency from validation and audit processing +- Throughput depends on gateway scaling and Redis capacity +- Increased infrastructure footprint for security controls +- Additional operational complexity (monitoring, policies, keys) **When to Adopt**: - Regulated industries requiring compliance @@ -1437,12 +1204,5 @@ The Gateway Pattern transforms the MAS Framework into an enterprise-ready platfo - Organizations prioritizing security over raw performance **Recommended Approach**: -Start with Pure P2P for rapid iteration, migrate to Gateway when: -- Product-market fit achieved -- Enterprise customers require compliance -- Security incidents motivate investment -- Traffic volume justifies infrastructure cost - -The Gateway Pattern is not a replacement for P2P—it's an evolution for enterprise requirements. - - +Adopt the gateway architecture from the start to ensure consistent security, +auditability, and operational control across all environments. diff --git a/README.md b/README.md index d7d58cb..3f447e6 100644 --- a/README.md +++ b/README.md @@ -143,18 +143,15 @@ if __name__ == "__main__": Messages are routed through a centralized gateway for security and compliance: -Messages routed through centralized gateway for security and compliance: - ```python import asyncio from mas import Agent async def main(): - # Enable gateway mode - agent = Agent("my_agent", use_gateway=True) + agent = Agent("my_agent") await agent.start() - # Messages now routed through gateway + # Messages are routed through the gateway # Gateway provides: auth, authz, rate limiting, DLP, audit await agent.send("target_agent", "test.message", {"data": "hello"}) @@ -176,8 +173,6 @@ if __name__ == "__main__": See [GATEWAY.md](GATEWAY.md) for complete gateway documentation. -Gateway mode is the default and recommended for production deployments. - ## Features ### Auto-Persisted State @@ -484,20 +479,35 @@ Performance benchmarks are planned for future releases. - **[Architecture Guide](ARCHITECTURE.md)** - Architecture, design decisions, and implementation details - **[Gateway Guide](GATEWAY.md)** - Enterprise gateway pattern with security, audit, and compliance features -- **[API Reference](#messaging-modes)** - Feature documentation and usage examples +- **[API Reference](#messaging)** - Feature documentation and usage examples + +## Getting Started (Recommended) + +The definitive way to run MAS is the built-in runner. It starts the MAS service +and gateway, then brings up all configured agents automatically. + +1) Define agents in `agents.yaml` (required) +2) Run MAS + +```bash +uv run python -m mas +``` + +The runner searches upward from the current working directory to find +the nearest `agents.yaml`. + +The runner starts the MAS service and gateway before agents and stops them last. ### Quick Architecture Overview **Core Components:** -- **MAS Service** - Agent registry and health monitor (optional) +- **MAS Service** - Agent registry and health monitor - **Agent** - Base class for implementing agents -- **Gateway Service** - Optional security/audit layer for enterprise deployments +- **Gateway Service** - Central security and audit layer - **Registry** - Agent discovery by capabilities - **State Manager** - State persistence to Redis **Message Flow:** - -Gateway Mode: ``` Agent A → Gateway Service → Redis Streams → Agent B (validation) (reliable delivery) @@ -507,7 +517,7 @@ Agent A → Gateway Service → Redis Streams → Agent B - `agent:{id}` - Agent metadata - `agent:{id}:heartbeat` - Health monitoring (60s TTL) - `agent.state:{id}` - Persisted agent state -- `agent.stream:{id}` - Message stream (gateway mode) +- `agent.stream:{id}` - Message stream - `mas.system` - System events (pub/sub) For detailed architecture information, see [ARCHITECTURE.md](ARCHITECTURE.md). @@ -527,9 +537,9 @@ For detailed architecture information, see [ARCHITECTURE.md](ARCHITECTURE.md). - ✅ Heartbeat monitoring ### In Development -- [ ] Priority queue for gateway mode +- [ ] Priority queue for delivery - [ ] Enhanced metrics and observability -- [ ] Performance benchmarks for both modes +- [ ] Performance benchmarks - [ ] Prometheus metrics integration - [ ] Management dashboard diff --git a/agents.yaml.example b/agents.yaml.example new file mode 100644 index 0000000..80ad786 --- /dev/null +++ b/agents.yaml.example @@ -0,0 +1,13 @@ +agents: + - agent_id: worker_agent + class_path: my_app.agents.worker:WorkerAgent + instances: 3 + init_kwargs: + capabilities: + - worker + redis_url: redis://localhost:6379 + batch_size: 10 +start_service: true +service_redis_url: redis://localhost:6379 +start_gateway: true +gateway_config_file: gateway.yaml diff --git a/src/mas/__init__.py b/src/mas/__init__.py index aae51cb..09998a8 100644 --- a/src/mas/__init__.py +++ b/src/mas/__init__.py @@ -6,6 +6,7 @@ from .state import StateManager, StateType from .protocol import EnvelopeMessage as Message, MessageType, MessageMeta from .__version__ import __version__ +from .runner import AgentRunner, RunnerSettings, load_runner_settings __all__ = [ "Agent", @@ -18,5 +19,8 @@ "Message", "MessageType", "MessageMeta", + "AgentRunner", + "RunnerSettings", + "load_runner_settings", "__version__", ] diff --git a/src/mas/__main__.py b/src/mas/__main__.py index e69de29..679772a 100644 --- a/src/mas/__main__.py +++ b/src/mas/__main__.py @@ -0,0 +1,6 @@ +"""MAS entrypoint.""" + +from .runner import run + +if __name__ == "__main__": + run() diff --git a/src/mas/agent.py b/src/mas/agent.py index 8cc2c09..8f2a145 100644 --- a/src/mas/agent.py +++ b/src/mas/agent.py @@ -7,7 +7,6 @@ import hmac import json import logging -import os import time import uuid from dataclasses import dataclass @@ -51,10 +50,18 @@ class Agent(Generic[StateType]): Key features: - Self-contained (only needs Redis URL) - Gateway-mediated messaging (central routing and policy enforcement) - - Auto-persisted state to Redis + - Auto-persisted state to Redis (shared across instances) - Simple discovery by capabilities - - Automatic heartbeat monitoring + - Automatic heartbeat monitoring (per-instance) - Strongly-typed state via generics + - Multi-instance support for horizontal scaling + + Multi-instance behavior: + - Each Agent instance gets a unique instance_id + - Multiple instances with the same agent_id share the workload + - Messages are load-balanced across instances via Redis consumer groups + - Request-response replies are routed to the originating instance + - State is shared across all instances of the same agent_id Usage with typed state and decorator-based handlers: class MyState(BaseModel): @@ -69,6 +76,13 @@ async def handle_increment(self, message: AgentMessage, payload: None): # self.state is strongly typed as MyState self.state.counter += 1 await self.update_state({"counter": self.state.counter}) + + Horizontal scaling: + # Run multiple instances of the same agent for parallel processing + # Each instance automatically joins the consumer group + python my_agent.py & # Instance 1 + python my_agent.py & # Instance 2 + python my_agent.py & # Instance 3 """ def __init__( @@ -77,26 +91,32 @@ def __init__( capabilities: list[str] | None = None, redis_url: str = "redis://localhost:6379", state_model: type[StateType] | None = None, - use_gateway: bool = False, - gateway_url: Optional[str] = None, + reclaim_idle_ms: int = 30_000, + reclaim_batch_size: int = 50, ): """ Initialize agent. Args: - agent_id: Unique agent identifier + agent_id: Logical agent identifier (shared across instances for scaling) capabilities: List of agent capabilities for discovery redis_url: Redis connection URL state_model: Optional Pydantic model for typed state. If provided, self.state will be strongly typed. - use_gateway: Whether to route messages through gateway - gateway_url: Gateway service URL (if different from redis_url) + reclaim_idle_ms: Idle time (ms) before reclaiming pending messages. + reclaim_batch_size: Max pending messages reclaimed per cycle. + + Note: + Each Agent instance automatically generates a unique instance_id. + Multiple instances with the same agent_id will share workload via + Redis consumer groups. """ self.id = agent_id + self.instance_id = uuid.uuid4().hex[:8] # Unique per instance self.capabilities = capabilities or [] self.redis_url = redis_url - self.use_gateway = use_gateway - self.gateway_url = gateway_url or redis_url + self._reclaim_idle_ms = reclaim_idle_ms + self._reclaim_batch_size = reclaim_batch_size # Internal state self._redis: Optional[AsyncRedisProtocol] = None @@ -112,7 +132,7 @@ def __init__( self._state_manager: Optional[StateManager[StateType]] = None self._state_model: type[StateType] | None = state_model - # Gateway client (if use_gateway=True) + # Gateway client (optional, managed externally) self._gateway: Optional["GatewayService"] = None # Request-response tracking @@ -139,15 +159,17 @@ def token(self) -> Optional[str]: return self._token async def start(self) -> None: - """Start the agent.""" + """Start the agent instance.""" redis_client = create_redis_client(url=self.redis_url, decode_responses=True) self._redis = redis_client self._registry = AgentRegistry(redis_client) - # Register agent + # Register agent instance (idempotent - first instance creates, others join) self._token = await self._registry.register( - self.id, self.capabilities, metadata=self.get_metadata() + self.id, self.instance_id, self.capabilities, metadata=self.get_metadata() ) + # Seed heartbeat immediately to avoid inactive status on fast restarts + await self._registry.update_heartbeat(self.id, self.instance_id) # Initialize state manager self._state_manager = StateManager( @@ -156,6 +178,7 @@ async def start(self) -> None: await self._state_manager.load() # Ensure delivery stream consumer group exists and start consumer loop + # Shared stream for load-balanced message delivery delivery_stream = f"agent.stream:{self.id}" try: await redis_client.xgroup_create( @@ -165,25 +188,56 @@ async def start(self) -> None: if "BUSYGROUP" not in str(e): raise + # Instance-specific stream for replies to this instance's requests + # This ensures request-response works correctly with multi-instance agents + # Uses the same consumer group name as shared stream for xreadgroup compatibility + reply_stream = f"agent.stream:{self.id}:{self.instance_id}" + try: + await redis_client.xgroup_create( + reply_stream, "agents", id="$", mkstream=True + ) + except Exception as e: + if "BUSYGROUP" not in str(e): + raise + self._running = True # Start background tasks self._tasks.append(asyncio.create_task(self._stream_loop())) self._tasks.append(asyncio.create_task(self._heartbeat_loop())) - # Publish registration event + # Publish instance join event + instance_count = await self._registry.get_instance_count(self.id) + if instance_count == 1: + # First instance - publish REGISTER event + await redis_client.publish( + "mas.system", + json.dumps( + { + "type": "REGISTER", + "agent_id": self.id, + "capabilities": self.capabilities, + } + ), + ) + + # Always publish INSTANCE_JOIN for observability await redis_client.publish( "mas.system", json.dumps( { - "type": "REGISTER", + "type": "INSTANCE_JOIN", "agent_id": self.id, - "capabilities": self.capabilities, + "instance_id": self.instance_id, + "instance_count": instance_count, } ), ) - logger.info("Agent started", extra={"agent_id": self.id}) + logger.info( + "Agent instance started", + extra={"agent_id": self.id, "instance_id": self.instance_id}, + ) # Signal that transport can begin (registration + subscriptions established) self._transport_ready.set() @@ -192,35 +246,50 @@ async def start(self) -> None: await self.on_start() async def stop(self) -> None: - """Stop the agent.""" + """Stop the agent instance.""" self._running = False # Call user hook await self.on_stop() - # Publish deregistration event - if self._redis: - await self._redis.publish( - "mas.system", - json.dumps( - { - "type": "DEREGISTER", - "agent_id": self.id, - } - ), - ) - # Cancel tasks for task in self._tasks: task.cancel() await asyncio.gather(*self._tasks, return_exceptions=True) - # Cleanup + # Deregister instance and check if last if self._registry: - await self._registry.deregister(self.id) + # Get instance count before deregistering + instance_count_before = await self._registry.get_instance_count(self.id) + await self._registry.deregister(self.id, self.instance_id) + is_last_instance = instance_count_before <= 1 + + # Publish events + if self._redis: + # Always publish INSTANCE_LEAVE for observability + await self._redis.publish( + "mas.system", + json.dumps( + { + "type": "INSTANCE_LEAVE", + "agent_id": self.id, + "instance_id": self.instance_id, + } + ), + ) - # No pubsub to close in streams mode + if is_last_instance: + # Last instance - publish DEREGISTER event + await self._redis.publish( + "mas.system", + json.dumps( + { + "type": "DEREGISTER", + "agent_id": self.id, + } + ), + ) # Note: Don't stop gateway - it's shared across agents # Gateway lifecycle is managed externally @@ -228,7 +297,10 @@ async def stop(self) -> None: if self._redis: await self._redis.aclose() - logger.info("Agent stopped", extra={"agent_id": self.id}) + logger.info( + "Agent instance stopped", + extra={"agent_id": self.id, "instance_id": self.instance_id}, + ) def set_gateway(self, gateway: "GatewayService") -> None: """ @@ -271,25 +343,36 @@ async def _send_envelope(self, message: AgentMessage) -> None: raise RuntimeError("Agent not started") if not self._token: raise RuntimeError("No token available for gateway authentication") - signing_key = os.getenv("SIGNING_KEY") - ts = str(int(time.time())) - nonce = str(uuid.uuid4()) + signing_key = await self._get_signing_key() + if not signing_key: + raise RuntimeError( + "No signing key available for agent; cannot sign message" + ) + ts = time.time() + nonce = uuid.uuid4().hex envelope_json = message.model_dump_json() fields: dict[str, str] = { "envelope": envelope_json, "agent_id": self.id, "token": self._token, + "timestamp": str(ts), + "nonce": nonce, + } + + signature_data = { + "envelope": envelope_json, "timestamp": ts, "nonce": nonce, } - if signing_key: - mac = hmac.new( - signing_key.encode(), - f"{envelope_json}.{ts}.{nonce}".encode(), - hashlib.sha256, - ).hexdigest() - fields["signature"] = mac - fields["alg"] = "HMAC-SHA256" + canonical_str = self._canonicalize(signature_data) + key_bytes = bytes.fromhex(signing_key) + mac = hmac.new( + key_bytes, + canonical_str.encode("utf-8"), + hashlib.sha256, + ).hexdigest() + fields["signature"] = mac + fields["alg"] = "HMAC-SHA256" await self._redis.xadd("mas.gateway.ingress", fields) logger.debug( "Message enqueued to gateway ingress", @@ -362,7 +445,10 @@ async def handle_diagnosis(self, msg: AgentMessage, payload: Mapping[str, Any]): message_type=message_type, data=payload, meta=MessageMeta( - correlation_id=correlation_id, expects_reply=True, is_reply=False + correlation_id=correlation_id, + expects_reply=True, + is_reply=False, + sender_instance_id=self.instance_id, ), ) await self._send_envelope(message) @@ -438,6 +524,14 @@ async def discover( return await self._registry.discover(capabilities) + def _canonicalize(self, data: Mapping[str, Any]) -> str: + return json.dumps(data, sort_keys=True, separators=(",", ":")) + + async def _get_signing_key(self) -> Optional[str]: + if not self._redis: + return None + return await self._redis.get(f"agent:{self.id}:signing_key") + async def wait_transport_ready(self, timeout: float | None = None) -> None: """ Wait until the framework signals that transport can begin. @@ -470,59 +564,166 @@ async def reset_state(self) -> None: await self._state_manager.reset() async def _stream_loop(self) -> None: - """Consume incoming messages from the agent's delivery stream.""" + """ + Consume incoming messages from both delivery streams. + + Listens on two streams: + 1. Shared stream (agent.stream:{id}) - load-balanced across all instances + 2. Instance stream (agent.stream:{id}:{instance_id}) - replies to this instance + """ if not self._redis: return - stream = f"agent.stream:{self.id}" - group = "agents" - consumer = f"{self.id}-1" + + # Shared stream for load-balanced messages + shared_stream = f"agent.stream:{self.id}" + shared_group = "agents" + shared_consumer = f"{self.id}-{self.instance_id}" + + # Instance-specific stream for replies + # Uses same group name "agents" for xreadgroup compatibility + instance_stream = f"agent.stream:{self.id}:{self.instance_id}" + try: + claim_start_ids = { + shared_stream: "0-0", + instance_stream: "0-0", + } + last_reclaim = 0.0 + reclaim_interval = max(1.0, self._reclaim_idle_ms / 1000.0) + while self._running: + now = time.time() + if now - last_reclaim >= reclaim_interval: + for stream_name in (shared_stream, instance_stream): + claim_start_ids[stream_name] = await self._reclaim_pending( + stream_name, + shared_group, + shared_consumer, + claim_start_ids[stream_name], + ) + last_reclaim = now + + # Read from both streams simultaneously using same consumer group items = await self._redis.xreadgroup( - group, - consumer, - streams={stream: ">"}, + shared_group, + shared_consumer, + streams={shared_stream: ">", instance_stream: ">"}, count=50, block=1000, ) if not items: continue - for _, messages in items: + + for stream_name, messages in items: for entry_id, fields in messages: - try: - data_json = fields.get("envelope", "") - if not data_json: - await self._redis.xack(stream, group, entry_id) - continue - msg = AgentMessage.model_validate_json(data_json) - msg.attach_agent(self) - - # Replies resolve pending requests - if msg.meta.is_reply: - correlation_id = msg.meta.correlation_id - if ( - correlation_id - and correlation_id in self._pending_requests - ): - future = self._pending_requests.pop(correlation_id) - if not future.done(): - future.set_result(msg) - await self._redis.xack(stream, group, entry_id) - continue - - asyncio.create_task( - self._handle_message_with_error_handling(msg) - ) - await self._redis.xack(stream, group, entry_id) - except Exception as e: - logger.error( - "Failed to process stream message", - exc_info=e, - extra={"agent_id": self.id}, - ) + await self._process_stream_entry( + stream_name, + shared_group, + entry_id, + fields, + ) except asyncio.CancelledError: pass + async def _handle_message_and_ack( + self, + msg: AgentMessage, + stream_name: str, + group: str, + entry_id: str, + ) -> None: + try: + await self._handle_message_with_error_handling(msg) + finally: + if self._redis: + try: + await self._redis.xack(stream_name, group, entry_id) + except Exception: + pass + + async def _process_stream_entry( + self, + stream_name: str, + group: str, + entry_id: str, + fields: Mapping[str, str], + ) -> None: + if not self._redis: + return + try: + data_json = fields.get("envelope", "") + if not data_json: + await self._redis.xack(stream_name, group, entry_id) + return + + msg = AgentMessage.model_validate_json(data_json) + msg.attach_agent(self) + + # Replies resolve pending requests + if msg.meta.is_reply: + correlation_id = msg.meta.correlation_id + if correlation_id and correlation_id in self._pending_requests: + future = self._pending_requests.pop(correlation_id) + if not future.done(): + future.set_result(msg) + await self._redis.xack(stream_name, group, entry_id) + return + + asyncio.create_task( + self._handle_message_and_ack( + msg, + stream_name, + group, + entry_id, + ) + ) + except Exception as e: + logger.error( + "Failed to process stream message", + exc_info=e, + extra={ + "agent_id": self.id, + "instance_id": self.instance_id, + "stream": stream_name, + }, + ) + + async def _reclaim_pending( + self, + stream_name: str, + group: str, + consumer: str, + start_id: str, + ) -> str: + if not self._redis or self._reclaim_idle_ms <= 0: + return start_id + + try: + next_start_id, messages, _deleted_ids = await self._redis.xautoclaim( + stream_name, + group, + consumer, + self._reclaim_idle_ms, + start_id, + count=self._reclaim_batch_size, + ) + except Exception as e: + logger.error( + "Failed to reclaim pending messages", + exc_info=e, + extra={ + "agent_id": self.id, + "instance_id": self.instance_id, + "stream": stream_name, + }, + ) + return start_id + + for entry_id, fields in messages: + await self._process_stream_entry(stream_name, group, entry_id, fields) + + return next_start_id + async def _handle_message_with_error_handling(self, msg: AgentMessage) -> None: """ Handle a message with error handling. @@ -545,16 +746,20 @@ async def _handle_message_with_error_handling(self, msg: AgentMessage) -> None: ) async def _heartbeat_loop(self) -> None: - """Send periodic heartbeats.""" + """Send periodic heartbeats for this instance.""" try: while self._running: if self._registry: - await self._registry.update_heartbeat(self.id) + await self._registry.update_heartbeat(self.id, self.instance_id) await asyncio.sleep(30) # Heartbeat every 30 seconds except asyncio.CancelledError: pass except Exception as e: - logger.error("Heartbeat failed", exc_info=e, extra={"agent_id": self.id}) + logger.error( + "Heartbeat failed", + exc_info=e, + extra={"agent_id": self.id, "instance_id": self.instance_id}, + ) # User-overridable hooks @dataclass(frozen=True, slots=True) @@ -652,6 +857,10 @@ async def send_reply_envelope( ) -> None: """ Send a correlated reply to the original message. + + If the original message has a sender_instance_id, the reply is routed + directly to that instance to ensure request-response works correctly + with multi-instance agents. """ if not original.meta.correlation_id: raise RuntimeError("Original message missing correlation_id") @@ -666,6 +875,8 @@ async def send_reply_envelope( correlation_id=original.meta.correlation_id, expects_reply=False, is_reply=True, + # Preserve the original sender's instance ID for routing + sender_instance_id=original.meta.sender_instance_id, ), ) await self._send_envelope(reply) diff --git a/src/mas/gateway/audit.py b/src/mas/gateway/audit.py index 5d4bdd0..1eb6a00 100644 --- a/src/mas/gateway/audit.py +++ b/src/mas/gateway/audit.py @@ -7,9 +7,11 @@ import logging import time from typing import Any, Optional -from ..redis_types import AsyncRedisProtocol + from pydantic import BaseModel, Field +from ..redis_types import AsyncRedisProtocol + logger = logging.getLogger(__name__) AuditRecord = dict[str, Any] @@ -70,6 +72,9 @@ async def log_message( """ Log message to audit stream. + Uses pipeline to batch all Redis operations into a single round-trip, + reducing calls from 5 to 1. + Args: message_id: Unique message identifier sender_id: Sender agent ID @@ -86,7 +91,7 @@ async def log_message( payload_str = json.dumps(payload, sort_keys=True) payload_hash = hashlib.sha256(payload_str.encode()).hexdigest() - # Get previous hash for chain + # Get previous hash for chain (need this before pipeline) previous_hash = await self.redis.get("audit:last_hash") # Create audit entry @@ -112,23 +117,23 @@ async def log_message( # Remove None values (Redis doesn't accept them) entry_dict = {k: v for k, v in entry_dict.items() if v is not None} - # Write to main audit stream - # Redis expects encodable field values; coerce all to strings for consistency - fields_main: dict[str, str] = {k: str(v) for k, v in entry_dict.items()} - main_stream_id = await self.redis.xadd("audit:messages", fields_main) + # Coerce all to strings for Redis + fields: dict[str, str] = {k: str(v) for k, v in entry_dict.items()} - # Index by sender + # Batch all writes using pipeline (reduces 4 calls to 1) sender_stream = f"audit:by_sender:{sender_id}" - fields_sender: dict[str, str] = {k: str(v) for k, v in entry_dict.items()} - await self.redis.xadd(sender_stream, fields_sender) - - # Index by target target_stream = f"audit:by_target:{target_id}" - fields_target: dict[str, str] = {k: str(v) for k, v in entry_dict.items()} - await self.redis.xadd(target_stream, fields_target) - # Update hash chain - await self.redis.set("audit:last_hash", entry_hash) + pipe = self.redis.pipeline() + pipe.xadd("audit:messages", fields) + pipe.xadd(sender_stream, fields) + pipe.xadd(target_stream, fields) + pipe.set("audit:last_hash", entry_hash) + + results = await pipe.execute() + + # First result is the main stream ID + main_stream_id = results[0] logger.debug( "Audit entry logged", diff --git a/src/mas/gateway/circuit_breaker.py b/src/mas/gateway/circuit_breaker.py index cac35cd..0186078 100644 --- a/src/mas/gateway/circuit_breaker.py +++ b/src/mas/gateway/circuit_breaker.py @@ -3,12 +3,11 @@ import logging import time from enum import Enum -from typing import Any, Mapping, Optional +from typing import Any, Mapping, Optional, Tuple from pydantic import BaseModel from ..redis_types import AsyncRedisProtocol - from .metrics import MetricsCollector logger = logging.getLogger(__name__) @@ -202,6 +201,219 @@ async def record_success(self, target_id: str) -> CircuitStatus: allowed=state != CircuitState.OPEN, ) + async def check_and_record_success( + self, target_id: str + ) -> Tuple[CircuitStatus, CircuitStatus]: + """ + Check circuit and record success in one operation. + + This method combines check_circuit and record_success to avoid + the double-fetch pattern where both methods read the same data. + Uses a single hgetall call instead of two. + + Args: + target_id: Target agent ID + + Returns: + Tuple of (check_status, record_status) + """ + circuit_key = f"circuit:{target_id}" + circuit_data = self._normalize_hash(await self.redis.hgetall(circuit_key)) + + if not circuit_data: + # No circuit data, default to CLOSED - no need to record + status = CircuitStatus( + state=CircuitState.CLOSED, + failure_count=0, + success_count=0, + allowed=True, + ) + return status, status + + state = CircuitState(circuit_data.get("state", CircuitState.CLOSED.value)) + failure_count = int(circuit_data.get("failure_count", 0)) + success_count = int(circuit_data.get("success_count", 0)) + last_failure_time = ( + float(circuit_data["last_failure_time"]) + if "last_failure_time" in circuit_data + else None + ) + opened_at = ( + float(circuit_data["opened_at"]) + if "opened_at" in circuit_data and circuit_data["opened_at"] + else None + ) + + current_time = time.time() + + # State transitions for check + if state == CircuitState.OPEN: + if opened_at and (current_time - opened_at) >= self.config.timeout_seconds: + state = CircuitState.HALF_OPEN + success_count = 0 + MetricsCollector.record_circuit_breaker_trip(target_id, "HALF_OPEN") + + allowed = state == CircuitState.CLOSED or state == CircuitState.HALF_OPEN + + check_status = CircuitStatus( + state=state, + failure_count=failure_count, + success_count=success_count, + last_failure_time=last_failure_time, + opened_at=opened_at, + allowed=allowed, + ) + + # Now record success if allowed + if not allowed: + return check_status, check_status + + if state == CircuitState.HALF_OPEN: + success_count += 1 + if success_count >= self.config.success_threshold: + state = CircuitState.CLOSED + failure_count = 0 + success_count = 0 + MetricsCollector.record_circuit_breaker_trip(target_id, "CLOSED") + await self._update_state(target_id, state, failure_count, success_count) + elif state == CircuitState.CLOSED and failure_count > 0: + failure_count = 0 + await self._update_state(target_id, state, failure_count, success_count) + + record_status = CircuitStatus( + state=state, + failure_count=failure_count, + success_count=success_count, + allowed=state != CircuitState.OPEN, + ) + + return check_status, record_status + + async def check_and_record_failure( + self, target_id: str, reason: str = "unknown" + ) -> Tuple[CircuitStatus, CircuitStatus]: + """ + Check circuit and record failure in one operation. + + This method combines check_circuit and record_failure to avoid + the double-fetch pattern where both methods read the same data. + Uses a single hgetall call instead of two. + + Args: + target_id: Target agent ID + reason: Failure reason + + Returns: + Tuple of (check_status, record_status) + """ + circuit_key = f"circuit:{target_id}" + circuit_data = self._normalize_hash(await self.redis.hgetall(circuit_key)) + + current_time = time.time() + opened_at: Optional[float] = None + + if not circuit_data: + state = CircuitState.CLOSED + failure_count = 1 + success_count = 0 + last_failure_time = None + else: + state = CircuitState(circuit_data.get("state", CircuitState.CLOSED.value)) + failure_count = int(circuit_data.get("failure_count", 0)) + success_count = int(circuit_data.get("success_count", 0)) + last_failure_time = ( + float(circuit_data["last_failure_time"]) + if "last_failure_time" in circuit_data + else None + ) + opened_at = ( + float(circuit_data["opened_at"]) + if "opened_at" in circuit_data and circuit_data["opened_at"] + else None + ) + + # State transition for check (OPEN -> HALF_OPEN after timeout) + if state == CircuitState.OPEN: + if ( + opened_at + and (current_time - opened_at) >= self.config.timeout_seconds + ): + state = CircuitState.HALF_OPEN + success_count = 0 + MetricsCollector.record_circuit_breaker_trip(target_id, "HALF_OPEN") + + allowed = state == CircuitState.CLOSED or state == CircuitState.HALF_OPEN + + check_status = CircuitStatus( + state=state, + failure_count=failure_count, + success_count=success_count, + last_failure_time=last_failure_time, + opened_at=opened_at, + allowed=allowed, + ) + + # Now record failure + if circuit_data: + # Check if failures are within window + if last_failure_time and ( + current_time - last_failure_time > self.config.window_seconds + ): + failure_count = 1 + else: + failure_count += 1 + # else: failure_count already set to 1 above + + # Check if we should open circuit + if ( + state == CircuitState.CLOSED + and failure_count >= self.config.failure_threshold + ): + state = CircuitState.OPEN + opened_at = current_time + MetricsCollector.record_circuit_breaker_trip(target_id, "OPEN") + logger.warning( + f"Circuit breaker OPEN for {target_id} after {failure_count} failures", + extra={ + "target_id": target_id, + "state": state, + "failure_count": failure_count, + "reason": reason, + }, + ) + elif state == CircuitState.HALF_OPEN: + state = CircuitState.OPEN + opened_at = current_time + MetricsCollector.record_circuit_breaker_trip(target_id, "OPEN") + logger.warning( + f"Circuit breaker back to OPEN for {target_id} (half-open test failed)", + extra={"target_id": target_id, "state": state, "reason": reason}, + ) + + # Update circuit state + await self.redis.hset( + circuit_key, + mapping={ + "state": state.value, + "failure_count": str(failure_count), + "success_count": str(success_count), + "last_failure_time": str(current_time), + "opened_at": str(opened_at) if opened_at else "", + }, + ) + await self.redis.expire(circuit_key, int(self.config.timeout_seconds * 2)) + + record_status = CircuitStatus( + state=state, + failure_count=failure_count, + success_count=success_count, + last_failure_time=current_time, + opened_at=opened_at, + allowed=state != CircuitState.OPEN, + ) + + return check_status, record_status + async def record_failure( self, target_id: str, reason: str = "unknown" ) -> CircuitStatus: diff --git a/src/mas/gateway/config.py b/src/mas/gateway/config.py index 61af454..485635f 100644 --- a/src/mas/gateway/config.py +++ b/src/mas/gateway/config.py @@ -99,9 +99,9 @@ class FeaturesSettings(BaseSettings): """ Feature flags configuration. - Production-ready defaults (all security features enabled): + Production-ready defaults (security features enabled, queues opt-in): - DLP: True (data loss prevention) - - Priority Queue: True (message prioritization) + - Priority Queue: False (message prioritization is opt-in) - RBAC: True (role-based access control) - Message Signing: True (integrity verification) - Circuit Breaker: True (reliability) @@ -113,7 +113,7 @@ class FeaturesSettings(BaseSettings): """ dlp: bool = Field(default=True, description="Enable DLP scanning") - priority_queue: bool = Field(default=True, description="Enable priority queues") + priority_queue: bool = Field(default=False, description="Enable priority queues") rbac: bool = Field(default=True, description="Enable RBAC authorization (Phase 2)") message_signing: bool = Field( default=True, description="Enable message signing (Phase 2)" diff --git a/src/mas/gateway/gateway.py b/src/mas/gateway/gateway.py index f262f00..17a121c 100644 --- a/src/mas/gateway/gateway.py +++ b/src/mas/gateway/gateway.py @@ -3,6 +3,7 @@ import asyncio import logging import time +import uuid from typing import Any, Optional from pydantic import BaseModel @@ -192,6 +193,7 @@ async def handle_message( signature: Optional[str] = None, timestamp: Optional[float] = None, nonce: Optional[str] = None, + envelope_json: Optional[str] = None, ) -> GatewayResult: """ Handle message through gateway validation pipeline. @@ -315,6 +317,7 @@ async def handle_message( signature=signature, timestamp=timestamp, nonce=nonce, + envelope_json=envelope_json, ) if not sig_result.valid: @@ -583,7 +586,20 @@ async def _route_message(self, message: AgentMessage) -> None: ) else: # Stream-based delivery (at-least-once) - target_stream = f"{self.settings.agent_stream_prefix}{message.target_id}" + # For replies with sender_instance_id, route directly to the specific instance + # to ensure request-response works correctly with multi-instance agents + if message.meta.is_reply and message.meta.sender_instance_id: + # Route to instance-specific stream for reply delivery + target_stream = ( + f"{self.settings.agent_stream_prefix}" + f"{message.target_id}:{message.meta.sender_instance_id}" + ) + else: + # Regular message - route to shared stream (load balanced across instances) + target_stream = ( + f"{self.settings.agent_stream_prefix}{message.target_id}" + ) + await self._redis.xadd( target_stream, { @@ -596,6 +612,8 @@ async def _route_message(self, message: AgentMessage) -> None: extra={ "message_id": message.message_id, "target_stream": target_stream, + "is_instance_specific": message.meta.is_reply + and bool(message.meta.sender_instance_id), }, ) @@ -728,7 +746,7 @@ async def _start_ingress_consumer(self) -> None: async def _loop() -> None: assert self._redis is not None - consumer = "gw-1" + consumer = f"gw-{uuid.uuid4().hex[:8]}" while self._running: try: items = await self._redis.xreadgroup( @@ -759,6 +777,7 @@ async def _loop() -> None: signature=signature, timestamp=timestamp, nonce=nonce, + envelope_json=envelope_json or None, ) if not result.success: # Write to DLQ with reason @@ -780,6 +799,4 @@ async def _loop() -> None: logger.error("Ingress consumer loop error", exc_info=e) await asyncio.sleep(1.0) - import asyncio # local import to avoid unused in type-checking contexts - self._ingress_task = asyncio.create_task(_loop()) diff --git a/src/mas/gateway/message_signing.py b/src/mas/gateway/message_signing.py index 99c3635..3703c0c 100644 --- a/src/mas/gateway/message_signing.py +++ b/src/mas/gateway/message_signing.py @@ -120,6 +120,7 @@ async def sign_message( payload: dict[str, Any], timestamp: Optional[float] = None, nonce: Optional[str] = None, + envelope_json: Optional[str] = None, ) -> dict[str, Any]: """ Sign a message with HMAC. @@ -148,14 +149,21 @@ async def sign_message( raise ValueError(f"No signing key found for agent {agent_id}") # Create signature payload - # Include all fields that should be protected from tampering - signature_data = { - "message_id": message_id, - "sender_id": agent_id, - "timestamp": timestamp, - "nonce": nonce, - "payload": payload, - } + # Prefer signing the full envelope to prevent tampering with routing/meta fields. + if envelope_json is not None: + signature_data = { + "envelope": envelope_json, + "timestamp": timestamp, + "nonce": nonce, + } + else: + signature_data = { + "message_id": message_id, + "sender_id": agent_id, + "timestamp": timestamp, + "nonce": nonce, + "payload": payload, + } # Convert to canonical string representation canonical_str = self._canonicalize(signature_data) @@ -185,6 +193,7 @@ async def verify_signature( signature: str, timestamp: float, nonce: str, + envelope_json: Optional[str] = None, ) -> SignatureResult: """ Verify message signature. @@ -227,13 +236,20 @@ async def verify_signature( ) # Reconstruct signature payload - signature_data = { - "message_id": message_id, - "sender_id": agent_id, - "timestamp": timestamp, - "nonce": nonce, - "payload": payload, - } + if envelope_json is not None: + signature_data = { + "envelope": envelope_json, + "timestamp": timestamp, + "nonce": nonce, + } + else: + signature_data = { + "message_id": message_id, + "sender_id": agent_id, + "timestamp": timestamp, + "nonce": nonce, + "payload": payload, + } # Create canonical representation canonical_str = self._canonicalize(signature_data) diff --git a/src/mas/gateway/priority_queue.py b/src/mas/gateway/priority_queue.py index 09e471b..4a67afb 100644 --- a/src/mas/gateway/priority_queue.py +++ b/src/mas/gateway/priority_queue.py @@ -13,11 +13,11 @@ - Supports message TTL and dead letter queue for expired messages """ +import logging import time from enum import IntEnum from typing import Any, Optional -import logging from pydantic import BaseModel from ..redis_types import AsyncRedisProtocol @@ -166,24 +166,18 @@ async def enqueue( ttl_seconds=ttl, ) - # Store message metadata + # Store message metadata and add to queue using pipeline (3 calls -> 1) metadata_key = self._get_metadata_key(message_id) - await self.redis.setex( - metadata_key, - max(1, int(ttl)), # Ensure at least 1 second - queued_msg.model_dump_json(), - ) - - # Add to priority queue (sorted set by timestamp) queue_key = self._get_queue_key(priority, target_id) score = enqueued_at # Messages with same priority ordered by arrival time - await self.redis.zadd(queue_key, {message_id: score}) - - # Set TTL on the queue itself to clean up empty queues - await self.redis.expire(queue_key, int(ttl * 2)) + pipe = self.redis.pipeline() + pipe.setex(metadata_key, max(1, int(ttl)), queued_msg.model_dump_json()) + pipe.zadd(queue_key, {message_id: score}) + pipe.expire(queue_key, int(ttl * 2)) + await pipe.execute() - # Calculate queue position and estimated delay + # Calculate queue position and estimated delay (uses separate pipeline) queue_position = await self._estimate_queue_position(target_id, priority, score) estimated_delay = await self._estimate_delay(target_id, queue_position) @@ -213,12 +207,16 @@ async def dequeue( """ Dequeue messages for target using fair weighted round-robin. + Uses pipeline batching to check all priority queues in a single + round-trip, then dequeues from non-empty queues based on weighted + selection. + Algorithm: - 1. Use weighted round-robin to select priority level - 2. Dequeue oldest message from selected priority - 3. Check message TTL, skip if expired - 4. Apply fairness boost if message waited too long - 5. Return valid messages + 1. Batch check all priority queue sizes (single pipeline call) + 2. Use weighted round-robin to select from non-empty priorities + 3. Dequeue oldest message from selected priority + 4. Check message TTL, skip if expired + 5. Apply fairness boost if message waited too long Args: target_id: Target agent ID @@ -230,9 +228,18 @@ async def dequeue( messages: list[QueuedMessage] = [] current_time = time.time() + # Batch check all priority queue sizes in a single pipeline call + non_empty_priorities = await self._get_non_empty_priorities(target_id) + + if not non_empty_priorities: + return messages + for _ in range(max_messages): - # Select priority level using weighted round-robin - priority = self._select_priority_fair() + if not non_empty_priorities: + break + + # Select priority level using weighted round-robin from non-empty queues + priority = self._select_priority_fair_from(non_empty_priorities) # Try to dequeue from selected priority message = await self._dequeue_from_priority( @@ -242,20 +249,24 @@ async def dequeue( if message: messages.append(message) else: - # Try other priorities if selected priority is empty + # Queue became empty, remove from candidates and try next + non_empty_priorities.discard(priority) + if not non_empty_priorities: + break + + # Try highest priority non-empty queue for fallback_priority in sorted( - MessagePriority, key=lambda p: p.value, reverse=True + non_empty_priorities, key=lambda p: p.value, reverse=True ): - if fallback_priority == priority: - continue message = await self._dequeue_from_priority( target_id, fallback_priority, current_time ) if message: messages.append(message) break + else: + non_empty_priorities.discard(fallback_priority) - # If no messages found in any queue, stop if not message: break @@ -271,6 +282,73 @@ async def dequeue( return messages + async def _get_non_empty_priorities(self, target_id: str) -> set[MessagePriority]: + """ + Batch check all priority queues and return non-empty ones. + + Uses pipeline to check all queues in a single round-trip. + + Args: + target_id: Target agent ID + + Returns: + Set of priorities that have messages + """ + # Use pipeline to check all queue sizes in one round-trip + pipe = self.redis.pipeline() + priorities_order = list(MessagePriority) + + for priority in priorities_order: + queue_key = self._get_queue_key(priority, target_id) + pipe.zcard(queue_key) + + results = await pipe.execute() + + # Build set of non-empty priorities + non_empty: set[MessagePriority] = set() + for priority, count in zip(priorities_order, results, strict=True): + if count > 0: + non_empty.add(priority) + + return non_empty + + def _select_priority_fair_from( + self, candidates: set[MessagePriority] + ) -> MessagePriority: + """ + Select priority from candidates using weighted round-robin. + + Args: + candidates: Set of non-empty priorities to choose from + + Returns: + Selected priority + """ + if not candidates: + return MessagePriority.NORMAL + + if len(candidates) == 1: + return next(iter(candidates)) + + self._dequeue_counter += 1 + + # Calculate cumulative weights for candidates only + candidate_weights = [ + (p, self.config.dequeue_weights[p]) + for p in sorted(candidates, key=lambda p: p.value, reverse=True) + ] + total_weight = sum(w for _, w in candidate_weights) + position = self._dequeue_counter % total_weight + + cumulative = 0 + for priority, weight in candidate_weights: + cumulative += weight + if position < cumulative: + return priority + + # Fallback to highest priority candidate + return max(candidates, key=lambda p: p.value) + async def _dequeue_from_priority( self, target_id: str, @@ -373,23 +451,29 @@ async def _estimate_queue_position( """ Estimate position in queue. + Uses pipeline to batch all queue size checks into a single round-trip. + Considers: - Messages in higher priority queues (processed first) - Messages in same priority queue with lower score """ - position = 0 + # Use pipeline to batch all queue checks + pipe = self.redis.pipeline() + higher_priorities = [p for p in MessagePriority if p.value > priority.value] - # Count messages in higher priority queues - for p in MessagePriority: - if p.value > priority.value: - queue_key = self._get_queue_key(p, target_id) - count = await self.redis.zcard(queue_key) - position += count + # Queue zcard calls for higher priority queues + for p in higher_priorities: + queue_key = self._get_queue_key(p, target_id) + pipe.zcard(queue_key) - # Count messages in same priority queue with lower score + # Queue zcount for same priority queue queue_key = self._get_queue_key(priority, target_id) - count = await self.redis.zcount(queue_key, "-inf", score) - position += count + pipe.zcount(queue_key, "-inf", score) + + results = await pipe.execute() + + # Sum up position + position = sum(results) return max(1, position) # At least position 1 diff --git a/src/mas/gateway/rate_limit.py b/src/mas/gateway/rate_limit.py index 2ae655e..8561f68 100644 --- a/src/mas/gateway/rate_limit.py +++ b/src/mas/gateway/rate_limit.py @@ -3,12 +3,69 @@ import logging import time from typing import Optional -from ..redis_types import AsyncRedisProtocol + from pydantic import BaseModel +from ..redis_types import AsyncRedisProtocol + logger = logging.getLogger(__name__) +# Lua script for atomic rate limit check +# This reduces 10 Redis calls to 1 atomic operation +_RATE_LIMIT_SCRIPT = """ +-- KEYS[1] = minute window key (ratelimit:{agent_id}:minute) +-- KEYS[2] = hour window key (ratelimit:{agent_id}:hour) +-- KEYS[3] = limits key (ratelimit:{agent_id}:limits) +-- ARGV[1] = message_id +-- ARGV[2] = current timestamp (float) +-- ARGV[3] = default per_minute limit +-- ARGV[4] = default per_hour limit + +local minute_key = KEYS[1] +local hour_key = KEYS[2] +local limits_key = KEYS[3] +local message_id = ARGV[1] +local now = tonumber(ARGV[2]) +local default_per_minute = tonumber(ARGV[3]) +local default_per_hour = tonumber(ARGV[4]) + +-- Get custom limits or use defaults +local per_minute = tonumber(redis.call('HGET', limits_key, 'per_minute')) or default_per_minute +local per_hour = tonumber(redis.call('HGET', limits_key, 'per_hour')) or default_per_hour + +-- Check minute window +local minute_start = now - 60 +redis.call('ZREMRANGEBYSCORE', minute_key, '-inf', minute_start) +local minute_count = redis.call('ZCARD', minute_key) + +if minute_count >= per_minute then + -- Rate limited by minute + return {0, per_minute, 0, now + 60, 'minute'} +end + +-- Check hour window +local hour_start = now - 3600 +redis.call('ZREMRANGEBYSCORE', hour_key, '-inf', hour_start) +local hour_count = redis.call('ZCARD', hour_key) + +if hour_count >= per_hour then + -- Rate limited by hour + return {0, per_hour, 0, now + 3600, 'hour'} +end + +-- Allowed - add to both windows +redis.call('ZADD', minute_key, now, message_id) +redis.call('EXPIRE', minute_key, 60) +redis.call('ZADD', hour_key, now, message_id) +redis.call('EXPIRE', hour_key, 3600) + +-- Return success with minute window info (more restrictive) +local remaining = per_minute - minute_count - 1 +return {1, per_minute, remaining, now + 60, 'minute'} +""" + + class RateLimitResult(BaseModel): """Rate limit check result.""" @@ -60,6 +117,76 @@ async def check_rate_limit(self, agent_id: str, message_id: str) -> RateLimitRes - 60 seconds (per-minute limit) - 3600 seconds (per-hour limit) + This method uses a Lua script to perform all operations atomically + in a single Redis round-trip, reducing calls from ~10 to 1. + + Args: + agent_id: Agent identifier + message_id: Message identifier + + Returns: + RateLimitResult with limit status + """ + now = time.time() + + minute_key = f"ratelimit:{agent_id}:minute" + hour_key = f"ratelimit:{agent_id}:hour" + limits_key = f"ratelimit:{agent_id}:limits" + + # Execute Lua script atomically + result = await self.redis.eval( + _RATE_LIMIT_SCRIPT, + 3, # number of keys + minute_key, + hour_key, + limits_key, + message_id, + str(now), + str(self.default_per_minute), + str(self.default_per_hour), + ) + + # Parse result: [allowed, limit, remaining, reset_time, window] + allowed = bool(result[0]) + limit = int(result[1]) + remaining = int(result[2]) + reset_time = float(result[3]) + window = result[4] if len(result) > 4 else "minute" + + if not allowed: + logger.warning( + f"Rate limit exceeded (per-{window})", + extra={ + "agent_id": agent_id, + "limit": limit, + "remaining": remaining, + }, + ) + else: + logger.debug( + "Rate limit check passed", + extra={ + "agent_id": agent_id, + "remaining": remaining, + }, + ) + + return RateLimitResult( + allowed=allowed, + limit=limit, + remaining=remaining, + reset_time=reset_time, + ) + + async def check_rate_limit_legacy( + self, agent_id: str, message_id: str + ) -> RateLimitResult: + """ + Legacy rate limit check without Lua script. + + Kept for compatibility with Redis servers that don't support Lua. + Uses multiple Redis calls (~10 per check). + Args: agent_id: Agent identifier message_id: Message identifier diff --git a/src/mas/protocol.py b/src/mas/protocol.py index 3a29338..eeb740d 100644 --- a/src/mas/protocol.py +++ b/src/mas/protocol.py @@ -4,9 +4,8 @@ from __future__ import annotations - import time -from typing import Any, Optional, TypeAlias, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional, TypeAlias from pydantic import BaseModel, Field, PrivateAttr @@ -26,6 +25,9 @@ class MessageMeta(BaseModel): correlation_id: Optional[str] = None expects_reply: bool = False is_reply: bool = False + # Instance ID of the sender for routing replies back to the correct instance + # in multi-instance agent deployments + sender_instance_id: Optional[str] = None class EnvelopeMessage(BaseModel): diff --git a/src/mas/redis_types.py b/src/mas/redis_types.py index 90e7a61..01ed2c9 100644 --- a/src/mas/redis_types.py +++ b/src/mas/redis_types.py @@ -14,6 +14,7 @@ Literal, MutableMapping, Protocol, + Self, Set, overload, ) @@ -26,6 +27,36 @@ async def aclose(self) -> None: ... def listen(self) -> AsyncIterator[MutableMapping[str, Any]]: ... +class PipelineProtocol(Protocol): + """ + Protocol for Redis pipeline (batched commands). + + Pipelines allow multiple commands to be sent in a single round-trip, + significantly reducing latency for bulk operations. + """ + + # Pipeline returns self for chaining + def hgetall(self, key: str) -> Self: ... + def hset(self, key: str, *, mapping: dict[str, str]) -> Self: ... + def hget(self, key: str, field: str) -> Self: ... + def delete(self, *keys: str) -> Self: ... + def exists(self, key: str) -> Self: ... + def get(self, key: str) -> Self: ... + def set(self, key: str, value: str) -> Self: ... + def xadd(self, name: str, fields: dict[str, str]) -> Self: ... + def zadd(self, key: str, mapping: dict[str, float]) -> Self: ... + def zcard(self, key: str) -> Self: ... + def expire(self, key: str, seconds: int) -> Self: ... + def ttl(self, key: str) -> Self: ... + def setex(self, key: str, seconds: int, value: str) -> Self: ... + def zcount(self, key: str, min: float | str, max: float | str) -> Self: ... + def incr(self, key: str) -> Self: ... + def decr(self, key: str) -> Self: ... + + # Execute all queued commands and return results + async def execute(self) -> list[Any]: ... + + class AsyncRedisProtocol(Protocol): # Connection def aclose(self) -> Awaitable[None]: ... @@ -44,6 +75,8 @@ def scan_iter(self, *, match: str) -> AsyncIterator[str]: ... def get(self, key: str) -> Awaitable[str | None]: ... def set(self, key: str, value: str) -> Awaitable[bool | str]: ... def setex(self, key: str, seconds: int, value: str) -> Awaitable[bool | str]: ... + def incr(self, key: str) -> Awaitable[int]: ... + def decr(self, key: str) -> Awaitable[int]: ... def publish(self, channel: str, message: str) -> Awaitable[int]: ... def pubsub(self) -> PubSubProtocol: ... @@ -99,4 +132,25 @@ def xreadgroup( count: int | None = ..., block: int | None = ..., ) -> Awaitable[list[tuple[str, list[tuple[str, dict[str, str]]]]] | None]: ... + def xautoclaim( + self, + name: str, + groupname: str, + consumername: str, + min_idle_time: int, + start_id: str, + *, + count: int | None = ..., + ) -> Awaitable[tuple[str, list[tuple[str, dict[str, str]]], list[str]]]: ... def xack(self, name: str, groupname: str, *ids: str) -> Awaitable[int]: ... + + # Scripting + def eval( + self, + script: str, + numkeys: int, + *keys_and_args: str, + ) -> Awaitable[Any]: ... + + # Pipeline + def pipeline(self) -> PipelineProtocol: ... diff --git a/src/mas/registry.py b/src/mas/registry.py index 47e4b57..32fd3ba 100644 --- a/src/mas/registry.py +++ b/src/mas/registry.py @@ -1,11 +1,11 @@ -"""Redis-based agent registry.""" +"""Redis-based agent registry with multi-instance support.""" from __future__ import annotations import json import secrets import time -from typing import Any, List, Optional, TypedDict +from typing import Any, Optional, TypedDict from .redis_types import AsyncRedisProtocol @@ -26,7 +26,15 @@ class AgentRecord(_AgentRecordRequired, total=False): class AgentRegistry: - """Manages agent registration in Redis.""" + """ + Manages agent registration in Redis with multi-instance support. + + Multi-instance features: + - Idempotent registration: first instance registers, subsequent instances join + - Instance counting: tracks active instance count per agent + - Per-instance heartbeats: each instance maintains its own heartbeat + - Graceful deregistration: only removes agent when last instance leaves + """ def __init__(self, redis: AsyncRedisProtocol): """ @@ -40,20 +48,48 @@ def __init__(self, redis: AsyncRedisProtocol): async def register( self, agent_id: str, + instance_id: str, capabilities: list[str], metadata: Optional[dict[str, Any]] = None, ) -> str: """ - Register an agent. + Register an agent instance. + + This operation is idempotent for the agent. The first instance to register + creates the agent entry and generates a token. Subsequent instances + retrieve the existing token and increment the instance count. Args: - agent_id: Unique agent identifier + agent_id: Logical agent identifier (shared across instances) + instance_id: Unique instance identifier capabilities: List of agent capabilities metadata: Optional agent metadata Returns: - Authentication token for the agent + Authentication token for the agent (shared across all instances) """ + agent_key = f"agent:{agent_id}" + instance_count_key = f"agent:{agent_id}:instance_count" + + # Check if agent already exists + existing_data = await self.redis.hgetall(agent_key) + + if existing_data and existing_data.get("token"): + # Agent exists - increment instance count and return existing token + await self.redis.incr(instance_count_key) + # Reactivate if previously inactive + if existing_data.get("status") == "INACTIVE": + await self.redis.hset( + agent_key, + mapping={ + "status": "ACTIVE", + "registered_at": str(time.time()), + }, + ) + await self._ensure_signing_key(agent_id) + return existing_data["token"] + + # First instance - create new registration token = self._generate_token() agent_data: dict[str, str] = { @@ -65,23 +101,62 @@ async def register( "registered_at": str(time.time()), } - await self.redis.hset(f"agent:{agent_id}", mapping=agent_data) + signing_key = self._generate_signing_key() + signing_key_field = f"agent:{agent_id}:signing_key" + + # Use pipeline for atomic registration + pipe = self.redis.pipeline() + pipe.hset(agent_key, mapping=agent_data) + pipe.incr(instance_count_key) # Set to 1 + pipe.set(signing_key_field, signing_key) + await pipe.execute() + return token - async def deregister(self, agent_id: str, keep_state: bool = True) -> None: + async def deregister( + self, + agent_id: str, + instance_id: str, + keep_state: bool = True, + ) -> None: """ - Deregister an agent. + Deregister an agent instance. + + Decrements the instance count. Only removes the agent entry when + the last instance deregisters. Args: - agent_id: Agent identifier to deregister + agent_id: Logical agent identifier + instance_id: Instance identifier being deregistered keep_state: If True, preserves agent state in Redis (default: True) """ - await self.redis.delete(f"agent:{agent_id}") - await self.redis.delete(f"agent:{agent_id}:heartbeat") + instance_count_key = f"agent:{agent_id}:instance_count" + heartbeat_key = f"agent:{agent_id}:heartbeat:{instance_id}" + + # Decrement instance count + new_count = await self.redis.decr(instance_count_key) + + # Always delete this instance's heartbeat + await self.redis.delete(heartbeat_key) + + if new_count <= 0: + # Last instance - clean up agent registration + pipe = self.redis.pipeline() + pipe.delete(f"agent:{agent_id}") + pipe.delete(instance_count_key) - # Only delete state if explicitly requested - if not keep_state: - await self.redis.delete(f"agent.state:{agent_id}") + # Clean up any remaining heartbeat keys for this agent + # (in case of unclean shutdowns) + async for key in self.redis.scan_iter( + match=f"agent:{agent_id}:heartbeat:*" + ): + pipe.delete(key) + + # Only delete state if explicitly requested + if not keep_state: + pipe.delete(f"agent.state:{agent_id}") + + await pipe.execute() async def get_agent(self, agent_id: str) -> AgentRecord | None: """ @@ -105,12 +180,33 @@ async def get_agent(self, agent_id: str) -> AgentRecord | None: registered_at=float(data["registered_at"]), ) + async def get_instance_count(self, agent_id: str) -> int: + """ + Get the number of active instances for an agent. + + Args: + agent_id: Agent identifier + + Returns: + Number of active instances (0 if agent not registered) + """ + count = await self.redis.get(f"agent:{agent_id}:instance_count") + if count is None: + return 0 + return int(count) + async def discover( self, capabilities: list[str] | None = None ) -> list[AgentRecord]: """ Discover agents by capabilities. + Uses pipeline batching to fetch all agent data in a single round-trip, + eliminating the N+1 query pattern. + + Note: Returns logical agents, not instances. Callers don't need to know + about individual instances. + Args: capabilities: Optional list of required capabilities. If None, returns all active agents. @@ -118,15 +214,29 @@ async def discover( Returns: List of agent data dicts """ - agents: List[AgentRecord] = [] + # Phase 1: Collect all matching keys + keys: list[str] = [] pattern = "agent:*" async for key in self.redis.scan_iter(match=pattern): - # Skip non-agent keys (like agent:id:heartbeat) + # Only include agent hashes, not heartbeat or instance_count keys if not key.startswith("agent:") or key.count(":") != 1: continue + keys.append(key) + + if not keys: + return [] + + # Phase 2: Batch fetch all agent data using pipeline + pipe = self.redis.pipeline() + for key in keys: + pipe.hgetall(key) - agent_data = await self.redis.hgetall(key) + results = await pipe.execute() + + # Phase 3: Process results and filter + agents: list[AgentRecord] = [] + for agent_data in results: if not agent_data or agent_data.get("status") != "ACTIVE": continue @@ -146,16 +256,89 @@ async def discover( return agents - async def update_heartbeat(self, agent_id: str, ttl: int = 60) -> None: + async def update_heartbeat( + self, + agent_id: str, + instance_id: str, + ttl: int = 60, + ) -> None: """ - Update agent heartbeat. + Update heartbeat for a specific agent instance. + + Each instance maintains its own heartbeat key. The agent is considered + healthy if at least one instance has a valid heartbeat. Args: - agent_id: Agent identifier + agent_id: Logical agent identifier + instance_id: Instance identifier ttl: Time-to-live in seconds (default: 60) """ - await self.redis.setex(f"agent:{agent_id}:heartbeat", ttl, str(time.time())) + heartbeat_key = f"agent:{agent_id}:heartbeat:{instance_id}" + await self.redis.setex(heartbeat_key, ttl, str(time.time())) + + async def get_instance_heartbeats(self, agent_id: str) -> dict[str, float | None]: + """ + Get heartbeat TTLs for all instances of an agent. + + Args: + agent_id: Agent identifier + + Returns: + Dict mapping instance_id to TTL (None if expired/missing) + """ + heartbeats: dict[str, float | None] = {} + pattern = f"agent:{agent_id}:heartbeat:*" + + # Collect all heartbeat keys + keys: list[str] = [] + async for key in self.redis.scan_iter(match=pattern): + keys.append(key) + + if not keys: + return heartbeats + + # Batch fetch TTLs + pipe = self.redis.pipeline() + for key in keys: + pipe.ttl(key) + + ttls = await pipe.execute() + + # Extract instance IDs and map to TTLs + prefix_len = len(f"agent:{agent_id}:heartbeat:") + for key, ttl in zip(keys, ttls, strict=True): + instance_id = key[prefix_len:] + # TTL of -2 means key doesn't exist, -1 means no expiry + heartbeats[instance_id] = ttl if ttl > 0 else None + + return heartbeats + + async def has_healthy_instance(self, agent_id: str) -> bool: + """ + Check if an agent has at least one healthy instance. + + An instance is healthy if its heartbeat key exists and has TTL > 0. + + Args: + agent_id: Agent identifier + + Returns: + True if at least one instance has a valid heartbeat + """ + heartbeats = await self.get_instance_heartbeats(agent_id) + return any(ttl is not None and ttl > 0 for ttl in heartbeats.values()) + + async def _ensure_signing_key(self, agent_id: str) -> None: + key_field = f"agent:{agent_id}:signing_key" + existing = await self.redis.get(key_field) + if existing: + return + await self.redis.set(key_field, self._generate_signing_key()) def _generate_token(self) -> str: """Generate authentication token.""" return secrets.token_urlsafe(32) + + def _generate_signing_key(self) -> str: + """Generate per-agent signing key (hex encoded).""" + return secrets.token_bytes(32).hex() diff --git a/src/mas/runner.py b/src/mas/runner.py new file mode 100644 index 0000000..a66fe1c --- /dev/null +++ b/src/mas/runner.py @@ -0,0 +1,391 @@ +"""Config-driven agent runner.""" + +from __future__ import annotations + +import asyncio +import importlib +import logging +import os +import signal +import sys +from pathlib import Path +from typing import Annotated, Any, Literal, Optional, TypedDict, Union, cast + +import yaml +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings, SettingsConfigDict + +from .agent import Agent +from .gateway import GatewayService +from .gateway.config import GatewaySettings, load_settings as load_gateway_settings +from .service import MASService + +logger = logging.getLogger(__name__) + + +class AgentSpec(BaseModel): + """Configuration for a single agent definition.""" + + agent_id: str = Field(..., min_length=1, description="Agent ID to register") + class_path: str = Field( + ..., + description="Import path for the agent class (module:ClassName)", + ) + instances: int = Field(default=1, ge=1, description="Number of instances to run") + init_kwargs: dict[str, Any] = Field( + default_factory=dict, description="Kwargs forwarded to agent constructor" + ) + + +class AllowBidirectionalSpec(BaseModel): + """Bidirectional permission for two agents.""" + + type: Literal["allow_bidirectional"] + agents: list[str] = Field(min_length=2, max_length=2) + + +class AllowNetworkSpec(BaseModel): + """Full mesh or chained network permissions.""" + + type: Literal["allow_network"] + agents: list[str] = Field(min_length=2) + bidirectional: bool = True + + +class AllowBroadcastSpec(BaseModel): + """One-way broadcast permissions.""" + + type: Literal["allow_broadcast"] + sender: str + receivers: list[str] = Field(min_length=1) + + +class AllowWildcardSpec(BaseModel): + """Wildcard permission for a single agent.""" + + type: Literal["allow_wildcard"] + agent_id: str + + +class AllowSpec(BaseModel): + """One-way permissions from a sender to targets.""" + + type: Literal["allow"] + sender: str + targets: list[str] = Field(min_length=1) + + +PermissionSpec = Annotated[ + Union[ + AllowBidirectionalSpec, + AllowNetworkSpec, + AllowBroadcastSpec, + AllowWildcardSpec, + AllowSpec, + ], + Field(discriminator="type"), +] + + +class _RunnerSettingsInit(TypedDict, total=False): + config_file: Optional[str] + start_service: bool + service_redis_url: str + start_gateway: bool + gateway_config_file: Optional[str] + permissions: list[PermissionSpec] + agents: list[AgentSpec] + + +class RunnerSettings(BaseSettings): + """ + Runner configuration. + + Configuration sources: + 1) Explicit parameters + 2) Environment variables (MAS_RUNNER_*) + 3) agents.yaml (auto-loaded if present) + 4) Defaults + """ + + config_file: Optional[str] = Field( + default=None, description="Path to YAML config file" + ) + start_service: bool = Field( + default=True, description="Start MAS service alongside agents" + ) + service_redis_url: str = Field( + default="redis://localhost:6379", + description="Redis URL for MAS service", + ) + start_gateway: bool = Field( + default=True, description="Start gateway service for all agents" + ) + gateway_config_file: Optional[str] = Field( + default=None, description="Path to gateway YAML config file" + ) + permissions: list[PermissionSpec] = Field( + default_factory=list, description="Authorization rules to apply" + ) + agents: list[AgentSpec] = Field( + default_factory=list, description="Agent definitions to run" + ) + + model_config = SettingsConfigDict( + env_prefix="MAS_RUNNER_", + env_nested_delimiter="__", + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore", + ) + + def __init__(self, **data: Any) -> None: + config_file = ( + data.get("config_file") + or os.getenv("MAS_RUNNER_CONFIG_FILE") + or self._default_config_file() + ) + + if config_file is None and "agents" not in data: + raise FileNotFoundError( + "agents.yaml not found. Create agents.yaml in the project root " + "or pass a config_file." + ) + + if config_file: + yaml_data = self._load_yaml(config_file) + merged_data: dict[str, Any] = { + **yaml_data, + **data, + "config_file": config_file, + } + super().__init__(**cast(_RunnerSettingsInit, merged_data)) + else: + super().__init__(**cast(_RunnerSettingsInit, data)) + + @staticmethod + def _default_config_file() -> Optional[str]: + start = Path.cwd() + for current in [start, *start.parents]: + candidate = current / "agents.yaml" + if candidate.exists(): + return str(candidate) + return None + + @staticmethod + def _load_yaml(file_path: str) -> dict[str, Any]: + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {file_path}") + + with path.open("r") as f: + data = yaml.safe_load(f) + + if data is None: + return {} + + return data + + +class AgentRunner: + """Start and supervise agent instances from RunnerSettings.""" + + def __init__(self, settings: RunnerSettings) -> None: + self._settings = settings + self._agents: list[Agent[Any]] = [] + self._service: MASService | None = None + self._gateway: GatewayService | None = None + self._shutdown_event = asyncio.Event() + self._ensure_import_base() + + def _ensure_import_base(self) -> None: + if not self._settings.config_file: + return + base = str(Path(self._settings.config_file).resolve().parent) + if base not in sys.path: + sys.path.insert(0, base) + + async def run(self) -> None: + """Start agents and wait for shutdown.""" + if not self._settings.agents: + raise RuntimeError("No agents configured. Provide agents.yaml or settings.") + if not self._settings.start_gateway: + raise RuntimeError( + "Gateway service is required. start_gateway must be true." + ) + + self._setup_signals() + try: + await self._start_service() + await self._start_gateway() + await self._apply_permissions() + await self._start_agents() + logger.info( + "Runner started", + extra={"agent_definitions": len(self._settings.agents)}, + ) + await self._shutdown_event.wait() + finally: + await self._stop_agents() + await self._stop_gateway() + await self._stop_service() + + def request_shutdown(self) -> None: + """Signal the runner to shutdown.""" + if not self._shutdown_event.is_set(): + logger.info("Shutdown requested") + self._shutdown_event.set() + + async def _start_agents(self) -> None: + for spec in self._settings.agents: + agent_cls = self._load_agent_class(spec.class_path) + reserved_keys = {"agent_id"} + conflicting = reserved_keys.intersection(spec.init_kwargs.keys()) + if conflicting: + raise ValueError( + "init_kwargs contains reserved keys: " + + ", ".join(sorted(conflicting)) + ) + if "use_gateway" in spec.init_kwargs: + raise ValueError( + "use_gateway is not supported. MAS always routes via the gateway." + ) + for _ in range(spec.instances): + agent = agent_cls(spec.agent_id, **spec.init_kwargs) + self._agents.append(agent) + + for agent in self._agents: + await agent.start() + + async def _start_service(self) -> None: + if not self._settings.start_service: + return + + service = MASService(redis_url=self._settings.service_redis_url) + await service.start() + self._service = service + + async def _start_gateway(self) -> None: + if not self._settings.start_gateway: + return + + settings = self._load_gateway_settings() + gateway = GatewayService(settings=settings) + await gateway.start() + self._gateway = gateway + + async def _stop_agents(self) -> None: + if not self._agents: + return + + await asyncio.gather( + *(agent.stop() for agent in self._agents), + return_exceptions=True, + ) + self._agents.clear() + + async def _stop_gateway(self) -> None: + if self._gateway is None: + return + + await self._gateway.stop() + self._gateway = None + + async def _apply_permissions(self) -> None: + if not self._gateway or not self._settings.permissions: + return + + auth = self._gateway.auth_manager() + pending_apply = False + + for spec in self._settings.permissions: + if isinstance(spec, AllowBidirectionalSpec): + await auth.allow_bidirectional(spec.agents[0], spec.agents[1]) + elif isinstance(spec, AllowNetworkSpec): + await auth.allow_network(spec.agents, bidirectional=spec.bidirectional) + elif isinstance(spec, AllowBroadcastSpec): + await auth.allow_broadcast(spec.sender, spec.receivers) + elif isinstance(spec, AllowWildcardSpec): + await auth.allow_wildcard(spec.agent_id) + elif isinstance(spec, AllowSpec): + auth.allow(spec.sender, spec.targets) + pending_apply = True + + if pending_apply: + await auth.apply() + + async def _stop_service(self) -> None: + if self._service is None: + return + + await self._service.stop() + self._service = None + + def _setup_signals(self) -> None: + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, self.request_shutdown) + except NotImplementedError: + signal.signal(sig, lambda *_: self.request_shutdown()) + + @staticmethod + def _load_agent_class(class_path: str) -> type[Agent[Any]]: + if ":" not in class_path: + raise ValueError( + f"Invalid class_path '{class_path}'. Use module:ClassName format." + ) + + module_name, class_name = class_path.split(":", 1) + module = importlib.import_module(module_name) + target = getattr(module, class_name) + + if not isinstance(target, type): + raise TypeError(f"{class_path} does not reference a class") + + if not issubclass(target, Agent): + raise TypeError(f"{class_path} is not a mas.Agent subclass") + + return cast(type[Agent[Any]], target) + + def _load_gateway_settings(self) -> GatewaySettings: + if self._settings.gateway_config_file: + return load_gateway_settings(config_file=self._settings.gateway_config_file) + + if self._settings.config_file: + base = Path(self._settings.config_file).resolve().parent + candidate = base / "gateway.yaml" + if candidate.exists(): + return load_gateway_settings(config_file=str(candidate)) + + return load_gateway_settings() + + +def load_runner_settings( + config_file: Optional[str] = None, **overrides: Any +) -> RunnerSettings: + """Load runner settings with optional overrides.""" + if config_file: + overrides["config_file"] = config_file + return RunnerSettings(**overrides) + + +async def main(config_file: Optional[str] = None) -> None: + """Run agents defined by RunnerSettings.""" + settings = load_runner_settings(config_file=config_file) + runner = AgentRunner(settings) + try: + await runner.run() + except RuntimeError as exc: + logger.error(str(exc)) + raise SystemExit(1) from exc + + +def run(config_file: Optional[str] = None) -> None: + """Sync entrypoint for the runner.""" + asyncio.run(main(config_file=config_file)) + + +if __name__ == "__main__": + run() diff --git a/src/mas/service.py b/src/mas/service.py index 250ba7c..80959d3 100644 --- a/src/mas/service.py +++ b/src/mas/service.py @@ -19,7 +19,12 @@ class MASService: Agents communicate peer-to-peer. This service only handles: - Agent registration - Agent discovery - - Health monitoring + - Health monitoring (with multi-instance support) + + Multi-instance health monitoring: + - Each agent instance maintains its own heartbeat + - An agent is considered healthy if at least one instance has a valid heartbeat + - An agent is marked INACTIVE only when ALL instances are unhealthy Usage: service = MASService(redis_url="redis://localhost:6379") @@ -73,7 +78,7 @@ async def stop(self) -> None: logger.info("MAS Service stopped") async def _handle_system_messages(self) -> None: - """Listen for system messages (register, deregister).""" + """Listen for system messages (register, deregister, instance events).""" if not self._redis: return @@ -132,65 +137,159 @@ async def _handle_message(self, msg: dict[str, Any]) -> None: ) case "DEREGISTER": logger.info("Agent deregistered", extra={"agent_id": msg["agent_id"]}) + case "INSTANCE_JOIN": + logger.info( + "Agent instance joined", + extra={ + "agent_id": msg["agent_id"], + "instance_id": msg.get("instance_id"), + "instance_count": msg.get("instance_count"), + }, + ) + case "INSTANCE_LEAVE": + logger.info( + "Agent instance left", + extra={ + "agent_id": msg["agent_id"], + "instance_id": msg.get("instance_id"), + }, + ) case _: logger.warning("Unknown message type", extra={"type": msg.get("type")}) async def _monitor_health(self) -> None: - """Monitor agent health via heartbeats.""" + """Monitor agent health via per-instance heartbeats. + + Multi-instance health monitoring: + 1. Scan for all agent keys + 2. For each agent, check all instance heartbeats + 3. Agent is healthy if ANY instance has a valid heartbeat + 4. Agent is marked INACTIVE only when ALL instances are unhealthy + + Uses pipeline batching to reduce Redis round-trips. + """ while self._running: try: if not self._redis: await asyncio.sleep(30) continue - # Find stale agents by existing heartbeat keys (expiring soon or invalid TTL) - async for key in self._redis.scan_iter(match="agent:*:heartbeat"): - ttl = await self._redis.ttl(key) - if ttl <= 0: # -2 (missing) or -1 (no expiry) or invalid - agent_id = key.split(":")[1] - logger.warning( - "Agent heartbeat expired", extra={"agent_id": agent_id} - ) - # Mark as inactive if still present - agent_key = f"agent:{agent_id}" - exists = await self._redis.exists(agent_key) - if exists: - status = await self._redis.hget(agent_key, "status") - if status != "INACTIVE": - await self._redis.hset( - agent_key, mapping={"status": "INACTIVE"} - ) + # Phase 1: Collect all agent keys (single scan) + agent_keys: list[str] = [] + async for key in self._redis.scan_iter(match="agent:*"): + # Only include agent hashes, not heartbeat or instance_count keys + if key.count(":") == 1: + agent_keys.append(key) - # Also detect agents with missing heartbeat keys entirely (with grace period) - async for agent_key in self._redis.scan_iter(match="agent:*"): - # Skip non-agent hashes like heartbeat keys themselves - if agent_key.count(":") != 1: - continue + if not agent_keys: + await asyncio.sleep(30) + continue + + # Phase 2: Get agent statuses and registration times + pipe = self._redis.pipeline() + for agent_key in agent_keys: + pipe.hget(agent_key, "status") + pipe.hget(agent_key, "registered_at") + + agent_info_results = await pipe.execute() + + # Phase 3: For each agent, collect all instance heartbeat keys + current_time = time.time() + agents_to_check: list[tuple[str, str, float | None]] = [] - hb_key = f"{agent_key}:heartbeat" - if await self._redis.exists(hb_key): + for i, agent_key in enumerate(agent_keys): + status = agent_info_results[i * 2] + reg_at_raw = agent_info_results[i * 2 + 1] + + # Skip if already inactive + if status == "INACTIVE": continue - reg_at_raw = await self._redis.hget(agent_key, "registered_at") - reg_at: Optional[float] = None + reg_at: float | None = None if isinstance(reg_at_raw, str): try: reg_at = float(reg_at_raw) except ValueError: - reg_at = None + pass - if reg_at is None: - continue + agent_id = agent_key.split(":")[1] + agents_to_check.append((agent_key, agent_id, reg_at)) - if (time.time() - reg_at) <= float(self.heartbeat_timeout): + if not agents_to_check: + await asyncio.sleep(30) + continue + + # Phase 4: Collect all heartbeat keys for agents to check + agent_heartbeat_keys: dict[str, list[str]] = {} + for agent_key, agent_id, _ in agents_to_check: + heartbeat_keys: list[str] = [] + async for hb_key in self._redis.scan_iter( + match=f"agent:{agent_id}:heartbeat:*" + ): + heartbeat_keys.append(hb_key) + agent_heartbeat_keys[agent_id] = heartbeat_keys + + # Phase 5: Batch fetch all heartbeat TTLs + all_hb_keys: list[str] = [] + key_to_agent: dict[str, str] = {} + for agent_id, hb_keys in agent_heartbeat_keys.items(): + for hb_key in hb_keys: + all_hb_keys.append(hb_key) + key_to_agent[hb_key] = agent_id + + agent_has_healthy_instance: dict[str, bool] = { + agent_id: False for _, agent_id, _ in agents_to_check + } + + if all_hb_keys: + ttl_pipe = self._redis.pipeline() + for hb_key in all_hb_keys: + ttl_pipe.ttl(hb_key) + + ttls = await ttl_pipe.execute() + + # Determine which agents have at least one healthy instance + for hb_key, ttl in zip(all_hb_keys, ttls, strict=True): + agent_id = key_to_agent[hb_key] + # TTL > 0 means the heartbeat is valid + if ttl is not None and ttl > 0: + agent_has_healthy_instance[agent_id] = True + + # Phase 6: Determine which agents should be deactivated + agents_to_deactivate: list[str] = [] + + for agent_key, agent_id, reg_at in agents_to_check: + has_heartbeat_keys = len(agent_heartbeat_keys.get(agent_id, [])) > 0 + is_healthy = agent_has_healthy_instance.get(agent_id, False) + + if is_healthy: + # At least one instance is healthy continue - status = await self._redis.hget(agent_key, "status") - if status != "INACTIVE": - await self._redis.hset( - agent_key, - mapping={"status": "INACTIVE"}, + if not has_heartbeat_keys: + # No heartbeat keys exist - check grace period for new agents + if reg_at is not None: + if (current_time - reg_at) > float(self.heartbeat_timeout): + logger.warning( + "Agent has no healthy instances (grace period expired)", + extra={"agent_id": agent_id}, + ) + agents_to_deactivate.append(agent_key) + # If no reg_at, skip (shouldn't happen but be safe) + else: + # Has heartbeat keys but none are healthy (all expired) + logger.warning( + "All agent instances unhealthy", + extra={"agent_id": agent_id}, ) + agents_to_deactivate.append(agent_key) + + # Phase 7: Batch update stale agents using pipeline + if agents_to_deactivate: + update_pipe = self._redis.pipeline() + for agent_key in agents_to_deactivate: + update_pipe.hset(agent_key, mapping={"status": "INACTIVE"}) + await update_pipe.execute() await asyncio.sleep(30) # Check every 30 seconds except Exception as e: diff --git a/tests/conftest.py b/tests/conftest.py index 02a00cd..6f94f84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import pytest from redis.asyncio import Redis + from mas import MASService # Use anyio for async test support @@ -28,20 +29,50 @@ async def redis(): @pytest.fixture(autouse=True) async def cleanup_agent_keys(): """ - Auto-use fixture to clean up Redis agent keys before each test. + Auto-use fixture to clean up Redis agent keys and streams before each test. + + This ensures tests don't interfere with each other by cleaning up: + - agent registration keys (agent:*) + - agent state keys (agent.state:*) + - agent delivery streams (agent.stream:*) + - gateway streams (mas.gateway.*) - This ensures tests don't interfere with each other by cleaning up - agent registration keys and state keys. This runs before the redis - fixture cleanup, so it's safe for tests that use the redis fixture. + This runs before the redis fixture cleanup, so it's safe for tests + that use the redis fixture. """ redis = Redis.from_url("redis://localhost:6379", decode_responses=True) - # Delete all test agent keys + + # Collect all keys/streams to delete keys_to_delete = [] + + # Agent registration and heartbeat keys async for key in redis.scan_iter("agent:*"): keys_to_delete.append(key) + + # Agent state keys async for key in redis.scan_iter("agent.state:*"): keys_to_delete.append(key) + # Agent delivery streams (including instance-specific streams) + async for key in redis.scan_iter("agent.stream:*"): + keys_to_delete.append(key) + + # Gateway streams (ingress, dlq, etc.) + async for key in redis.scan_iter("mas.gateway.*"): + keys_to_delete.append(key) + + # Rate limit keys + async for key in redis.scan_iter("rate_limit:*"): + keys_to_delete.append(key) + + # ACL keys + async for key in redis.scan_iter("acl:*"): + keys_to_delete.append(key) + + # Audit keys + async for key in redis.scan_iter("audit:*"): + keys_to_delete.append(key) + if keys_to_delete: await redis.delete(*keys_to_delete) diff --git a/tests/test_audit_query.py b/tests/test_audit_query.py index a53d8f7..2590e67 100644 --- a/tests/test_audit_query.py +++ b/tests/test_audit_query.py @@ -1,9 +1,11 @@ """Tests for Audit Query API.""" -import json import asyncio -import pytest +import json import time + +import pytest + from mas.gateway.audit import AuditModule # Use anyio for async test support @@ -359,7 +361,8 @@ class TestAuditComplianceReport: async def test_export_compliance_report_csv(self, audit_module): """Test generating compliance report in CSV format.""" - start_time = time.time() + # Use buffer to avoid timing edge cases where messages fall outside range + start_time = time.time() - 0.1 # Log some messages for i in range(3): @@ -367,7 +370,7 @@ async def test_export_compliance_report_csv(self, audit_module): f"msg-{i}", "agent-a", "agent-b", "ALLOWED", 10.0, {} ) - end_time = time.time() + end_time = time.time() + 0.1 # Generate report report = await audit_module.export_compliance_report( @@ -381,7 +384,8 @@ async def test_export_compliance_report_csv(self, audit_module): async def test_export_compliance_report_json(self, audit_module): """Test generating compliance report in JSON format.""" - start_time = time.time() + # Use buffer to avoid timing edge cases where messages fall outside range + start_time = time.time() - 0.1 # Log some messages for i in range(3): @@ -389,7 +393,7 @@ async def test_export_compliance_report_json(self, audit_module): f"msg-{i}", "agent-a", "agent-b", "ALLOWED", 10.0, {} ) - end_time = time.time() + end_time = time.time() + 0.1 # Generate report report = await audit_module.export_compliance_report( diff --git a/tests/test_config.py b/tests/test_config.py index c9468eb..3e4e1ae 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -67,9 +67,9 @@ def test_default_features(self): """Test default feature flags (production-ready defaults).""" settings = FeaturesSettings() - # All features enabled by default for production security + # Security features enabled by default; priority queue is opt-in assert settings.dlp is True - assert settings.priority_queue is True + assert settings.priority_queue is False assert settings.rbac is True assert settings.message_signing is True assert settings.circuit_breaker is True @@ -185,7 +185,7 @@ def test_default_gateway_settings(self): assert settings.redis.url == "redis://localhost:6379" assert settings.rate_limit.per_minute == 100 assert settings.features.dlp is True - assert settings.features.priority_queue is True + assert settings.features.priority_queue is False assert settings.features.rbac is True assert settings.features.message_signing is True assert settings.features.circuit_breaker is True diff --git a/tests/test_message_signing.py b/tests/test_message_signing.py index 3eb1acc..ee622fd 100644 --- a/tests/test_message_signing.py +++ b/tests/test_message_signing.py @@ -302,7 +302,7 @@ async def test_get_nonce_status(self, signing_module): class TestCanonicalRepresentation: """Test canonical data representation.""" - def test_canonicalize_deterministic(self, signing_module): + async def test_canonicalize_deterministic(self, signing_module): """Test that canonicalization is deterministic.""" data1 = {"b": 2, "a": 1, "c": 3} data2 = {"c": 3, "a": 1, "b": 2} @@ -313,7 +313,7 @@ def test_canonicalize_deterministic(self, signing_module): # Same data in different order should produce same canonical form assert canon1 == canon2 - def test_canonicalize_nested(self, signing_module): + async def test_canonicalize_nested(self, signing_module): """Test canonicalization of nested data.""" data = { "outer": {"b": 2, "a": 1}, diff --git a/tests/test_multi_instance.py b/tests/test_multi_instance.py new file mode 100644 index 0000000..09ed8c6 --- /dev/null +++ b/tests/test_multi_instance.py @@ -0,0 +1,574 @@ +"""Tests for multi-instance agent support.""" + +import asyncio +from typing import override + +import pytest +from pydantic import BaseModel + +from mas import Agent, AgentMessage +from mas.gateway import GatewayService +from mas.gateway.config import FeaturesSettings, GatewaySettings +from mas.registry import AgentRegistry + +pytestmark = pytest.mark.asyncio + + +class CounterState(BaseModel): + """State model for counter agent.""" + + count: int = 0 + + +class CollectorAgent(Agent): + """Test agent that collects received messages.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.messages: list[AgentMessage] = [] + self.message_event = asyncio.Event() + + @override + async def on_message(self, message: AgentMessage) -> None: + """Store received messages.""" + self.messages.append(message) + self.message_event.set() + + +class ResponderAgent(Agent): + """Test agent that responds with its instance_id.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.handled_count = 0 + + @override + async def on_message(self, message: AgentMessage) -> None: + """Respond with instance info.""" + self.handled_count += 1 + if message.meta.expects_reply and message.meta.correlation_id: + await message.reply( + "response", + { + "instance_id": self.instance_id, + "handled_count": self.handled_count, + }, + ) + + +class TestMultiInstanceBasics: + """Test basic multi-instance functionality.""" + + async def test_each_instance_has_unique_id(self): + """Test that each agent instance gets a unique instance_id.""" + agent1 = Agent("shared_agent", capabilities=["test"]) + agent2 = Agent("shared_agent", capabilities=["test"]) + + # Instance IDs should be unique even for same agent_id + assert agent1.instance_id != agent2.instance_id + assert len(agent1.instance_id) == 8 + assert len(agent2.instance_id) == 8 + + async def test_instance_id_is_stable(self): + """Test that instance_id doesn't change during agent lifecycle.""" + agent = Agent("test_agent", capabilities=["test"]) + initial_id = agent.instance_id + + await agent.start() + assert agent.instance_id == initial_id + + await agent.stop() + assert agent.instance_id == initial_id + + +class TestInstanceRegistration: + """Test instance registration and tracking.""" + + async def test_first_instance_registers_agent(self, redis): + """Test that first instance creates agent registration.""" + agent = Agent("new_agent", capabilities=["test"]) + await agent.start() + + try: + # Verify agent is registered + agent_data = await redis.hgetall("agent:new_agent") + assert agent_data["id"] == "new_agent" + assert agent_data["status"] == "ACTIVE" + assert "token" in agent_data + + # Verify instance count + count = await redis.get("agent:new_agent:instance_count") + assert count == "1" + finally: + await agent.stop() + + async def test_second_instance_joins_existing(self, redis): + """Test that second instance joins without re-registering.""" + agent1 = Agent("shared_agent", capabilities=["test"]) + await agent1.start() + + original_token = agent1.token + original_reg_time = await redis.hget("agent:shared_agent", "registered_at") + + try: + # Start second instance + agent2 = Agent("shared_agent", capabilities=["test"]) + await agent2.start() + + try: + # Both should have the same token + assert agent2.token == original_token + + # Registration time should be unchanged + current_reg_time = await redis.hget( + "agent:shared_agent", "registered_at" + ) + assert current_reg_time == original_reg_time + + # Instance count should be 2 + count = await redis.get("agent:shared_agent:instance_count") + assert count == "2" + finally: + await agent2.stop() + finally: + await agent1.stop() + + async def test_instance_count_decrements_on_stop(self, redis): + """Test that instance count decrements when instance stops.""" + agent1 = Agent("shared_agent", capabilities=["test"]) + agent2 = Agent("shared_agent", capabilities=["test"]) + + await agent1.start() + await agent2.start() + + # Both running + count = await redis.get("agent:shared_agent:instance_count") + assert count == "2" + + # Stop one + await agent2.stop() + count = await redis.get("agent:shared_agent:instance_count") + assert count == "1" + + # Agent still registered + agent_data = await redis.hgetall("agent:shared_agent") + assert agent_data["status"] == "ACTIVE" + + # Stop last one + await agent1.stop() + + # Agent should be deregistered + agent_data = await redis.hgetall("agent:shared_agent") + assert agent_data == {} + + +class TestInstanceHeartbeats: + """Test per-instance heartbeat functionality.""" + + async def test_each_instance_has_own_heartbeat(self, redis): + """Test that each instance maintains its own heartbeat.""" + agent1 = Agent("shared_agent", capabilities=["test"]) + agent2 = Agent("shared_agent", capabilities=["test"]) + + await agent1.start() + await agent2.start() + + try: + # Wait for heartbeats to be set + await asyncio.sleep(0.1) + + # Each instance should have its own heartbeat key + hb1_key = f"agent:shared_agent:heartbeat:{agent1.instance_id}" + hb2_key = f"agent:shared_agent:heartbeat:{agent2.instance_id}" + + hb1_exists = await redis.exists(hb1_key) + hb2_exists = await redis.exists(hb2_key) + + assert hb1_exists == 1 + assert hb2_exists == 1 + finally: + await agent1.stop() + await agent2.stop() + + async def test_heartbeat_cleanup_on_stop(self, redis): + """Test that heartbeat is cleaned up when instance stops.""" + agent = Agent("test_agent", capabilities=["test"]) + await agent.start() + + hb_key = f"agent:test_agent:heartbeat:{agent.instance_id}" + + # Heartbeat should exist + await asyncio.sleep(0.1) + assert await redis.exists(hb_key) == 1 + + # Stop agent + await agent.stop() + + # Heartbeat should be cleaned up + assert await redis.exists(hb_key) == 0 + + +class TestInstanceHealthChecks: + """Test health checking with multiple instances.""" + + async def test_registry_has_healthy_instance(self, redis): + """Test has_healthy_instance check.""" + registry = AgentRegistry(redis) + + # Register and set up heartbeat manually + agent = Agent("health_test", capabilities=["test"]) + await agent.start() + + try: + # Should be healthy + is_healthy = await registry.has_healthy_instance("health_test") + assert is_healthy is True + finally: + await agent.stop() + + async def test_registry_instance_heartbeats(self, redis): + """Test get_instance_heartbeats returns all instances.""" + agent1 = Agent("multi_health", capabilities=["test"]) + agent2 = Agent("multi_health", capabilities=["test"]) + + await agent1.start() + await agent2.start() + + try: + await asyncio.sleep(0.1) + + registry = AgentRegistry(redis) + heartbeats = await registry.get_instance_heartbeats("multi_health") + + # Should have two heartbeats + assert len(heartbeats) == 2 + assert agent1.instance_id in heartbeats + assert agent2.instance_id in heartbeats + + # Both should have positive TTL + assert all(ttl is not None and ttl > 0 for ttl in heartbeats.values()) + finally: + await agent1.stop() + await agent2.stop() + + +class TestLoadBalancing: + """Test message load balancing across instances.""" + + async def test_messages_distributed_across_instances(self, mas_service): + """Test that messages are load-balanced across instances.""" + settings = GatewaySettings( + features=FeaturesSettings( + dlp=False, + priority_queue=False, + rbac=False, + message_signing=False, + circuit_breaker=False, + ) + ) + gateway = GatewayService(settings=settings) + await gateway.start() + + # Create sender + sender = Agent("sender", capabilities=["send"]) + + # Create two instances of receiver + receiver1 = CollectorAgent("receiver", capabilities=["receive"]) + receiver2 = CollectorAgent("receiver", capabilities=["receive"]) + + await sender.start() + await receiver1.start() + await receiver2.start() + + try: + await gateway.auth_manager().allow_bidirectional("sender", "receiver") + + # Send multiple messages + num_messages = 20 + for i in range(num_messages): + await sender.send("receiver", "test.message", {"index": i}) + + # Wait for delivery + await asyncio.sleep(1.0) + + total_received = len(receiver1.messages) + len(receiver2.messages) + assert total_received == num_messages + + # Both instances should have received some messages (load balanced) + # Note: With only 2 instances and 20 messages, distribution may vary + # but both should have received at least 1 message + assert len(receiver1.messages) > 0, "Instance 1 received no messages" + assert len(receiver2.messages) > 0, "Instance 2 received no messages" + finally: + await sender.stop() + await receiver1.stop() + await receiver2.stop() + await gateway.stop() + + +class TestRequestResponseRouting: + """Test that request-response routes replies to correct instance.""" + + async def test_reply_routes_to_requesting_instance(self, mas_service): + """Test that replies go to the specific instance that made the request.""" + settings = GatewaySettings( + features=FeaturesSettings( + dlp=False, + priority_queue=False, + rbac=False, + message_signing=False, + circuit_breaker=False, + ) + ) + gateway = GatewayService(settings=settings) + await gateway.start() + + # Two instances of requester + requester1 = Agent("requester", capabilities=["request"]) + requester2 = Agent("requester", capabilities=["request"]) + + # One responder + responder = ResponderAgent("responder", capabilities=["respond"]) + + await requester1.start() + await requester2.start() + await responder.start() + + try: + await gateway.auth_manager().allow_bidirectional("requester", "responder") + + # Both instances make requests + response1 = await requester1.request( + "responder", + "test.request", + {"from_instance": requester1.instance_id}, + timeout=5.0, + ) + + response2 = await requester2.request( + "responder", + "test.request", + {"from_instance": requester2.instance_id}, + timeout=5.0, + ) + + # Each should have received a response + assert response1 is not None + assert response2 is not None + + # Responses should be from the same responder + assert response1.data["instance_id"] == responder.instance_id + assert response2.data["instance_id"] == responder.instance_id + + # Responder should have handled both requests + assert responder.handled_count == 2 + finally: + await requester1.stop() + await requester2.stop() + await responder.stop() + await gateway.stop() + + async def test_concurrent_requests_from_multiple_instances(self, mas_service): + """Test concurrent requests from multiple instances route correctly.""" + settings = GatewaySettings( + features=FeaturesSettings( + dlp=False, + priority_queue=False, + rbac=False, + message_signing=False, + circuit_breaker=False, + ) + ) + gateway = GatewayService(settings=settings) + await gateway.start() + + # Multiple requester instances + requesters = [ + Agent(f"requester_{i}", capabilities=["request"]) for i in range(3) + ] + + # Single responder + responder = ResponderAgent("responder", capabilities=["respond"]) + + for r in requesters: + await r.start() + await responder.start() + + try: + for r in requesters: + await gateway.auth_manager().allow_bidirectional(r.id, "responder") + + # All instances make concurrent requests + async def make_request(requester): + return await requester.request( + "responder", + "test.request", + {"from": requester.id}, + timeout=5.0, + ) + + responses = await asyncio.gather(*[make_request(r) for r in requesters]) + + # All should have received responses + assert len(responses) == 3 + assert all(r is not None for r in responses) + finally: + for r in requesters: + await r.stop() + await responder.stop() + await gateway.stop() + + +class TestSharedState: + """Test that state is shared across instances.""" + + async def test_state_shared_across_instances(self, redis): + """Test that multiple instances share the same state.""" + agent1 = Agent("stateful", capabilities=["test"], state_model=CounterState) + agent2 = Agent("stateful", capabilities=["test"], state_model=CounterState) + + await agent1.start() + await agent2.start() + + try: + # Update state from instance 1 + await agent1.update_state({"count": 42}) + + # Reload state in instance 2 + if agent2._state_manager: + await agent2._state_manager.load() + + # Instance 2 should see the updated state + assert agent2.state.count == 42 + finally: + await agent1.stop() + await agent2.stop() + + +class TestDiscovery: + """Test that discovery returns logical agents, not instances.""" + + async def test_discovery_returns_single_agent(self, mas_service): + """Test that discovery doesn't duplicate multi-instance agents.""" + # Create multiple instances of same agent + agent1 = Agent("multi_instance", capabilities=["special"]) + agent2 = Agent("multi_instance", capabilities=["special"]) + discoverer = Agent("discoverer", capabilities=["discover"]) + + await agent1.start() + await agent2.start() + await discoverer.start() + + try: + # Wait for registration + await asyncio.sleep(0.1) + + # Discover should return single agent + agents = await discoverer.discover(capabilities=["special"]) + assert len(agents) == 1 + assert agents[0]["id"] == "multi_instance" + finally: + await agent1.stop() + await agent2.stop() + await discoverer.stop() + + +class TestSystemEvents: + """Test system events for instance lifecycle.""" + + async def test_register_event_only_on_first_instance(self, redis): + """Test that REGISTER event fires only for first instance.""" + pubsub = redis.pubsub() + await pubsub.subscribe("mas.system") + + events = [] + + async def collect_events(): + async for msg in pubsub.listen(): + if msg["type"] == "message": + import json + + events.append(json.loads(msg["data"])) + if len(events) >= 4: # Wait for expected events + break + + collector_task = asyncio.create_task(collect_events()) + + agent1 = Agent("event_test", capabilities=["test"]) + agent2 = Agent("event_test", capabilities=["test"]) + + try: + await agent1.start() + await asyncio.sleep(0.1) + await agent2.start() + await asyncio.sleep(0.1) + + # Wait for events + await asyncio.wait_for(collector_task, timeout=2.0) + + # Should have: REGISTER (once), INSTANCE_JOIN (twice) + register_events = [e for e in events if e.get("type") == "REGISTER"] + join_events = [e for e in events if e.get("type") == "INSTANCE_JOIN"] + + assert len(register_events) == 1 + assert len(join_events) == 2 + except asyncio.TimeoutError: + collector_task.cancel() + finally: + await agent1.stop() + await agent2.stop() + await pubsub.unsubscribe() + await pubsub.aclose() + + async def test_deregister_event_only_on_last_instance(self, redis): + """Test that DEREGISTER event fires only when last instance leaves.""" + agent1 = Agent("event_test", capabilities=["test"]) + agent2 = Agent("event_test", capabilities=["test"]) + + await agent1.start() + await agent2.start() + + pubsub = redis.pubsub() + await pubsub.subscribe("mas.system") + + events = [] + + async def collect_events(): + async for msg in pubsub.listen(): + if msg["type"] == "message": + import json + + events.append(json.loads(msg["data"])) + if any(e.get("type") == "DEREGISTER" for e in events): + break + + collector_task = asyncio.create_task(collect_events()) + + try: + # Stop first instance + await agent1.stop() + await asyncio.sleep(0.1) + + # Should have INSTANCE_LEAVE but not DEREGISTER + leave_events = [e for e in events if e.get("type") == "INSTANCE_LEAVE"] + deregister_events = [e for e in events if e.get("type") == "DEREGISTER"] + + assert len(leave_events) >= 1 + assert len(deregister_events) == 0 + + # Stop second (last) instance + await agent2.stop() + await asyncio.sleep(0.1) + + # Wait for DEREGISTER + await asyncio.wait_for(collector_task, timeout=2.0) + + # Now should have DEREGISTER + deregister_events = [e for e in events if e.get("type") == "DEREGISTER"] + assert len(deregister_events) == 1 + except asyncio.TimeoutError: + collector_task.cancel() + await agent2.stop() + finally: + await pubsub.unsubscribe() + await pubsub.aclose() diff --git a/tests/test_performance.py b/tests/test_performance.py new file mode 100644 index 0000000..cf8fdfa --- /dev/null +++ b/tests/test_performance.py @@ -0,0 +1,673 @@ +"""Performance benchmark tests. + +These tests measure throughput and latency to validate optimization impact. +Run with verbose output to see benchmark results: + + uv run pytest tests/test_performance.py -v -s + +Note: These tests require a running Redis instance on localhost:6379. +""" + +import asyncio +import json +import statistics +import time +from dataclasses import dataclass +from typing import Any, Callable, Coroutine + +import pytest + +pytestmark = pytest.mark.asyncio + + +@dataclass +class BenchmarkResult: + """Results from a benchmark run.""" + + name: str + iterations: int + total_time_seconds: float + throughput_per_second: float + avg_latency_ms: float + min_latency_ms: float + max_latency_ms: float + p50_latency_ms: float + p95_latency_ms: float + p99_latency_ms: float + + def __str__(self) -> str: + return ( + f"\n{'=' * 60}\n" + f"Benchmark: {self.name}\n" + f"{'=' * 60}\n" + f" Iterations: {self.iterations}\n" + f" Total time: {self.total_time_seconds:.3f}s\n" + f" Throughput: {self.throughput_per_second:.2f} ops/sec\n" + f" Avg latency: {self.avg_latency_ms:.3f}ms\n" + f" Min latency: {self.min_latency_ms:.3f}ms\n" + f" Max latency: {self.max_latency_ms:.3f}ms\n" + f" P50 latency: {self.p50_latency_ms:.3f}ms\n" + f" P95 latency: {self.p95_latency_ms:.3f}ms\n" + f" P99 latency: {self.p99_latency_ms:.3f}ms\n" + f"{'=' * 60}" + ) + + +async def run_benchmark( + name: str, + func: Callable[[], Coroutine[Any, Any, Any]], + iterations: int = 100, + warmup: int = 10, +) -> BenchmarkResult: + """ + Run a benchmark and collect timing statistics. + + Args: + name: Name of the benchmark + func: Async function to benchmark (no arguments) + iterations: Number of iterations to run + warmup: Number of warmup iterations (not counted) + + Returns: + BenchmarkResult with timing statistics + """ + # Warmup phase + for _ in range(warmup): + await func() + + # Benchmark phase + latencies: list[float] = [] + start_total = time.perf_counter() + + for _ in range(iterations): + start = time.perf_counter() + await func() + end = time.perf_counter() + latencies.append((end - start) * 1000) # Convert to ms + + end_total = time.perf_counter() + total_time = end_total - start_total + + # Calculate statistics + latencies.sort() + p50_idx = int(len(latencies) * 0.50) + p95_idx = int(len(latencies) * 0.95) + p99_idx = int(len(latencies) * 0.99) + + return BenchmarkResult( + name=name, + iterations=iterations, + total_time_seconds=total_time, + throughput_per_second=iterations / total_time, + avg_latency_ms=statistics.mean(latencies), + min_latency_ms=min(latencies), + max_latency_ms=max(latencies), + p50_latency_ms=latencies[p50_idx], + p95_latency_ms=latencies[p95_idx], + p99_latency_ms=latencies[p99_idx], + ) + + +# ----------------------------------------------------------------------------- +# Rate Limit Benchmarks +# ----------------------------------------------------------------------------- + + +class TestRateLimitPerformance: + """Benchmark rate limiting performance.""" + + async def test_rate_limit_throughput(self, redis): + """Measure rate limit checks per second.""" + from mas.gateway.rate_limit import RateLimitModule + + # High limits to avoid blocking during benchmark + module = RateLimitModule( + redis, default_per_minute=100000, default_per_hour=1000000 + ) + + msg_counter = 0 + + async def rate_limit_check(): + nonlocal msg_counter + msg_counter += 1 + await module.check_rate_limit("bench_agent", f"msg_{msg_counter}") + + result = await run_benchmark( + name="Rate Limit Check", + func=rate_limit_check, + iterations=200, + warmup=20, + ) + + print(result) + + # Baseline expectations (adjust after optimization) + # BEFORE optimization: ~100-500 ops/sec (limited by 10 Redis calls) + # AFTER optimization: ~1000-5000 ops/sec (single Lua call) + assert result.throughput_per_second > 50, ( + f"Throughput too low: {result.throughput_per_second:.2f} ops/sec" + ) + assert result.p95_latency_ms < 100, ( + f"P95 latency too high: {result.p95_latency_ms:.3f}ms" + ) + + async def test_rate_limit_under_load(self, redis): + """Test rate limit performance under concurrent load.""" + from mas.gateway.rate_limit import RateLimitModule + + module = RateLimitModule( + redis, default_per_minute=100000, default_per_hour=1000000 + ) + + async def check_for_agent(agent_id: str, count: int): + for i in range(count): + await module.check_rate_limit(agent_id, f"msg_{i}") + + # Simulate 5 agents making concurrent requests + num_agents = 5 + requests_per_agent = 50 + + start = time.perf_counter() + await asyncio.gather( + *[ + check_for_agent(f"agent_{i}", requests_per_agent) + for i in range(num_agents) + ] + ) + elapsed = time.perf_counter() - start + + total_requests = num_agents * requests_per_agent + throughput = total_requests / elapsed + + print("\nConcurrent Rate Limit Benchmark:") + print(f" Agents: {num_agents}") + print(f" Requests per agent: {requests_per_agent}") + print(f" Total requests: {total_requests}") + print(f" Total time: {elapsed:.3f}s") + print(f" Throughput: {throughput:.2f} ops/sec") + + assert throughput > 100, f"Concurrent throughput too low: {throughput:.2f}" + + +# ----------------------------------------------------------------------------- +# Discovery Benchmarks +# ----------------------------------------------------------------------------- + + +class TestDiscoveryPerformance: + """Benchmark agent discovery performance.""" + + @pytest.fixture + async def setup_many_agents(self, redis): + """Setup many agents for discovery benchmarking.""" + num_agents = 100 + agent_ids = [f"perf_agent_{i}" for i in range(num_agents)] + + for i, agent_id in enumerate(agent_ids): + # Distribute capabilities across agents + caps = ["common"] + if i % 2 == 0: + caps.append("even") + if i % 3 == 0: + caps.append("divisible_by_3") + if i % 10 == 0: + caps.append("divisible_by_10") + + await redis.hset( + f"agent:{agent_id}", + mapping={ + "id": agent_id, + "capabilities": json.dumps(caps), + "metadata": json.dumps({"index": i}), + "status": "ACTIVE", + "registered_at": str(time.time()), + }, + ) + + yield num_agents + + # Cleanup + for agent_id in agent_ids: + await redis.delete(f"agent:{agent_id}") + + async def test_discovery_throughput(self, redis, setup_many_agents): + """Measure discovery operations per second.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + num_agents = setup_many_agents + + async def discover_all(): + return await registry.discover() + + result = await run_benchmark( + name=f"Discovery (all {num_agents} agents)", + func=discover_all, + iterations=50, + warmup=5, + ) + + print(result) + + # Baseline expectations + # BEFORE optimization: ~10-50 ops/sec (N+1 queries) + # AFTER optimization: ~100-500 ops/sec (pipeline batching) + assert result.throughput_per_second > 5, ( + f"Throughput too low: {result.throughput_per_second:.2f} ops/sec" + ) + + async def test_discovery_with_filter(self, redis, setup_many_agents): + """Measure filtered discovery performance.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + + async def discover_filtered(): + return await registry.discover(capabilities=["divisible_by_10"]) + + result = await run_benchmark( + name="Discovery (filtered by capability)", + func=discover_filtered, + iterations=50, + warmup=5, + ) + + print(result) + + # Filtered discovery should have similar performance characteristics + # since filtering happens after fetching + assert result.throughput_per_second > 5 + + async def test_discovery_scaling(self, redis): + """Test how discovery scales with number of agents.""" + from mas.registry import AgentRegistry + + results: list[tuple[int, float]] = [] + + for num_agents in [10, 25, 50, 100]: + # Setup agents + agent_ids = [f"scale_agent_{i}" for i in range(num_agents)] + for agent_id in agent_ids: + await redis.hset( + f"agent:{agent_id}", + mapping={ + "id": agent_id, + "capabilities": json.dumps(["scale_test"]), + "metadata": "{}", + "status": "ACTIVE", + "registered_at": "123", + }, + ) + + registry = AgentRegistry(redis) + + # Measure time + start = time.perf_counter() + for _ in range(10): + await registry.discover() + elapsed = time.perf_counter() - start + + avg_time_ms = (elapsed / 10) * 1000 + results.append((num_agents, avg_time_ms)) + + # Cleanup + for agent_id in agent_ids: + await redis.delete(f"agent:{agent_id}") + + print("\nDiscovery Scaling:") + print(f" {'Agents':<10} {'Avg Time (ms)':<15} {'Time/Agent (ms)':<15}") + print(f" {'-' * 40}") + for num_agents, avg_time in results: + per_agent = avg_time / num_agents + print(f" {num_agents:<10} {avg_time:<15.3f} {per_agent:<15.3f}") + + # After optimization, time/agent should be much lower due to batching + # BEFORE: time scales linearly with N (O(N) queries) + # AFTER: time is nearly constant (O(1) batched queries) + + +# ----------------------------------------------------------------------------- +# Audit Benchmarks +# ----------------------------------------------------------------------------- + + +class TestAuditPerformance: + """Benchmark audit logging performance.""" + + async def test_audit_log_throughput(self, redis): + """Measure audit log writes per second.""" + from mas.gateway.audit import AuditModule + + audit = AuditModule(redis) + msg_counter = 0 + + async def log_message(): + nonlocal msg_counter + msg_counter += 1 + await audit.log_message( + message_id=f"msg_{msg_counter}", + sender_id="agent_a", + target_id="agent_b", + decision="ALLOWED", + latency_ms=10.0, + payload={"index": msg_counter}, + ) + + result = await run_benchmark( + name="Audit Log Write", + func=log_message, + iterations=100, + warmup=10, + ) + + print(result) + + # Baseline expectations + # BEFORE optimization: ~100-300 ops/sec (5 sequential calls) + # AFTER optimization: ~500-1500 ops/sec (pipelined) + assert result.throughput_per_second > 50, ( + f"Throughput too low: {result.throughput_per_second:.2f} ops/sec" + ) + + async def test_audit_query_throughput(self, redis): + """Measure audit log query performance.""" + from mas.gateway.audit import AuditModule + + audit = AuditModule(redis) + + # Pre-populate audit log + for i in range(50): + await audit.log_message( + message_id=f"query_msg_{i}", + sender_id="query_agent", + target_id="target_agent", + decision="ALLOWED", + latency_ms=10.0, + payload={"index": i}, + ) + + async def query_by_sender(): + return await audit.query_by_sender("query_agent", count=20) + + result = await run_benchmark( + name="Audit Query by Sender", + func=query_by_sender, + iterations=50, + warmup=5, + ) + + print(result) + + assert result.throughput_per_second > 20 + + +# ----------------------------------------------------------------------------- +# Priority Queue Benchmarks +# ----------------------------------------------------------------------------- + + +class TestPriorityQueuePerformance: + """Benchmark priority queue performance.""" + + async def test_enqueue_throughput(self, redis): + """Measure enqueue operations per second.""" + from mas.gateway.priority_queue import MessagePriority, PriorityQueueModule + + pq = PriorityQueueModule(redis) + msg_counter = 0 + + async def enqueue_message(): + nonlocal msg_counter + msg_counter += 1 + await pq.enqueue( + message_id=f"msg_{msg_counter}", + sender_id="sender", + target_id="pq_bench_target", + payload={"index": msg_counter}, + priority=MessagePriority.NORMAL, + ) + + result = await run_benchmark( + name="Priority Queue Enqueue", + func=enqueue_message, + iterations=100, + warmup=10, + ) + + print(result) + + # Cleanup + await pq.clear_queue("pq_bench_target") + + assert result.throughput_per_second > 50 + + async def test_dequeue_throughput(self, redis): + """Measure dequeue operations per second.""" + from mas.gateway.priority_queue import MessagePriority, PriorityQueueModule + + pq = PriorityQueueModule(redis) + + # Pre-populate queue + for i in range(200): + priority = list(MessagePriority)[i % 5] + await pq.enqueue( + message_id=f"dequeue_msg_{i}", + sender_id="sender", + target_id="dequeue_bench_target", + payload={"index": i}, + priority=priority, + ) + + async def dequeue_message(): + return await pq.dequeue("dequeue_bench_target", max_messages=1) + + result = await run_benchmark( + name="Priority Queue Dequeue", + func=dequeue_message, + iterations=100, + warmup=10, + ) + + print(result) + + # Cleanup + await pq.clear_queue("dequeue_bench_target") + + # Baseline expectations + # BEFORE optimization: ~20-100 ops/sec (fallback iteration) + # AFTER optimization: ~200-500 ops/sec (batched checks) + assert result.throughput_per_second > 20 + + +# ----------------------------------------------------------------------------- +# Circuit Breaker Benchmarks +# ----------------------------------------------------------------------------- + + +class TestCircuitBreakerPerformance: + """Benchmark circuit breaker performance.""" + + async def test_circuit_check_throughput(self, redis): + """Measure circuit check operations per second.""" + from mas.gateway.circuit_breaker import CircuitBreakerModule + + cb = CircuitBreakerModule(redis) + + async def check_circuit(): + return await cb.check_circuit("cb_bench_target") + + result = await run_benchmark( + name="Circuit Breaker Check", + func=check_circuit, + iterations=200, + warmup=20, + ) + + print(result) + + assert result.throughput_per_second > 200 + + async def test_circuit_check_and_record(self, redis): + """Measure check + record cycle (common pattern).""" + from mas.gateway.circuit_breaker import CircuitBreakerModule + + cb = CircuitBreakerModule(redis) + counter = 0 + + async def check_and_record(): + nonlocal counter + counter += 1 + target = f"cb_cycle_target_{counter % 10}" + status = await cb.check_circuit(target) + if status.allowed: + await cb.record_success(target) + + result = await run_benchmark( + name="Circuit Breaker Check+Record Cycle", + func=check_and_record, + iterations=200, + warmup=20, + ) + + print(result) + + # This pattern currently does 2 hgetall calls + # Could be optimized to reuse state + assert result.throughput_per_second > 100 + + +# ----------------------------------------------------------------------------- +# End-to-End Gateway Benchmarks +# ----------------------------------------------------------------------------- + + +class TestGatewayPerformance: + """Benchmark complete gateway message handling.""" + + async def test_gateway_message_throughput(self, redis): + """Measure end-to-end gateway message processing.""" + from mas.gateway import GatewayService + from mas.gateway.config import FeaturesSettings, GatewaySettings + from mas.protocol import EnvelopeMessage + + # Configure gateway with minimal features for baseline + settings = GatewaySettings( + features=FeaturesSettings( + dlp=False, + priority_queue=False, + rbac=False, + message_signing=False, + circuit_breaker=False, + ), + ) + + gateway = GatewayService(settings=settings) + await gateway.start() + + try: + # Setup sender and target + sender_id = "gw_bench_sender" + target_id = "gw_bench_target" + token = "bench_token" + + await redis.hset( + f"agent:{sender_id}", + mapping={ + "token": token, + "status": "ACTIVE", + "token_expires": str(time.time() + 3600), + }, + ) + await redis.hset(f"agent:{target_id}", mapping={"status": "ACTIVE"}) + await gateway.authz.set_permissions(sender_id, allowed_targets=[target_id]) + + msg_counter = 0 + + async def handle_message(): + nonlocal msg_counter + msg_counter += 1 + message = EnvelopeMessage( + sender_id=sender_id, + target_id=target_id, + message_type="benchmark.message", + data={"index": msg_counter}, + ) + return await gateway.handle_message(message, token) + + result = await run_benchmark( + name="Gateway Message Handling (minimal features)", + func=handle_message, + iterations=100, + warmup=10, + ) + + print(result) + + # Gateway should handle messages reasonably fast + assert result.throughput_per_second > 20 + + finally: + await gateway.stop() + + async def test_gateway_with_all_features(self, redis): + """Measure gateway performance with all features enabled.""" + from mas.gateway import GatewayService + from mas.gateway.config import FeaturesSettings, GatewaySettings + from mas.protocol import EnvelopeMessage + + settings = GatewaySettings( + features=FeaturesSettings( + dlp=True, + priority_queue=False, # Keep false to avoid queue complexity + rbac=True, + message_signing=False, # Requires key setup + circuit_breaker=True, + ), + ) + + gateway = GatewayService(settings=settings) + await gateway.start() + + try: + sender_id = "gw_full_sender" + target_id = "gw_full_target" + token = "full_token" + + await redis.hset( + f"agent:{sender_id}", + mapping={ + "token": token, + "status": "ACTIVE", + "token_expires": str(time.time() + 3600), + }, + ) + await redis.hset(f"agent:{target_id}", mapping={"status": "ACTIVE"}) + await gateway.authz.set_permissions(sender_id, allowed_targets=[target_id]) + + msg_counter = 0 + + async def handle_message(): + nonlocal msg_counter + msg_counter += 1 + message = EnvelopeMessage( + sender_id=sender_id, + target_id=target_id, + message_type="benchmark.message", + data={"text": f"Message {msg_counter}", "count": msg_counter}, + ) + return await gateway.handle_message(message, token) + + result = await run_benchmark( + name="Gateway Message Handling (DLP + RBAC + Circuit Breaker)", + func=handle_message, + iterations=100, + warmup=10, + ) + + print(result) + + # With all features, expect lower but still reasonable throughput + assert result.throughput_per_second > 10 + + finally: + await gateway.stop() diff --git a/tests/test_rbac.py b/tests/test_rbac.py index d6e1dae..8e1e160 100644 --- a/tests/test_rbac.py +++ b/tests/test_rbac.py @@ -112,22 +112,22 @@ async def test_unassign_role(self, authz_rbac, redis): class TestPermissionMatching: """Test permission pattern matching.""" - def test_exact_match(self, authz_rbac): + async def test_exact_match(self, authz_rbac): """Test exact permission match.""" assert authz_rbac._matches_permission("send:agent-1", "send:agent-1") - def test_wildcard_all(self, authz_rbac): + async def test_wildcard_all(self, authz_rbac): """Test * wildcard matches everything.""" assert authz_rbac._matches_permission("send:agent-1", "*") assert authz_rbac._matches_permission("read:anything", "*") - def test_wildcard_action(self, authz_rbac): + async def test_wildcard_action(self, authz_rbac): """Test action:* wildcard.""" assert authz_rbac._matches_permission("send:agent-1", "send:*") assert authz_rbac._matches_permission("send:agent-2", "send:*") assert not authz_rbac._matches_permission("read:agent-1", "send:*") - def test_wildcard_pattern(self, authz_rbac): + async def test_wildcard_pattern(self, authz_rbac): """Test pattern wildcards like agent.*""" # agent.* means "agent." followed by anything assert authz_rbac._matches_permission("send:agent.1", "send:agent.*") @@ -136,7 +136,7 @@ def test_wildcard_pattern(self, authz_rbac): assert not authz_rbac._matches_permission("send:user.1", "send:agent.*") assert not authz_rbac._matches_permission("send:agent-1", "send:agent.*") - def test_no_match(self, authz_rbac): + async def test_no_match(self, authz_rbac): """Test permission doesn't match.""" assert not authz_rbac._matches_permission("send:agent-1", "read:agent-1") assert not authz_rbac._matches_permission("send:agent-1", "send:agent-2") diff --git a/tests/test_redis_efficiency.py b/tests/test_redis_efficiency.py new file mode 100644 index 0000000..b95655c --- /dev/null +++ b/tests/test_redis_efficiency.py @@ -0,0 +1,1006 @@ +"""Tests for Redis operation efficiency. + +These tests verify that optimizations reduce Redis round-trips. +Run baseline tests before optimization to document current behavior, +then verify improvements after optimization. + +Usage: + # Run efficiency tests with verbose output + uv run pytest tests/test_redis_efficiency.py -v -s +""" + +from __future__ import annotations + +import json +from collections import defaultdict +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Literal, + Self, + Set, + overload, +) + +import pytest + +from mas.redis_types import AsyncRedisProtocol, PubSubProtocol + +if TYPE_CHECKING: + from mas.redis_types import PipelineProtocol + +pytestmark = pytest.mark.asyncio + + +class PipelineCallCounter: + """ + Pipeline wrapper that counts operations and delegates to real pipeline. + + Tracks pipeline operations as a single "pipeline_execute" call when + execute() is called, which accurately reflects the single round-trip. + """ + + def __init__( + self, real_pipeline: "PipelineProtocol", counter: "RedisCallCounter" + ) -> None: + self._pipeline = real_pipeline + self._counter = counter + self._queued_ops: list[str] = [] + + def hgetall(self, key: str) -> Self: + self._queued_ops.append("hgetall") + self._pipeline.hgetall(key) + return self + + def hset(self, key: str, *, mapping: dict[str, str]) -> Self: + self._queued_ops.append("hset") + self._pipeline.hset(key, mapping=mapping) + return self + + def hget(self, key: str, field: str) -> Self: + self._queued_ops.append("hget") + self._pipeline.hget(key, field) + return self + + def delete(self, *keys: str) -> Self: + self._queued_ops.append("delete") + self._pipeline.delete(*keys) + return self + + def exists(self, key: str) -> Self: + self._queued_ops.append("exists") + self._pipeline.exists(key) + return self + + def get(self, key: str) -> Self: + self._queued_ops.append("get") + self._pipeline.get(key) + return self + + def set(self, key: str, value: str) -> Self: + self._queued_ops.append("set") + self._pipeline.set(key, value) + return self + + def xadd(self, name: str, fields: dict[str, str]) -> Self: + self._queued_ops.append("xadd") + self._pipeline.xadd(name, fields) + return self + + def zadd(self, key: str, mapping: dict[str, float]) -> Self: + self._queued_ops.append("zadd") + self._pipeline.zadd(key, mapping) + return self + + def zcard(self, key: str) -> Self: + self._queued_ops.append("zcard") + self._pipeline.zcard(key) + return self + + def zcount(self, key: str, min: float | str, max: float | str) -> Self: + self._queued_ops.append("zcount") + self._pipeline.zcount(key, min, max) + return self + + def expire(self, key: str, seconds: int) -> Self: + self._queued_ops.append("expire") + self._pipeline.expire(key, seconds) + return self + + def ttl(self, key: str) -> Self: + self._queued_ops.append("ttl") + self._pipeline.ttl(key) + return self + + def setex(self, key: str, seconds: int, value: str) -> Self: + self._queued_ops.append("setex") + self._pipeline.setex(key, seconds, value) + return self + + async def execute(self) -> list[Any]: + # Track as single pipeline call with info about batched ops + self._counter._track("pipeline_execute") + self._counter.pipeline_ops.append(self._queued_ops.copy()) + return await self._pipeline.execute() + + +class RedisCallCounter: + """ + Wrapper that counts Redis operations. + + Proxies all calls to the underlying Redis client while tracking + the number of calls per method. Used to verify that optimizations + reduce the number of Redis round-trips. + + Usage: + counter = RedisCallCounter(real_redis) + # Use counter as redis client + await counter.hget("key", "field") + # Check counts + assert counter.call_counts["hget"] == 1 + assert counter.total_calls == 1 + """ + + def __init__(self, redis: AsyncRedisProtocol): + self._redis = redis + self.call_counts: dict[str, int] = defaultdict(int) + self.total_calls = 0 + self.pipeline_ops: list[list[str]] = [] # Track ops in each pipeline + + def reset_counts(self) -> None: + """Reset all counters to zero.""" + self.call_counts.clear() + self.total_calls = 0 + self.pipeline_ops.clear() + + def _track(self, method: str) -> None: + """Track a method call.""" + self.call_counts[method] += 1 + self.total_calls += 1 + + def get_summary(self) -> str: + """Get a human-readable summary of call counts.""" + lines = [f"Total Redis calls: {self.total_calls}"] + for method, count in sorted(self.call_counts.items()): + lines.append(f" {method}: {count}") + if self.pipeline_ops: + lines.append(f" Pipeline batches: {len(self.pipeline_ops)}") + for i, ops in enumerate(self.pipeline_ops): + lines.append( + f" Batch {i + 1}: {len(ops)} ops ({', '.join(set(ops))})" + ) + return "\n".join(lines) + + # ------------------------------------------------------------------------- + # Pipeline support + # ------------------------------------------------------------------------- + + def pipeline(self) -> PipelineCallCounter: + """Create a pipeline with call counting.""" + return PipelineCallCounter(self._redis.pipeline(), self) + + # ------------------------------------------------------------------------- + # Connection methods + # ------------------------------------------------------------------------- + + def aclose(self) -> Awaitable[None]: + self._track("aclose") + return self._redis.aclose() + + # ------------------------------------------------------------------------- + # Scripting methods + # ------------------------------------------------------------------------- + + async def eval( + self, + script: str, + numkeys: int, + *keys_and_args: str, + ) -> Any: + self._track("eval") + return await self._redis.eval(script, numkeys, *keys_and_args) + + # ------------------------------------------------------------------------- + # Key methods + # ------------------------------------------------------------------------- + + async def exists(self, key: str) -> int: + self._track("exists") + return await self._redis.exists(key) + + async def delete(self, *keys: str) -> int: + self._track("delete") + return await self._redis.delete(*keys) + + async def expire(self, key: str, seconds: int) -> int: + self._track("expire") + return await self._redis.expire(key, seconds) + + async def ttl(self, key: str) -> int: + self._track("ttl") + return await self._redis.ttl(key) + + async def scan( + self, cursor: int, *, match: str, count: int + ) -> tuple[int, list[str]]: + self._track("scan") + return await self._redis.scan(cursor, match=match, count=count) + + def scan_iter(self, *, match: str) -> AsyncIterator[str]: + self._track("scan_iter") + return self._redis.scan_iter(match=match) + + # ------------------------------------------------------------------------- + # String methods + # ------------------------------------------------------------------------- + + async def get(self, key: str) -> str | None: + self._track("get") + return await self._redis.get(key) + + async def set(self, key: str, value: str) -> bool | str: + self._track("set") + return await self._redis.set(key, value) + + async def setex(self, key: str, seconds: int, value: str) -> bool | str: + self._track("setex") + return await self._redis.setex(key, seconds, value) + + async def incr(self, key: str) -> int: + self._track("incr") + return await self._redis.incr(key) + + async def decr(self, key: str) -> int: + self._track("decr") + return await self._redis.decr(key) + + async def publish(self, channel: str, message: str) -> int: + self._track("publish") + return await self._redis.publish(channel, message) + + def pubsub(self) -> PubSubProtocol: + self._track("pubsub") + return self._redis.pubsub() + + # ------------------------------------------------------------------------- + # Hash methods + # ------------------------------------------------------------------------- + + async def hget(self, key: str, field: str) -> str | None: + self._track("hget") + return await self._redis.hget(key, field) + + async def hset(self, key: str, *, mapping: dict[str, str]) -> int: + self._track("hset") + return await self._redis.hset(key, mapping=mapping) + + async def hgetall(self, key: str) -> dict[str, str]: + self._track("hgetall") + return await self._redis.hgetall(key) + + async def hdel(self, key: str, *fields: str) -> int: + self._track("hdel") + return await self._redis.hdel(key, *fields) + + # ------------------------------------------------------------------------- + # Set methods + # ------------------------------------------------------------------------- + + async def sadd(self, key: str, *members: str) -> int: + self._track("sadd") + return await self._redis.sadd(key, *members) + + async def srem(self, key: str, *members: str) -> int: + self._track("srem") + return await self._redis.srem(key, *members) + + async def smembers(self, key: str) -> Set[str]: + self._track("smembers") + return await self._redis.smembers(key) + + async def sismember(self, key: str, member: str) -> bool: + self._track("sismember") + return await self._redis.sismember(key, member) + + # ------------------------------------------------------------------------- + # Sorted set methods + # ------------------------------------------------------------------------- + + async def zadd(self, key: str, mapping: dict[str, float]) -> int: + self._track("zadd") + return await self._redis.zadd(key, mapping) + + async def zcard(self, key: str) -> int: + self._track("zcard") + return await self._redis.zcard(key) + + async def zrem(self, key: str, *members: str) -> int: + self._track("zrem") + return await self._redis.zrem(key, *members) + + async def zremrangebyscore( + self, key: str, min: float | str, max: float | str + ) -> int: + self._track("zremrangebyscore") + return await self._redis.zremrangebyscore(key, min, max) + + async def zcount(self, key: str, min: float | str, max: float | str) -> int: + self._track("zcount") + return await self._redis.zcount(key, min, max) + + async def zscore(self, key: str, member: str) -> float | None: + self._track("zscore") + return await self._redis.zscore(key, member) + + @overload + async def zrange( + self, key: str, start: int, end: int, *, withscores: Literal[True] + ) -> list[tuple[str, float]]: ... + + @overload + async def zrange( + self, key: str, start: int, end: int, *, withscores: Literal[False] = ... + ) -> list[str]: ... + + async def zrange( + self, key: str, start: int, end: int, *, withscores: bool = False + ) -> list[tuple[str, float]] | list[str]: + self._track("zrange") + if withscores: + return await self._redis.zrange(key, start, end, withscores=True) + return await self._redis.zrange(key, start, end, withscores=False) + + # ------------------------------------------------------------------------- + # Stream methods + # ------------------------------------------------------------------------- + + async def xadd(self, name: str, fields: dict[str, str]) -> str: + self._track("xadd") + return await self._redis.xadd(name, fields) + + async def xrange( + self, name: str, min: str, max: str, count: int | None = None + ) -> list[tuple[str, dict[str, str]]]: + self._track("xrange") + return await self._redis.xrange(name, min, max, count) + + async def xlen(self, name: str) -> int: + self._track("xlen") + return await self._redis.xlen(name) + + async def xgroup_create( + self, name: str, groupname: str, id: str = "$", mkstream: bool = False + ) -> str: + self._track("xgroup_create") + return await self._redis.xgroup_create(name, groupname, id, mkstream) + + async def xreadgroup( + self, + groupname: str, + consumername: str, + *, + streams: dict[str, str], + count: int | None = None, + block: int | None = None, + ) -> list[tuple[str, list[tuple[str, dict[str, str]]]]] | None: + self._track("xreadgroup") + return await self._redis.xreadgroup( + groupname, consumername, streams=streams, count=count, block=block + ) + + async def xack(self, name: str, groupname: str, *ids: str) -> int: + self._track("xack") + return await self._redis.xack(name, groupname, *ids) + + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture +async def counter(redis) -> RedisCallCounter: + """Redis client with call counting.""" + return RedisCallCounter(redis) + + +@pytest.fixture +async def setup_test_agents(redis) -> list[str]: + """Setup multiple test agents for discovery testing.""" + agent_ids = [f"efficiency_agent_{i}" for i in range(10)] + for agent_id in agent_ids: + await redis.hset( + f"agent:{agent_id}", + mapping={ + "id": agent_id, + "capabilities": json.dumps(["efficiency_test"]), + "metadata": "{}", + "status": "ACTIVE", + "registered_at": "1234567890", + }, + ) + yield agent_ids + # Cleanup + for agent_id in agent_ids: + await redis.delete(f"agent:{agent_id}") + + +# ----------------------------------------------------------------------------- +# Rate Limit Efficiency Tests +# ----------------------------------------------------------------------------- + + +class TestRateLimitEfficiency: + """ + Verify rate limit Redis call counts. + + BEFORE optimization: + - get_limits(): 2 hget calls + - _check_window() for minute: zremrangebyscore + zcard + zadd + expire = 4 calls + - _check_window() for hour: zremrangebyscore + zcard + zadd + expire = 4 calls + - Total: ~10 calls per rate limit check + + AFTER optimization (Lua script): + - Single eval call: 1 call + """ + + async def test_optimized_single_check(self, counter: RedisCallCounter): + """Optimized: Single Lua script call for rate limit check.""" + from mas.gateway.rate_limit import RateLimitModule + + module = RateLimitModule(counter, default_per_minute=100, default_per_hour=1000) + + result = await module.check_rate_limit("test_agent", "msg_1") + + print(f"\n{counter.get_summary()}") + print(f"Rate limit result: allowed={result.allowed}") + + # AFTER optimization: expect exactly 1 eval call + assert counter.call_counts["eval"] == 1, ( + f"Expected 1 eval call, got {counter.call_counts['eval']}" + ) + assert counter.total_calls == 1, ( + f"Expected 1 total call (eval), got {counter.total_calls}" + ) + + async def test_optimized_multiple_checks(self, counter: RedisCallCounter): + """Optimized: Multiple checks use 1 call each.""" + from mas.gateway.rate_limit import RateLimitModule + + module = RateLimitModule(counter, default_per_minute=100, default_per_hour=1000) + + num_checks = 5 + for i in range(num_checks): + await module.check_rate_limit("test_agent", f"msg_{i}") + + print(f"\n{counter.get_summary()}") + + # AFTER optimization: 1 eval call per check + calls_per_check = counter.total_calls / num_checks + print(f"Calls per check: {calls_per_check:.1f}") + + assert counter.call_counts["eval"] == num_checks, ( + f"Expected {num_checks} eval calls, got {counter.call_counts['eval']}" + ) + assert calls_per_check == 1.0, ( + f"Expected 1 call per check, got {calls_per_check}" + ) + + +# ----------------------------------------------------------------------------- +# Discovery Efficiency Tests +# ----------------------------------------------------------------------------- + + +class TestDiscoveryEfficiency: + """ + Verify discovery Redis call counts. + + BEFORE optimization: + - 1 scan_iter call + - N hgetall calls (one per agent found) + - Total: O(N+1) calls + + AFTER optimization (pipeline batching): + - 1 scan_iter call + - 1 pipeline_execute call (batched hgetall) + - Total: O(2) calls + """ + + async def test_optimized_discovery( + self, counter: RedisCallCounter, setup_test_agents: list[str] + ): + """Optimized: Pipeline batches all hgetall calls.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(counter) + + agents = await registry.discover(capabilities=["efficiency_test"]) + + print(f"\n{counter.get_summary()}") + print(f"Discovered {len(agents)} agents") + + num_agents = len(setup_test_agents) + assert len(agents) == num_agents, f"Expected {num_agents} agents" + + # AFTER optimization: 1 scan_iter + 1 pipeline_execute = 2 calls + assert counter.call_counts["scan_iter"] == 1, ( + f"Expected 1 scan_iter call, got {counter.call_counts['scan_iter']}" + ) + assert counter.call_counts["pipeline_execute"] == 1, ( + f"Expected 1 pipeline_execute call, got {counter.call_counts['pipeline_execute']}" + ) + assert counter.total_calls == 2, ( + f"Expected 2 total calls (scan_iter + pipeline), got {counter.total_calls}" + ) + + # Verify pipeline contained N hgetall operations + assert len(counter.pipeline_ops) == 1, "Expected 1 pipeline batch" + assert len(counter.pipeline_ops[0]) == num_agents, ( + f"Expected {num_agents} ops in pipeline, got {len(counter.pipeline_ops[0])}" + ) + + async def test_optimized_discovery_no_filter( + self, counter: RedisCallCounter, setup_test_agents: list[str] + ): + """Optimized: Discovery without filter also uses pipeline.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(counter) + + # Discover all agents (no capability filter) + agents = await registry.discover() + + print(f"\n{counter.get_summary()}") + print(f"Discovered {len(agents)} agents (no filter)") + + # Should use pipeline (2 total calls) + assert counter.total_calls == 2, ( + f"Expected 2 total calls, got {counter.total_calls}" + ) + assert counter.call_counts["pipeline_execute"] == 1 + + +# ----------------------------------------------------------------------------- +# Audit Module Efficiency Tests +# ----------------------------------------------------------------------------- + + +class TestAuditEfficiency: + """ + Verify audit logging Redis call counts. + + BEFORE optimization: + - get (previous_hash): 1 call + - xadd (main stream): 1 call + - xadd (sender index): 1 call + - xadd (target index): 1 call + - set (update hash): 1 call + - Total: 5 sequential calls + + AFTER optimization (pipeline): + - get (previous_hash): 1 call (needed before pipeline) + - 1 pipeline_execute call (batched xadd + set) + - Total: 2 calls + """ + + async def test_optimized_log_message(self, counter: RedisCallCounter): + """Optimized: Pipeline batches xadd and set calls.""" + from mas.gateway.audit import AuditModule + + audit = AuditModule(counter) + + await audit.log_message( + message_id="msg_1", + sender_id="agent_a", + target_id="agent_b", + decision="ALLOWED", + latency_ms=10.0, + payload={"test": "data"}, + ) + + print(f"\n{counter.get_summary()}") + + # AFTER optimization: get + 1 pipeline_execute = 2 calls + assert counter.call_counts["get"] == 1, ( + f"Expected 1 get call, got {counter.call_counts['get']}" + ) + assert counter.call_counts["pipeline_execute"] == 1, ( + f"Expected 1 pipeline_execute call, got {counter.call_counts['pipeline_execute']}" + ) + assert counter.total_calls == 2, ( + f"Expected 2 total calls (get + pipeline), got {counter.total_calls}" + ) + + # Verify pipeline contained 4 operations (3 xadd + 1 set) + assert len(counter.pipeline_ops) == 1, "Expected 1 pipeline batch" + assert counter.pipeline_ops[0].count("xadd") == 3, ( + f"Expected 3 xadd ops in pipeline, got {counter.pipeline_ops[0]}" + ) + assert counter.pipeline_ops[0].count("set") == 1, ( + f"Expected 1 set op in pipeline, got {counter.pipeline_ops[0]}" + ) + + async def test_optimized_multiple_logs(self, counter: RedisCallCounter): + """Optimized: Multiple logs use 2 calls each.""" + from mas.gateway.audit import AuditModule + + audit = AuditModule(counter) + + num_logs = 5 + for i in range(num_logs): + await audit.log_message( + message_id=f"msg_{i}", + sender_id="agent_a", + target_id="agent_b", + decision="ALLOWED", + latency_ms=10.0, + payload={"index": i}, + ) + + print(f"\n{counter.get_summary()}") + + # AFTER optimization: 2 calls per log × 5 logs = 10 calls + calls_per_log = counter.total_calls / num_logs + print(f"Calls per log: {calls_per_log:.1f}") + + assert calls_per_log == 2.0, f"Expected 2 calls per log, got {calls_per_log}" + assert counter.call_counts["pipeline_execute"] == num_logs, ( + f"Expected {num_logs} pipeline_execute calls, got {counter.call_counts['pipeline_execute']}" + ) + + +# ----------------------------------------------------------------------------- +# Priority Queue Efficiency Tests +# ----------------------------------------------------------------------------- + + +class TestPriorityQueueEfficiency: + """ + Verify priority queue Redis call counts. + + BEFORE optimization: + - Enqueue: setex + zadd + expire + zcount + zcard = ~6 calls + - Dequeue: iterates through priorities, up to 20 calls + + AFTER optimization (pipeline batching): + - Enqueue: 1 pipeline (setex + zadd + expire) + 1 pipeline (position) = 2 calls + - Dequeue: 1 pipeline (check all queues) + dequeue ops = ~3-4 calls + """ + + async def test_optimized_enqueue(self, counter: RedisCallCounter): + """Optimized: Enqueue uses pipeline for batched operations.""" + from mas.gateway.priority_queue import MessagePriority, PriorityQueueModule + + pq = PriorityQueueModule(counter) + + result = await pq.enqueue( + message_id="msg_1", + sender_id="agent_a", + target_id="test_target", + payload={"test": "data"}, + priority=MessagePriority.NORMAL, + ) + + print(f"\n{counter.get_summary()}") + print(f"Enqueue result: success={result.success}") + + # AFTER optimization: 2 pipeline_execute calls + # - First pipeline: setex + zadd + expire (enqueue) + # - Second pipeline: zcard + zcount (position estimation) + assert counter.call_counts["pipeline_execute"] == 2, ( + f"Expected 2 pipeline_execute calls, got {counter.call_counts['pipeline_execute']}" + ) + assert counter.total_calls == 2, ( + f"Expected 2 total calls (pipelines), got {counter.total_calls}" + ) + + async def test_optimized_dequeue_single_priority(self, counter: RedisCallCounter): + """Optimized: Dequeue uses pipeline to check all queues first.""" + from mas.gateway.priority_queue import MessagePriority, PriorityQueueModule + + # Use real redis for setup, then switch to counter + pq_setup = PriorityQueueModule(counter._redis) + await pq_setup.enqueue( + message_id="msg_1", + sender_id="agent_a", + target_id="dequeue_test", + payload={"test": "data"}, + priority=MessagePriority.NORMAL, + ) + + # Reset counter and dequeue + counter.reset_counts() + pq = PriorityQueueModule(counter) + + messages = await pq.dequeue("dequeue_test", max_messages=1) + + print(f"\n{counter.get_summary()}") + print(f"Dequeued {len(messages)} messages") + + # AFTER optimization: should use pipeline for initial queue check + assert counter.call_counts["pipeline_execute"] >= 1, ( + f"Expected at least 1 pipeline_execute call, got {counter.call_counts['pipeline_execute']}" + ) + # Total calls should be much lower than before (was up to 25) + assert counter.total_calls <= 6, f"Too many calls: {counter.total_calls}" + + +# ----------------------------------------------------------------------------- +# Circuit Breaker Efficiency Tests +# ----------------------------------------------------------------------------- + + +class TestCircuitBreakerEfficiency: + """ + Verify circuit breaker Redis call counts. + + BEFORE optimization: + - check_circuit: 1 hgetall call + - record_success/failure: 1 hgetall call + - Combined check+record: 2 hgetall calls (double-fetch) + + AFTER optimization: + - check_circuit: 1 hgetall call (unchanged) + - check_and_record_success/failure: 1 hgetall call (combined) + """ + + async def test_baseline_check_circuit(self, counter: RedisCallCounter): + """Baseline: Count Redis calls for circuit check.""" + from mas.gateway.circuit_breaker import CircuitBreakerModule + + cb = CircuitBreakerModule(counter) + + status = await cb.check_circuit("test_target") + + print(f"\n{counter.get_summary()}") + print(f"Circuit status: state={status.state}, allowed={status.allowed}") + + # Single hgetall for circuit state + assert counter.call_counts["hgetall"] == 1 + + async def test_baseline_check_then_record(self, counter: RedisCallCounter): + """Baseline: Separate check+record uses 2 hgetall calls.""" + from mas.gateway.circuit_breaker import CircuitBreakerModule + + cb = CircuitBreakerModule(counter) + + # Separate calls pattern (legacy) + status = await cb.check_circuit("test_target") + if status.allowed: + await cb.record_success("test_target") + + print(f"\n{counter.get_summary()}") + + # Separate calls: 2 hgetall calls (double-fetch) + assert counter.call_counts["hgetall"] == 2, ( + f"Expected 2 hgetall calls (double fetch), " + f"got {counter.call_counts['hgetall']}" + ) + + async def test_optimized_check_and_record_success(self, counter: RedisCallCounter): + """Optimized: Combined check+record uses 1 hgetall call.""" + from mas.gateway.circuit_breaker import CircuitBreakerModule + + cb = CircuitBreakerModule(counter) + + # Combined method avoids double-fetch + check_status, record_status = await cb.check_and_record_success( + "optimized_target" + ) + + print(f"\n{counter.get_summary()}") + print( + f"Check status: state={check_status.state}, allowed={check_status.allowed}" + ) + print(f"Record status: state={record_status.state}") + + # AFTER optimization: only 1 hgetall call (combined operation) + assert counter.call_counts["hgetall"] == 1, ( + f"Expected 1 hgetall call (combined), got {counter.call_counts['hgetall']}" + ) + assert counter.total_calls == 1, ( + f"Expected 1 total call, got {counter.total_calls}" + ) + + async def test_optimized_check_and_record_failure(self, counter: RedisCallCounter): + """Optimized: Combined check+failure uses 1 hgetall + 1 hset.""" + from mas.gateway.circuit_breaker import CircuitBreakerModule + + cb = CircuitBreakerModule(counter) + + # Combined method avoids double-fetch + check_status, record_status = await cb.check_and_record_failure( + "failure_target", "test_failure" + ) + + print(f"\n{counter.get_summary()}") + print(f"Check status: allowed={check_status.allowed}") + print(f"Record status: failure_count={record_status.failure_count}") + + # AFTER optimization: 1 hgetall (read) + 1 hset (write) + 1 expire + assert counter.call_counts["hgetall"] == 1, ( + f"Expected 1 hgetall call, got {counter.call_counts['hgetall']}" + ) + # Total: hgetall + hset + expire = 3 calls + assert counter.total_calls == 3, ( + f"Expected 3 total calls (hgetall + hset + expire), got {counter.total_calls}" + ) + + +# ----------------------------------------------------------------------------- +# Health Monitor Efficiency Tests +# ----------------------------------------------------------------------------- + + +class TestHealthMonitorEfficiency: + """ + Verify health monitor Redis call counts. + + BEFORE optimization: + - Two separate scan_iter calls (heartbeat keys, agent keys) + - Individual ttl, exists, hget, hset calls per key + - Total: O(2N) calls + + AFTER optimization: + - Single scan to collect all agent keys + - Pipeline batch-fetch TTLs and agent data + - Pipeline batch-update stale agents + - Total: O(3) calls (scan + 2 pipelines) + """ + + async def test_optimized_health_monitor_pattern( + self, counter: RedisCallCounter, setup_test_agents: list[str] + ): + """ + Optimized: Document the pipeline-based health monitoring pattern. + + The optimized _monitor_health uses: + 1. Single scan_iter to collect agent keys + 2. Pipeline to batch-fetch TTLs and status + 3. Pipeline to batch-update stale agents (if any) + """ + # Simulate the optimized pattern: single scan + pipeline + agent_keys: list[str] = [] + async for key in counter.scan_iter(match="agent:*"): + if key.count(":") == 1: # Only agent hashes, not heartbeat keys + agent_keys.append(key) + + print(f"\n{counter.get_summary()}") + print(f"Found {len(agent_keys)} agent keys") + + # Should only have 1 scan_iter call (optimized pattern) + assert counter.call_counts["scan_iter"] == 1, ( + f"Expected 1 scan_iter call, got {counter.call_counts['scan_iter']}" + ) + + # Now simulate the pipeline batch-fetch + if agent_keys: + pipe = counter.pipeline() + for agent_key in agent_keys: + agent_id = agent_key.split(":")[1] + hb_key = f"agent:{agent_id}:heartbeat" + pipe.ttl(hb_key) + pipe.hget(agent_key, "status") + pipe.hget(agent_key, "registered_at") + await pipe.execute() + + print(f"\nAfter batch fetch:\n{counter.get_summary()}") + + # Should have 1 scan + 1 pipeline = 2 total calls + assert counter.total_calls == 2, ( + f"Expected 2 total calls (scan + pipeline), got {counter.total_calls}" + ) + assert counter.call_counts["pipeline_execute"] == 1, ( + f"Expected 1 pipeline_execute call, got {counter.call_counts['pipeline_execute']}" + ) + + +# ----------------------------------------------------------------------------- +# Registry Deregister Efficiency Tests +# ----------------------------------------------------------------------------- + + +class TestDeregisterEfficiency: + """ + Verify deregister Redis call counts. + + BEFORE optimization: + - 2-3 separate delete calls + - Total: 2-3 calls + + AFTER optimization (pipeline): + - 1 pipeline_execute call (batched deletes) + - Total: 1 call + """ + + async def test_optimized_deregister(self, counter: RedisCallCounter): + """Optimized: Pipeline batches all delete calls.""" + from mas.registry import AgentRegistry + + # Setup agent first (with multi-instance support) + instance_id = "testinst" + await counter._redis.hset( + "agent:dereg_test", + mapping={ + "id": "dereg_test", + "capabilities": "[]", + "metadata": "{}", + "status": "ACTIVE", + "registered_at": "123", + }, + ) + # Set instance count to 1 (last instance) + await counter._redis.set("agent:dereg_test:instance_count", "1") + # New heartbeat format includes instance_id + await counter._redis.set(f"agent:dereg_test:heartbeat:{instance_id}", "123") + + counter.reset_counts() + registry = AgentRegistry(counter) + + await registry.deregister("dereg_test", instance_id, keep_state=True) + + print(f"\n{counter.get_summary()}") + + # Multi-instance deregister pattern: + # 1. DECR instance_count + # 2. DELETE instance heartbeat + # 3. Pipeline for cleanup (when last instance) + assert counter.call_counts["decr"] == 1, ( + f"Expected 1 decr call, got {counter.call_counts.get('decr', 0)}" + ) + assert counter.call_counts["delete"] == 1, ( + f"Expected 1 delete call for heartbeat, got {counter.call_counts.get('delete', 0)}" + ) + assert counter.call_counts["pipeline_execute"] == 1, ( + f"Expected 1 pipeline_execute call, got {counter.call_counts.get('pipeline_execute', 0)}" + ) + + async def test_optimized_deregister_with_state(self, counter: RedisCallCounter): + """Optimized: Deregister with state deletion uses single pipeline.""" + from mas.registry import AgentRegistry + + # Setup agent and state (with multi-instance support) + instance_id = "testinst" + await counter._redis.hset( + "agent:dereg_state_test", + mapping={ + "id": "dereg_state_test", + "capabilities": "[]", + "metadata": "{}", + "status": "ACTIVE", + "registered_at": "123", + }, + ) + # Set instance count to 1 (last instance) + await counter._redis.set("agent:dereg_state_test:instance_count", "1") + # New heartbeat format includes instance_id + await counter._redis.set( + f"agent:dereg_state_test:heartbeat:{instance_id}", "123" + ) + await counter._redis.hset( + "agent.state:dereg_state_test", mapping={"key": "value"} + ) + + counter.reset_counts() + registry = AgentRegistry(counter) + + await registry.deregister("dereg_state_test", instance_id, keep_state=False) + + print(f"\n{counter.get_summary()}") + + # Multi-instance deregister pattern: + # 1. DECR instance_count + # 2. DELETE instance heartbeat + # 3. Pipeline for cleanup (when last instance, includes state deletion) + assert counter.call_counts["decr"] == 1, ( + f"Expected 1 decr call, got {counter.call_counts.get('decr', 0)}" + ) + assert counter.call_counts["delete"] == 1, ( + f"Expected 1 delete call for heartbeat, got {counter.call_counts.get('delete', 0)}" + ) + assert counter.call_counts["pipeline_execute"] == 1, ( + f"Expected 1 pipeline_execute call, got {counter.call_counts.get('pipeline_execute', 0)}" + ) diff --git a/tests/test_regression.py b/tests/test_regression.py new file mode 100644 index 0000000..49f1d5d --- /dev/null +++ b/tests/test_regression.py @@ -0,0 +1,730 @@ +"""Regression tests to ensure optimizations don't break functionality. + +These tests verify that core behaviors remain correct after optimization. +Run these tests before and after each optimization to ensure no regressions. + +Usage: + uv run pytest tests/test_regression.py -v +""" + +import asyncio +import json +import time + +import pytest + +pytestmark = pytest.mark.asyncio + + +# ----------------------------------------------------------------------------- +# Rate Limit Regression Tests +# ----------------------------------------------------------------------------- + + +class TestRateLimitRegression: + """Ensure rate limiting behavior is preserved after optimization.""" + + async def test_allows_requests_within_limit(self, redis): + """Requests within the limit must be allowed.""" + from mas.gateway.rate_limit import RateLimitModule + + module = RateLimitModule(redis, default_per_minute=10, default_per_hour=100) + agent_id = "regression_rate_allow" + + for i in range(10): + result = await module.check_rate_limit(agent_id, f"msg_{i}") + assert result.allowed, f"Message {i} should be allowed (within limit)" + assert result.remaining >= 0 + + async def test_blocks_requests_over_limit(self, redis): + """Requests over the limit must be blocked.""" + from mas.gateway.rate_limit import RateLimitModule + + module = RateLimitModule(redis, default_per_minute=5, default_per_hour=100) + agent_id = "regression_rate_block" + + # First 5 should pass + for i in range(5): + result = await module.check_rate_limit(agent_id, f"msg_{i}") + assert result.allowed, f"Message {i} should be allowed" + + # 6th should be blocked + result = await module.check_rate_limit(agent_id, "msg_blocked") + assert not result.allowed, "Message over limit should be blocked" + assert result.remaining == 0 + + async def test_custom_limits_applied(self, redis): + """Custom per-agent limits must be respected.""" + from mas.gateway.rate_limit import RateLimitModule + + module = RateLimitModule(redis, default_per_minute=100, default_per_hour=1000) + agent_id = "regression_rate_custom" + + # Set custom low limit + await module.set_limits(agent_id, per_minute=3) + + # First 3 should pass + for i in range(3): + result = await module.check_rate_limit(agent_id, f"custom_msg_{i}") + assert result.allowed + + # 4th should be blocked (custom limit) + result = await module.check_rate_limit(agent_id, "custom_msg_blocked") + assert not result.allowed + + async def test_different_agents_independent(self, redis): + """Each agent has independent rate limits.""" + from mas.gateway.rate_limit import RateLimitModule + + module = RateLimitModule(redis, default_per_minute=2, default_per_hour=100) + + # Agent A uses its limit + for i in range(2): + result = await module.check_rate_limit("agent_a", f"msg_{i}") + assert result.allowed + + # Agent A is now blocked + result = await module.check_rate_limit("agent_a", "msg_blocked") + assert not result.allowed + + # Agent B should still be allowed (independent limit) + result = await module.check_rate_limit("agent_b", "msg_0") + assert result.allowed, "Agent B should have independent limit" + + async def test_remaining_count_accurate(self, redis): + """Remaining count must be accurate.""" + from mas.gateway.rate_limit import RateLimitModule + + module = RateLimitModule(redis, default_per_minute=5, default_per_hour=100) + agent_id = "regression_rate_remaining" + + result = await module.check_rate_limit(agent_id, "msg_0") + assert result.remaining == 4 # 5 - 1 = 4 + + result = await module.check_rate_limit(agent_id, "msg_1") + assert result.remaining == 3 # 5 - 2 = 3 + + +# ----------------------------------------------------------------------------- +# Discovery Regression Tests +# ----------------------------------------------------------------------------- + + +class TestDiscoveryRegression: + """Ensure discovery behavior is preserved after optimization.""" + + @pytest.fixture + async def setup_discovery_agents(self, redis): + """Setup agents with various states and capabilities.""" + agents = [ + ("active_nlp", ["nlp", "text"], "ACTIVE"), + ("active_vision", ["vision", "image"], "ACTIVE"), + ("active_both", ["nlp", "vision"], "ACTIVE"), + ("inactive_nlp", ["nlp"], "INACTIVE"), + ("active_no_caps", [], "ACTIVE"), + ] + + for agent_id, caps, status in agents: + await redis.hset( + f"agent:{agent_id}", + mapping={ + "id": agent_id, + "capabilities": json.dumps(caps), + "metadata": json.dumps({"test": True}), + "status": status, + "registered_at": str(time.time()), + }, + ) + + yield agents + + # Cleanup + for agent_id, _, _ in agents: + await redis.delete(f"agent:{agent_id}") + + async def test_discovers_all_active_agents(self, redis, setup_discovery_agents): + """Discovery without filter returns all active agents.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + agents = await registry.discover() + + agent_ids = {a["id"] for a in agents} + + # Should find all active agents + assert "active_nlp" in agent_ids + assert "active_vision" in agent_ids + assert "active_both" in agent_ids + assert "active_no_caps" in agent_ids + + # Should NOT find inactive agents + assert "inactive_nlp" not in agent_ids + + async def test_discovers_by_capability(self, redis, setup_discovery_agents): + """Discovery filters by capability correctly.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + + # Find NLP agents + nlp_agents = await registry.discover(capabilities=["nlp"]) + nlp_ids = {a["id"] for a in nlp_agents} + + assert "active_nlp" in nlp_ids + assert "active_both" in nlp_ids + assert "active_vision" not in nlp_ids + assert "inactive_nlp" not in nlp_ids # Inactive excluded + + async def test_discovers_with_multiple_capabilities( + self, redis, setup_discovery_agents + ): + """Discovery with multiple capabilities uses OR logic.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + + # Find agents with nlp OR vision + agents = await registry.discover(capabilities=["nlp", "vision"]) + agent_ids = {a["id"] for a in agents} + + assert "active_nlp" in agent_ids + assert "active_vision" in agent_ids + assert "active_both" in agent_ids + + async def test_discovery_returns_correct_data(self, redis, setup_discovery_agents): + """Discovery returns complete and correct agent data.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + + agents = await registry.discover(capabilities=["nlp"]) + + # Find specific agent + nlp_agent = next(a for a in agents if a["id"] == "active_nlp") + + assert nlp_agent["id"] == "active_nlp" + assert "nlp" in nlp_agent["capabilities"] + assert "text" in nlp_agent["capabilities"] + assert nlp_agent["metadata"]["test"] is True + + async def test_empty_discovery_returns_empty_list(self, redis): + """Discovery with no matches returns empty list.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + + agents = await registry.discover(capabilities=["nonexistent_capability"]) + + assert agents == [] + + +# ----------------------------------------------------------------------------- +# Audit Regression Tests +# ----------------------------------------------------------------------------- + + +class TestAuditRegression: + """Ensure audit logging behavior is preserved after optimization.""" + + async def test_log_message_creates_entry(self, redis): + """Logging a message creates an audit entry.""" + from mas.gateway.audit import AuditModule + + audit = AuditModule(redis) + + stream_id = await audit.log_message( + message_id="regression_msg_1", + sender_id="sender_a", + target_id="target_b", + decision="ALLOWED", + latency_ms=15.5, + payload={"test": "data"}, + violations=[], + ) + + assert stream_id is not None + assert isinstance(stream_id, str) + + async def test_query_by_sender(self, redis): + """Audit entries can be queried by sender.""" + from mas.gateway.audit import AuditModule + + audit = AuditModule(redis) + + # Log messages from different senders + await audit.log_message("msg_s1_1", "sender_one", "target", "ALLOWED", 10.0, {}) + await audit.log_message("msg_s1_2", "sender_one", "target", "ALLOWED", 10.0, {}) + await audit.log_message("msg_s2_1", "sender_two", "target", "ALLOWED", 10.0, {}) + + # Query by sender_one + entries = await audit.query_by_sender("sender_one") + + assert len(entries) >= 2 + assert all(e["sender_id"] == "sender_one" for e in entries) + + async def test_query_by_target(self, redis): + """Audit entries can be queried by target.""" + from mas.gateway.audit import AuditModule + + audit = AuditModule(redis) + + await audit.log_message("msg_t1", "sender", "target_one", "ALLOWED", 10.0, {}) + await audit.log_message("msg_t2", "sender", "target_two", "DENIED", 10.0, {}) + + entries = await audit.query_by_target("target_one") + + assert len(entries) >= 1 + assert all(e["target_id"] == "target_one" for e in entries) + + async def test_violations_stored_correctly(self, redis): + """Violations are stored and retrieved correctly.""" + from mas.gateway.audit import AuditModule + + audit = AuditModule(redis) + + await audit.log_message( + message_id="msg_violations", + sender_id="violation_sender", + target_id="target", + decision="DLP_BLOCKED", + latency_ms=10.0, + payload={"sensitive": "data"}, + violations=["ssn", "credit_card"], + ) + + entries = await audit.query_by_sender("violation_sender") + entry = next(e for e in entries if e["message_id"] == "msg_violations") + + assert "ssn" in entry["violations"] + assert "credit_card" in entry["violations"] + + async def test_security_events_logged(self, redis): + """Security events are logged to separate stream.""" + from mas.gateway.audit import AuditModule + + audit = AuditModule(redis) + + stream_id = await audit.log_security_event( + "AUTH_FAILURE", + {"agent_id": "bad_agent", "reason": "invalid_token"}, + ) + + assert stream_id is not None + + events = await audit.query_security_events() + assert len(events) >= 1 + + +# ----------------------------------------------------------------------------- +# Priority Queue Regression Tests +# ----------------------------------------------------------------------------- + + +class TestPriorityQueueRegression: + """Ensure priority queue behavior is preserved after optimization.""" + + async def test_enqueue_dequeue_preserves_data(self, redis): + """Enqueued data is preserved on dequeue.""" + from mas.gateway.priority_queue import MessagePriority, PriorityQueueModule + + pq = PriorityQueueModule(redis) + + await pq.enqueue( + message_id="preserve_msg", + sender_id="preserve_sender", + target_id="preserve_target", + payload={"key": "value", "number": 42}, + priority=MessagePriority.NORMAL, + ) + + messages = await pq.dequeue("preserve_target", max_messages=1) + + assert len(messages) == 1 + msg = messages[0] + assert msg.message_id == "preserve_msg" + assert msg.sender_id == "preserve_sender" + assert msg.target_id == "preserve_target" + assert msg.payload == {"key": "value", "number": 42} + + async def test_priority_ordering(self, redis): + """Higher priority messages are dequeued first.""" + from mas.gateway.priority_queue import MessagePriority, PriorityQueueModule + + pq = PriorityQueueModule(redis) + target = "priority_order_target" + + # Enqueue in reverse priority order + await pq.enqueue("msg_low", "s", target, {"p": "low"}, MessagePriority.LOW) + await pq.enqueue( + "msg_critical", "s", target, {"p": "critical"}, MessagePriority.CRITICAL + ) + await pq.enqueue( + "msg_normal", "s", target, {"p": "normal"}, MessagePriority.NORMAL + ) + + # Dequeue all + messages = await pq.dequeue(target, max_messages=3) + + # Critical should be first + assert messages[0].message_id == "msg_critical" + assert messages[1].message_id == "msg_normal" + assert messages[2].message_id == "msg_low" + + await pq.clear_queue(target) + + async def test_fifo_within_priority(self, redis): + """Messages with same priority are FIFO ordered.""" + from mas.gateway.priority_queue import MessagePriority, PriorityQueueModule + + pq = PriorityQueueModule(redis) + target = "fifo_target" + + # Enqueue in order + for i in range(3): + await pq.enqueue( + f"fifo_msg_{i}", "s", target, {"i": i}, MessagePriority.NORMAL + ) + await asyncio.sleep(0.01) # Ensure different timestamps + + messages = await pq.dequeue(target, max_messages=3) + + # Should be in FIFO order + assert messages[0].payload["i"] == 0 + assert messages[1].payload["i"] == 1 + assert messages[2].payload["i"] == 2 + + await pq.clear_queue(target) + + async def test_separate_queues_per_target(self, redis): + """Each target has independent queues.""" + from mas.gateway.priority_queue import MessagePriority, PriorityQueueModule + + pq = PriorityQueueModule(redis) + + await pq.enqueue("msg_a", "s", "target_a", {"t": "a"}, MessagePriority.NORMAL) + await pq.enqueue("msg_b", "s", "target_b", {"t": "b"}, MessagePriority.NORMAL) + + messages_a = await pq.dequeue("target_a", max_messages=1) + messages_b = await pq.dequeue("target_b", max_messages=1) + + assert len(messages_a) == 1 + assert messages_a[0].payload["t"] == "a" + + assert len(messages_b) == 1 + assert messages_b[0].payload["t"] == "b" + + await pq.clear_queue("target_a") + await pq.clear_queue("target_b") + + +# ----------------------------------------------------------------------------- +# Circuit Breaker Regression Tests +# ----------------------------------------------------------------------------- + + +class TestCircuitBreakerRegression: + """Ensure circuit breaker behavior is preserved after optimization.""" + + async def test_closed_by_default(self, redis): + """Circuit is closed (allowing traffic) by default.""" + from mas.gateway.circuit_breaker import CircuitBreakerModule, CircuitState + + cb = CircuitBreakerModule(redis) + + status = await cb.check_circuit("new_target") + + assert status.state == CircuitState.CLOSED + assert status.allowed is True + assert status.failure_count == 0 + + async def test_opens_after_failures(self, redis): + """Circuit opens after threshold failures.""" + from mas.gateway.circuit_breaker import ( + CircuitBreakerConfig, + CircuitBreakerModule, + CircuitState, + ) + + config = CircuitBreakerConfig(failure_threshold=3) + cb = CircuitBreakerModule(redis, config=config) + target = "failing_target" + + # Record failures + for i in range(3): + await cb.record_failure(target, f"failure_{i}") + + status = await cb.check_circuit(target) + + assert status.state == CircuitState.OPEN + assert status.allowed is False + + await cb.reset_circuit(target) + + async def test_success_resets_failure_count(self, redis): + """Success in closed state resets failure count.""" + from mas.gateway.circuit_breaker import ( + CircuitBreakerConfig, + CircuitBreakerModule, + ) + + config = CircuitBreakerConfig(failure_threshold=5) + cb = CircuitBreakerModule(redis, config=config) + target = "reset_target" + + # Record some failures (not enough to trip) + await cb.record_failure(target, "fail_1") + await cb.record_failure(target, "fail_2") + + status = await cb.check_circuit(target) + assert status.failure_count == 2 + + # Success should reset + await cb.record_success(target) + + status = await cb.check_circuit(target) + assert status.failure_count == 0 + + await cb.reset_circuit(target) + + async def test_half_open_closes_on_success(self, redis): + """Circuit closes from half-open after success threshold. + + Note: This test directly manipulates circuit state to avoid + timing-sensitive transitions that can be flaky in CI. + """ + from mas.gateway.circuit_breaker import ( + CircuitBreakerConfig, + CircuitBreakerModule, + CircuitState, + ) + + config = CircuitBreakerConfig( + failure_threshold=2, + success_threshold=2, + timeout_seconds=60.0, # Long timeout - we'll manipulate state directly + window_seconds=120.0, + ) + cb = CircuitBreakerModule(redis, config=config) + target = "halfopen_regression_test" + + try: + # Directly set circuit to HALF_OPEN state for testing + # This avoids timing-sensitive failure threshold + timeout transitions + await redis.hset( + f"circuit:{target}", + mapping={ + "state": CircuitState.HALF_OPEN.value, + "failure_count": "2", + "success_count": "0", + "last_failure_time": "0", + "opened_at": "0", + }, + ) + + # Verify we're in HALF_OPEN + status = await cb.check_circuit(target) + assert status.state == CircuitState.HALF_OPEN, ( + f"Circuit should be HALF_OPEN, got {status.state}" + ) + assert status.allowed is True, "HALF_OPEN should allow traffic" + + # Record first success + status1 = await cb.record_success(target) + assert status1.success_count == 1, "Should have 1 success" + assert status1.state == CircuitState.HALF_OPEN, ( + "Should still be HALF_OPEN after 1 success" + ) + + # Record second success - should close circuit + status2 = await cb.record_success(target) + assert status2.state == CircuitState.CLOSED, ( + f"Circuit should be CLOSED after 2 successes, got {status2.state}" + ) + assert status2.failure_count == 0, "Failure count should reset" + assert status2.success_count == 0, "Success count should reset" + + finally: + await cb.reset_circuit(target) + + +# ----------------------------------------------------------------------------- +# Registry Regression Tests +# ----------------------------------------------------------------------------- + + +class TestRegistryRegression: + """Ensure registry behavior is preserved after optimization.""" + + async def test_register_creates_agent(self, redis): + """Registration creates agent entry with correct data.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + instance_id = "test1234" + + token = await registry.register( + "regression_agent", + instance_id, + capabilities=["cap1", "cap2"], + metadata={"key": "value"}, + ) + + assert token is not None + assert len(token) > 0 + + agent = await registry.get_agent("regression_agent") + assert agent is not None + assert agent["id"] == "regression_agent" + assert "cap1" in agent["capabilities"] + assert "cap2" in agent["capabilities"] + + await registry.deregister("regression_agent", instance_id) + + async def test_deregister_removes_agent(self, redis): + """Deregistration removes agent entry (when last instance leaves).""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + instance_id = "test1234" + + await registry.register("dereg_agent", instance_id, capabilities=[]) + + # Verify exists + agent = await registry.get_agent("dereg_agent") + assert agent is not None + + # Deregister (last instance) + await registry.deregister("dereg_agent", instance_id) + + # Verify removed + agent = await registry.get_agent("dereg_agent") + assert agent is None + + async def test_deregister_preserves_state_by_default(self, redis): + """Deregistration preserves state by default.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + instance_id = "test1234" + + await registry.register("state_agent", instance_id, capabilities=[]) + + # Create some state + await redis.hset("agent.state:state_agent", mapping={"data": "preserved"}) + + # Deregister with default (keep_state=True) + await registry.deregister("state_agent", instance_id) + + # State should still exist + state = await redis.hgetall("agent.state:state_agent") + assert state.get("data") == "preserved" + + # Cleanup + await redis.delete("agent.state:state_agent") + + async def test_deregister_can_remove_state(self, redis): + """Deregistration can remove state when requested.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + instance_id = "test1234" + + await registry.register("state_remove_agent", instance_id, capabilities=[]) + await redis.hset("agent.state:state_remove_agent", mapping={"data": "remove"}) + + # Deregister with keep_state=False + await registry.deregister("state_remove_agent", instance_id, keep_state=False) + + # State should be removed + state = await redis.hgetall("agent.state:state_remove_agent") + assert state == {} + + async def test_heartbeat_updates(self, redis): + """Heartbeat creates/updates heartbeat key.""" + from mas.registry import AgentRegistry + + registry = AgentRegistry(redis) + instance_id = "test1234" + + await registry.register("heartbeat_agent", instance_id, capabilities=[]) + + # Update heartbeat (now requires instance_id) + await registry.update_heartbeat("heartbeat_agent", instance_id, ttl=60) + + # Verify heartbeat key exists (new format includes instance_id) + ttl = await redis.ttl(f"agent:heartbeat_agent:heartbeat:{instance_id}") + assert ttl > 0 + assert ttl <= 60 + + await registry.deregister("heartbeat_agent", instance_id) + + +# ----------------------------------------------------------------------------- +# State Manager Regression Tests +# ----------------------------------------------------------------------------- + + +class TestStateManagerRegression: + """Ensure state management behavior is preserved after optimization.""" + + async def test_state_persists_and_loads(self, redis): + """State is persisted and can be loaded.""" + from mas.state import StateManager + + # Create and save state + manager1 = StateManager("state_test_agent", redis) + await manager1.load() + await manager1.update({"counter": 42, "name": "test"}) + + # Create new manager and load + manager2 = StateManager("state_test_agent", redis) + await manager2.load() + + # Values should be preserved (as strings from Redis) + assert manager2.state["counter"] == "42" + assert manager2.state["name"] == "test" + + # Cleanup + await redis.delete("agent.state:state_test_agent") + + async def test_state_reset(self, redis): + """State reset clears data.""" + from mas.state import StateManager + + manager = StateManager("reset_agent", redis) + await manager.load() + await manager.update({"data": "value"}) + + # Reset + await manager.reset() + + # State should be empty + assert manager.state == {} + + # Redis key should be deleted + exists = await redis.exists("agent.state:reset_agent") + assert exists == 0 + + async def test_complex_values_serialized(self, redis): + """Complex values (dict, list) are JSON serialized.""" + from mas.state import StateManager + + manager = StateManager("complex_state_agent", redis) + await manager.load() + await manager.update( + { + "nested": {"a": 1, "b": 2}, + "items": [1, 2, 3], + } + ) + + # Load in new manager + manager2 = StateManager("complex_state_agent", redis) + await manager2.load() + + # Values are stored as JSON strings + assert manager2.state["nested"] == '{"a": 1, "b": 2}' + assert manager2.state["items"] == "[1, 2, 3]" + + # Cleanup + await redis.delete("agent.state:complex_state_agent") diff --git a/tests/test_runner.py b/tests/test_runner.py new file mode 100644 index 0000000..be5dc4b --- /dev/null +++ b/tests/test_runner.py @@ -0,0 +1,224 @@ +"""Tests for the config-driven agent runner.""" + +from __future__ import annotations + +from pathlib import Path +from typing import override + +import pytest + +from mas import Agent +from mas.runner import AgentRunner, AgentSpec, load_runner_settings + + +class NoopAgent(Agent[dict[str, object]]): + """Agent with no-op lifecycle for runner tests.""" + + @override + async def start(self) -> None: + self._running = True + + @override + async def stop(self) -> None: + self._running = False + + +class NotAnAgent: + """Dummy class for validation tests.""" + + +class FakeService: + """Stub MAS service for runner lifecycle tests.""" + + def __init__(self, redis_url: str) -> None: + self.redis_url = redis_url + self.started = False + self.stopped = False + + async def start(self) -> None: + self.started = True + + async def stop(self) -> None: + self.stopped = True + + +class FakeGateway: + """Stub gateway service for runner lifecycle tests.""" + + def __init__(self, settings) -> None: + self.settings = settings + self.started = False + self.stopped = False + + async def start(self) -> None: + self.started = True + + async def stop(self) -> None: + self.stopped = True + + +def _write_agents_yaml(path: Path) -> None: + path.write_text( + "\n".join( + [ + "agents:", + " - agent_id: test_agent", + " class_path: tests.test_runner:NoopAgent", + " instances: 1", + ] + ) + + "\n", + encoding="utf-8", + ) + + +def test_settings_loads_agents_yaml_from_parent(tmp_path, monkeypatch) -> None: + project_root = tmp_path / "project" + project_root.mkdir() + config_path = project_root / "agents.yaml" + _write_agents_yaml(config_path) + + nested = project_root / "apps" / "worker" + nested.mkdir(parents=True) + monkeypatch.chdir(nested) + + settings = load_runner_settings() + assert settings.config_file == str(config_path) + assert len(settings.agents) == 1 + assert settings.agents[0].agent_id == "test_agent" + + +def test_settings_requires_agents_yaml(tmp_path, monkeypatch) -> None: + project_root = tmp_path / "project" + project_root.mkdir() + monkeypatch.chdir(project_root) + + with pytest.raises(FileNotFoundError, match="agents.yaml not found"): + load_runner_settings() + + +def test_load_agent_class_validation() -> None: + with pytest.raises(ValueError, match="module:ClassName"): + AgentRunner._load_agent_class("tests.test_runner.NoopAgent") + + with pytest.raises(TypeError, match="mas.Agent subclass"): + AgentRunner._load_agent_class("tests.test_runner:NotAnAgent") + + loaded = AgentRunner._load_agent_class("tests.test_runner:NoopAgent") + assert loaded is NoopAgent + + +@pytest.mark.asyncio +async def test_runner_start_respects_instances() -> None: + settings = load_runner_settings( + agents=[ + AgentSpec( + agent_id="noop", + class_path="tests.test_runner:NoopAgent", + instances=2, + ) + ] + ) + runner = AgentRunner(settings) + + await runner._start_agents() + try: + assert len(runner._agents) == 2 + finally: + await runner._stop_agents() + + +@pytest.mark.asyncio +async def test_runner_rejects_reserved_kwargs() -> None: + settings = load_runner_settings( + agents=[ + AgentSpec( + agent_id="noop", + class_path="tests.test_runner:NoopAgent", + init_kwargs={"agent_id": "override"}, + ) + ] + ) + runner = AgentRunner(settings) + + with pytest.raises(ValueError, match="reserved keys"): + await runner._start_agents() + + +@pytest.mark.asyncio +async def test_runner_starts_and_stops_service(monkeypatch) -> None: + monkeypatch.setattr("mas.runner.MASService", FakeService) + monkeypatch.setattr("mas.runner.GatewayService", FakeGateway) + settings = load_runner_settings( + agents=[ + AgentSpec( + agent_id="noop", + class_path="tests.test_runner:NoopAgent", + ) + ], + service_redis_url="redis://example:6379", + ) + runner = AgentRunner(settings) + + await runner._start_service() + assert runner._service is not None + assert runner._service.redis_url == "redis://example:6379" + assert runner._service.started is True + + await runner._stop_service() + assert runner._service is None + + +@pytest.mark.asyncio +async def test_runner_starts_and_stops_gateway(monkeypatch) -> None: + monkeypatch.setattr("mas.runner.GatewayService", FakeGateway) + settings = load_runner_settings( + agents=[ + AgentSpec( + agent_id="noop", + class_path="tests.test_runner:NoopAgent", + ) + ] + ) + runner = AgentRunner(settings) + + await runner._start_gateway() + assert runner._gateway is not None + assert runner._gateway.started is True + + await runner._stop_gateway() + assert runner._gateway is None + + +@pytest.mark.asyncio +async def test_runner_rejects_use_gateway_kwarg() -> None: + settings = load_runner_settings( + agents=[ + AgentSpec( + agent_id="noop", + class_path="tests.test_runner:NoopAgent", + init_kwargs={"use_gateway": True}, + ) + ] + ) + runner = AgentRunner(settings) + + with pytest.raises(ValueError, match="use_gateway is not supported"): + await runner._start_agents() + + +@pytest.mark.asyncio +async def test_runner_requires_gateway_service() -> None: + settings = load_runner_settings( + agents=[ + AgentSpec( + agent_id="noop", + class_path="tests.test_runner:NoopAgent", + ) + ], + start_gateway=False, + ) + runner = AgentRunner(settings) + + with pytest.raises(RuntimeError, match="start_gateway must be true"): + await runner.run() diff --git a/tests/test_simple_messaging.py b/tests/test_simple_messaging.py index ecad77a..7db447b 100644 --- a/tests/test_simple_messaging.py +++ b/tests/test_simple_messaging.py @@ -1,4 +1,4 @@ -"""Test simplified peer-to-peer messaging.""" +"""Tests for basic gateway messaging.""" import asyncio from typing import override @@ -27,8 +27,8 @@ async def on_message(self, message: AgentMessage) -> None: @pytest.mark.asyncio -async def test_peer_to_peer_messaging(mas_service): - """Test direct peer-to-peer messaging between agents.""" +async def test_gateway_messaging(mas_service): + """Test messaging between agents through the gateway.""" # Start gateway (streams, signing disabled for tests) settings = GatewaySettings( features=FeaturesSettings(