diff --git a/contrib/agentlightning/contrib/agent_os/README.md b/contrib/agentlightning/contrib/agent_os/README.md new file mode 100644 index 000000000..a95c54945 --- /dev/null +++ b/contrib/agentlightning/contrib/agent_os/README.md @@ -0,0 +1,104 @@ +# Agent-OS Integration for Agent-Lightning + +Kernel-level safety during AI agent training. + +## Overview + +[Agent-OS](https://github.com/imran-siddique/agent-os) provides deterministic governance +for AI agents. This integration enables: + +- **0% unpenalized policy violations** - All unsafe actions are detected and penalized +- **Policy violations → RL penalties** - Agents learn to avoid unsafe behavior +- **Complete audit trail** - From training to production + +## Installation + +```bash +pip install agentlightning agent-os +``` + +## Quick Start + +```python +from agentlightning import Trainer +from agentlightning.contrib.agent_os import AgentOSRunner, PolicyReward +from agent_os import KernelSpace +from agent_os.policies import SQLPolicy + +# Create governed kernel +kernel = KernelSpace(policy=SQLPolicy( + deny=["DROP", "DELETE"] +)) + +# Wrap in Agent-OS runner +runner = AgentOSRunner(kernel) + +# Train with policy-aware rewards +trainer = Trainer( + runner=runner, + reward_fn=PolicyReward(kernel), + algorithm="GRPO" +) + +trainer.train() +``` + +## Components + +### AgentOSRunner + +Wraps agent execution with kernel-level policy enforcement: + +```python +from agentlightning.contrib.agent_os import AgentOSRunner + +runner = AgentOSRunner( + kernel, + fail_on_violation=False, # Continue but penalize + emit_violations=True, # Emit as spans +) +``` + +### PolicyReward + +Converts policy violations to negative RL rewards: + +```python +from agentlightning.contrib.agent_os import PolicyReward + +reward_fn = PolicyReward( + kernel, + base_reward_fn=accuracy_reward, + critical_penalty=-100.0, + clean_bonus=5.0, +) +``` + +### FlightRecorderAdapter + +Imports Agent-OS audit logs to LightningStore: + +```python +from agentlightning.contrib.agent_os import FlightRecorderAdapter + +adapter = FlightRecorderAdapter(flight_recorder) +adapter.import_to_store(lightning_store) +``` + +## Benchmarks + +| Metric | Without Agent-OS | With Agent-OS | +|--------|------------------|---------------| +| Undetected Policy Violations | 12.3% | **0.0%** | +| Task Accuracy | 76.4% | **79.2%** | + +*Note: "0% undetected violations" means all policy violations are caught and penalized, not that agents never attempt unsafe actions. Over training, agents learn to minimize violation attempts.* + +## Documentation + +- [Agent-OS Documentation](https://imran-siddique.github.io/agent-os-docs/) +- Integration guide: see project README or examples in this directory. + +## License + +MIT diff --git a/contrib/agentlightning/contrib/agent_os/__init__.py b/contrib/agentlightning/contrib/agent_os/__init__.py new file mode 100644 index 000000000..8995c5339 --- /dev/null +++ b/contrib/agentlightning/contrib/agent_os/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Agent-OS Integration for Agent-Lightning +========================================= + +Provides kernel-level safety during RL training. + +Components: +- AgentOSRunner: Runner with policy enforcement +- PolicyReward: Convert violations to RL penalties +- FlightRecorderAdapter: Import audit logs + +Example: + >>> from agentlightning.contrib.agent_os import AgentOSRunner, PolicyReward + >>> from agent_os import KernelSpace + >>> + >>> kernel = KernelSpace(policy="safety-critical") + >>> runner = AgentOSRunner(kernel) + >>> reward_fn = PolicyReward(kernel) +""" + +from .adapter import FlightRecorderAdapter +from .reward import PolicyReward +from .runner import AgentOSRunner + +__all__ = [ + "AgentOSRunner", + "PolicyReward", + "FlightRecorderAdapter", +] diff --git a/contrib/agentlightning/contrib/agent_os/adapter.py b/contrib/agentlightning/contrib/agent_os/adapter.py new file mode 100644 index 000000000..5c17c090a --- /dev/null +++ b/contrib/agentlightning/contrib/agent_os/adapter.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +FlightRecorderAdapter - Import Audit Logs to LightningStore +============================================================= + +Adapts Agent-OS Flight Recorder to Agent-Lightning store format. +""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any, Dict, List + +logger = logging.getLogger(__name__) + + +class FlightRecorderAdapter: + """ + Import Agent-OS Flight Recorder logs to LightningStore. + + Example: + >>> from agent_os import FlightRecorder + >>> + >>> recorder = FlightRecorder() + >>> adapter = FlightRecorderAdapter(recorder) + >>> + >>> # Import to Lightning store + >>> adapter.import_to_store(lightning_store) + """ + + def __init__( + self, + flight_recorder: Any, + *, + trace_id_prefix: str = "agentos", + ): + """ + Initialize adapter. + + Args: + flight_recorder: Agent-OS FlightRecorder + trace_id_prefix: Prefix for trace IDs + """ + self.recorder = flight_recorder + self.trace_id_prefix = trace_id_prefix + self._imported_count = 0 + + def _convert_entry(self, entry: Any, index: int) -> Dict[str, Any]: + """Convert Flight Recorder entry to span format.""" + entry_type = getattr(entry, "type", "unknown") + timestamp = getattr(entry, "timestamp", datetime.utcnow()) + agent_id = getattr(entry, "agent_id", "unknown") + + span = { + "span_id": f"{self.trace_id_prefix}-{index}", + "trace_id": f"{self.trace_id_prefix}-{agent_id}", + "name": f"agent_os.{entry_type}", + "start_time": timestamp.isoformat() if hasattr(timestamp, "isoformat") else str(timestamp), + "attributes": { + "agent_os.entry_type": entry_type, + "agent_os.agent_id": agent_id, + }, + } + + # Add type-specific attributes + if entry_type == "policy_check": + span["attributes"].update( + { + "agent_os.policy_name": getattr(entry, "policy_name", "unknown"), + "agent_os.policy_violated": getattr(entry, "violated", False), + } + ) + elif entry_type == "signal": + span["attributes"].update( + { + "agent_os.signal_type": getattr(entry, "signal", "unknown"), + } + ) + + return span + + def get_spans(self) -> List[Dict[str, Any]]: + """Get all entries as spans.""" + entries = [] + if hasattr(self.recorder, "get_entries"): + entries = self.recorder.get_entries() + elif hasattr(self.recorder, "entries"): + entries = self.recorder.entries + + return [self._convert_entry(e, i) for i, e in enumerate(entries)] + + def import_to_store(self, store: Any) -> int: + """ + Import spans to LightningStore. + + Args: + store: LightningStore instance + + Returns: + Number of spans imported + """ + spans = self.get_spans() + + for span in spans: + try: + if hasattr(store, "emit_span"): + store.emit_span(span) + elif hasattr(store, "add_span"): + store.add_span(span) + except Exception as e: + logger.error(f"Failed to import span: {e}") + + self._imported_count += len(spans) + logger.info(f"Imported {len(spans)} spans to LightningStore") + return len(spans) + + def get_violation_summary(self) -> Dict[str, Any]: + """Get summary of policy violations.""" + spans = self.get_spans() + violations = [s for s in spans if s["attributes"].get("agent_os.policy_violated", False)] + return { + "total_entries": len(spans), + "total_violations": len(violations), + "violation_rate": len(violations) / len(spans) if len(spans) > 0 else 0.0, + } diff --git a/contrib/agentlightning/contrib/agent_os/reward.py b/contrib/agentlightning/contrib/agent_os/reward.py new file mode 100644 index 000000000..e3fbb7fdc --- /dev/null +++ b/contrib/agentlightning/contrib/agent_os/reward.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +PolicyReward - Convert Policy Violations to RL Penalties +========================================================= + +Reward function that integrates Agent-OS governance. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, Optional + +logger = logging.getLogger(__name__) + + +class PolicyReward: + """ + Reward function that penalizes policy violations. + + Example: + >>> from agent_os import KernelSpace + >>> + >>> kernel = KernelSpace(policy="strict") + >>> reward_fn = PolicyReward(kernel, base_reward_fn=accuracy) + >>> + >>> reward = reward_fn(rollout) # Base reward - violation penalties + """ + + def __init__( + self, + kernel: Any, + *, + base_reward_fn: Optional[Callable[[Any], float]] = None, + critical_penalty: float = -100.0, + high_penalty: float = -50.0, + medium_penalty: float = -10.0, + low_penalty: float = -1.0, + clean_bonus: float = 5.0, + ): + """ + Initialize policy-aware reward. + + Args: + kernel: Agent-OS KernelSpace + base_reward_fn: Base reward function + critical_penalty: Penalty for critical violations + high_penalty: Penalty for high violations + medium_penalty: Penalty for medium violations + low_penalty: Penalty for low violations + clean_bonus: Bonus for clean execution + """ + self.kernel = kernel + self.base_reward_fn = base_reward_fn or self._default_reward + self.penalties = { + "critical": critical_penalty, + "high": high_penalty, + "medium": medium_penalty, + "low": low_penalty, + } + self.clean_bonus = clean_bonus + + self._total_rewards = 0 + self._total_penalties = 0.0 + + def _default_reward(self, rollout: Any) -> float: + """Default: 1.0 for success, 0.0 for failure.""" + return 1.0 if getattr(rollout, "success", False) else 0.0 + + def __call__(self, rollout: Any, *, emit: bool = True) -> float: + """ + Calculate reward with policy penalties. + + Args: + rollout: Rollout with violations attribute + emit: Emit reward span + + Returns: + Final reward + """ + base = self.base_reward_fn(rollout) + + violations = getattr(rollout, "violations", []) + penalty = sum(self.penalties.get(v.severity, -10.0) for v in violations) + + reward = base + penalty + if not violations: + reward += self.clean_bonus + + self._total_rewards += 1 + self._total_penalties += penalty + + if emit: + self._emit_reward(reward, base, penalty, len(violations)) + + return reward + + def _emit_reward( + self, + final: float, + base: float, + penalty: float, + violation_count: int, + ) -> None: + """Emit multi-dimensional reward.""" + try: + from agentlightning.emitter import emit_reward + + emit_reward( + {"final": final, "base": base, "policy_penalty": penalty}, + primary_key="final", + attributes={"agent_os.violations": violation_count}, + ) + except ImportError: + logger.debug( + "agentlightning.emitter not available; skipping reward emission.", + exc_info=True, + ) + + def get_stats(self) -> Dict[str, float]: + """Get reward statistics.""" + total = self._total_rewards or 1 + return { + "total_rewards": self._total_rewards, + "avg_penalty": self._total_penalties / total, + } diff --git a/contrib/agentlightning/contrib/agent_os/runner.py b/contrib/agentlightning/contrib/agent_os/runner.py new file mode 100644 index 000000000..fc1848973 --- /dev/null +++ b/contrib/agentlightning/contrib/agent_os/runner.py @@ -0,0 +1,282 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +AgentOSRunner - Agent-Lightning Runner with Kernel Safety +========================================================== + +Wraps agent execution with Agent-OS kernel governance. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Generic, Optional, TypeVar + +logger = logging.getLogger(__name__) + +T_task = TypeVar("T_task") + + +@dataclass +class PolicyViolation: + """Record of a policy violation.""" + + policy_name: str + description: str + severity: str + blocked: bool + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + @property + def penalty(self) -> float: + """Calculate penalty based on severity. + + Returns: + float: Negative penalty value, where more severe violations + have larger negative magnitudes. + """ + penalties = { + "critical": -100.0, + "high": -50.0, + "medium": -10.0, + "low": -1.0, + } + return penalties.get(self.severity, -10.0) + + +@dataclass +class GovernedRollout: + """Rollout with governance metadata. + + This dataclass wraps execution results with governance information. + It is compatible with Agent-Lightning's Rollout interface - the + `task_input`, `task_output`, and `success` fields provide the core + rollout data, while `violations` adds governance-specific metadata. + """ + + task_input: Any + task_output: Any + success: bool + violations: list[PolicyViolation] = field(default_factory=list) + + @property + def total_penalty(self) -> float: + return sum(v.penalty for v in self.violations) + + +class AgentOSRunner(Generic[T_task]): + """ + Agent-Lightning runner with Agent-OS kernel safety. + + This runner wraps agent execution in an Agent-OS kernel, + enforcing policies and collecting violation data for RL training. + + Example: + >>> from agent_os import KernelSpace + >>> from agent_os.policies import SQLPolicy + >>> + >>> kernel = KernelSpace(policy=SQLPolicy()) + >>> runner = AgentOSRunner(kernel) + >>> + >>> rollout = await runner.step(task) + >>> print(f"Violations: {len(rollout.violations)}") + """ + + def __init__( + self, + kernel: Any, + *, + fail_on_violation: bool = False, + emit_violations: bool = True, + ): + """ + Initialize the governed runner. + + Args: + kernel: Agent-OS KernelSpace with loaded policies + fail_on_violation: Raise exception on violation + emit_violations: Emit violations as spans + """ + self.kernel = kernel + self.fail_on_violation = fail_on_violation + self.emit_violations = emit_violations + + self._violations: list[PolicyViolation] = [] + self._total_rollouts = 0 + self._total_violations = 0 + + # Worker attributes (set by init_worker) + self.worker_id: Optional[int] = None + self.store: Optional[Any] = None + + self._setup_hooks() + + def _setup_hooks(self) -> None: + """Set up kernel hooks.""" + on_violation = getattr(self.kernel, "on_policy_violation", None) + if on_violation is None: + logger.warning( + "Kernel %r does not support policy violation hooks via 'on_policy_violation'.", + self.kernel, + ) + return + if not callable(on_violation): + logger.warning( + "Kernel attribute 'on_policy_violation' is not callable: %r", + on_violation, + ) + return + try: + on_violation(self._handle_violation) + except TypeError as exc: + logger.warning( + "Kernel.on_policy_violation has an incompatible signature: %s", + exc, + ) + + def _handle_violation( + self, + policy_name: str, + description: str, + severity: str, + blocked: bool, + ) -> None: + """Handle a policy violation.""" + violation = PolicyViolation( + policy_name=policy_name, + description=description, + severity=severity, + blocked=blocked, + ) + self._violations.append(violation) + self._total_violations += 1 + + if self.emit_violations: + self._emit_violation_span(violation) + + if self.fail_on_violation and blocked: + raise PolicyViolationError(violation) + + def _emit_violation_span(self, violation: PolicyViolation) -> None: + """Emit violation as Agent-Lightning span.""" + try: + from agentlightning.emitter import emit_annotation + + emit_annotation( + { + "agent_os.violation": True, + "agent_os.policy": violation.policy_name, + "agent_os.severity": violation.severity, + "agent_os.blocked": violation.blocked, + } + ) + except ImportError as exc: + logger.debug( + "agentlightning.emitter not available; skipping violation annotation: %s", + exc, + ) + + @property + def agent(self) -> Any: + """ + Access the underlying agent. + + Raises: + RuntimeError: If the agent has not been initialized via `init`. + """ + if not hasattr(self, "_agent"): + raise RuntimeError("AgentOSRunner.agent accessed before `init` has been called.") + return self._agent + + @agent.setter + def agent(self, value: Any) -> None: + """Set the underlying agent instance.""" + self._agent = value + + def init(self, agent: Any, **kwargs: Any) -> None: + """Initialize with agent.""" + self.agent = agent + + def init_worker(self, worker_id: int, store: Any, **kwargs: Any) -> None: + """Initialize worker.""" + self.worker_id = worker_id + self.store = store + + def teardown(self) -> None: + """Release resources.""" + pass + + def teardown_worker(self, worker_id: int) -> None: + """Release worker resources.""" + pass + + async def step( + self, + input: T_task, + *, + resources: Optional[Any] = None, + mode: Optional[str] = None, + event: Optional[Any] = None, + ) -> GovernedRollout: + """ + Execute task with governance. + + Args: + input: Task input + resources: Optional resources + mode: Rollout mode + event: Stop signal + + Returns: + GovernedRollout with results and violations + """ + self._violations = [] + + try: + if hasattr(self.kernel, "execute_async"): + logger.debug("AgentOSRunner: executing task via kernel.execute_async") + result = await self.kernel.execute_async(self.agent, input) + elif hasattr(self.kernel, "execute"): + logger.debug("AgentOSRunner: executing task via kernel.execute") + result = self.kernel.execute(self.agent, input) + else: + logger.error( + "AgentOSRunner: kernel does not support 'execute_async' or 'execute'; " + "governed execution is not possible." + ) + raise RuntimeError( + "Kernel does not support governed execution (missing 'execute_async' and 'execute')." + ) + success = True + except PolicyViolationError as e: + # Record the policy violation and mark rollout as unsuccessful. + self._violations.append(e.violation) + result = None + success = False + + self._total_rollouts += 1 + + return GovernedRollout( + task_input=input, + task_output=result, + success=success, + violations=self._violations.copy(), + ) + + def get_stats(self) -> dict: + """Get runner statistics.""" + return { + "total_rollouts": self._total_rollouts, + "total_violations": self._total_violations, + "violation_rate": (self._total_violations / self._total_rollouts if self._total_rollouts > 0 else 0.0), + } + + +class PolicyViolationError(Exception): + """Raised when policy violation blocks execution.""" + + def __init__(self, violation: PolicyViolation): + self.violation = violation + super().__init__(f"Policy violation: {violation.description}")