diff --git a/README.md b/README.md index 5511d9e..da28009 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ This guide will walk you through setting up your API key, downloading a dataset, ### Step 1: Configure Your API Key -First, tell CARIBOU about your OpenAI or DeepSeek API key. This is a one-time setup. +First, tell CARIBOU about your OpenAI, Anthropic (Claude), or DeepSeek API key. This is a one-time setup. ```bash caribou config set-openai-key "sk-YourSecretKeyGoesHere" @@ -63,6 +63,12 @@ or caribou config set-deepseek-key "sk-YourSecretKeyGoesHere" ``` +or + +```bash +caribou config set-anthropic-key "sk-ant-YourSecretKeyGoesHere" +``` + Your key will be stored securely in a local `.env` file within the CARIBOU configuration directory. @@ -91,7 +97,7 @@ This will trigger a series of prompts: 2. **Select a driver agent:** Choose which agent in the system will receive the first instruction. 3. **Select Dataset:** Pick the dataset you downloaded in Step 2. 4. **Choose a sandbox backend:** Select `docker` or `singularity`. -5. **Choose an LLM backend:** Select `chatgpt` or `ollama`. +5. **Choose an LLM backend:** Select `chatgpt`, `claude`, `deepseek`, or `ollama`. After configuration, the session will begin, and you can start giving instructions to your agent team\! @@ -161,6 +167,10 @@ Manage your CARIBOU configuration. ```bash caribou config set-deepseek-key "sk-..." ``` + * **Set your Anthropic API key:** + ```bash + caribou config set-anthropic-key "sk-ant-..." + ``` ----- diff --git a/caribou/README.md b/caribou/README.md index 963bc0b..01da052 100644 --- a/caribou/README.md +++ b/caribou/README.md @@ -51,7 +51,7 @@ This guide will walk you through setting up your API key, downloading a dataset, ### Step 1: Configure Your API Key -First, tell CARIBOU about your OpenAI or DeepSeek API key. This is a one-time setup. +First, tell CARIBOU about your OpenAI, Anthropic (Claude), or DeepSeek API key. This is a one-time setup. ```bash caribou config set-openai-key "sk-YourSecretKeyGoesHere" @@ -63,6 +63,12 @@ or caribou config set-deepseek-key "sk-YourSecretKeyGoesHere" ``` +or + +```bash +caribou config set-anthropic-key "sk-ant-YourSecretKeyGoesHere" +``` + Your key will be stored securely in a local `.env` file within the CARIBOU configuration directory. @@ -91,7 +97,7 @@ This will trigger a series of prompts: 2. **Select a driver agent:** Choose which agent in the system will receive the first instruction. 3. **Select Dataset:** Pick the dataset you downloaded in Step 2. 4. **Choose a sandbox backend:** Select `docker` or `singularity`. -5. **Choose an LLM backend:** Select `chatgpt` or `ollama`. +5. **Choose an LLM backend:** Select `chatgpt`, `claude`, `deepseek`, or `ollama`. After configuration, the session will begin, and you can start giving instructions to your agent team\! @@ -161,6 +167,10 @@ Manage your CARIBOU configuration. ```bash caribou config set-deepseek-key "sk-..." ``` + * **Set your Anthropic API key:** + ```bash + caribou config set-anthropic-key "sk-ant-..." + ``` ----- diff --git a/caribou/src/caribou/cli/config_cli.py b/caribou/src/caribou/cli/config_cli.py index 0f4ec1a..0ff269d 100644 --- a/caribou/src/caribou/cli/config_cli.py +++ b/caribou/src/caribou/cli/config_cli.py @@ -71,4 +71,36 @@ def set_deepseek_key( new_content = content.strip() + f"\n{key_to_set}\n" ENV_FILE.write_text(new_content) - console.print(f"[bold green]✅ DeepSeek API key has been set successfully in:[/bold green] {ENV_FILE}") \ No newline at end of file + console.print(f"[bold green]✅ DeepSeek API key has been set successfully in:[/bold green] {ENV_FILE}") + +@config_app.command("set-anthropic-key") +def set_anthropic_key( + ctx: typer.Context, + api_key: Optional[str] = typer.Argument(None, help="Your Anthropic API key (e.g., 'sk-ant-...')"), +): + """ + Saves your Anthropic API key to the Caribou environment file. + """ + if api_key is None: + console.print("[bold red]Error:[/bold red] You must provide an API key.\n") + typer.echo(ctx.parent.get_help()) + raise typer.Exit() + + if not api_key.startswith("sk-"): + console.print( + "[yellow]Warning: Key does not look like a standard Anthropic API key (should start with 'sk-').[/yellow]" + ) + + if not ENV_FILE.exists(): + ENV_FILE.touch() + + content = ENV_FILE.read_text() + key_to_set = f'ANTHROPIC_API_KEY="{api_key}"' + + if re.search(r"^ANTHROPIC_API_KEY=.*$", content, flags=re.MULTILINE): + new_content = re.sub(r"^ANTHROPIC_API_KEY=.*$", key_to_set, content, flags=re.MULTILINE) + else: + new_content = content.strip() + f"\n{key_to_set}\n" + + ENV_FILE.write_text(new_content.strip()) + console.print(f"[bold green]✅ Anthropic API key has been set successfully in:[/bold green] {ENV_FILE}") diff --git a/caribou/src/caribou/cli/run_cli.py b/caribou/src/caribou/cli/run_cli.py index 9fe742d..cdc706a 100644 --- a/caribou/src/caribou/cli/run_cli.py +++ b/caribou/src/caribou/cli/run_cli.py @@ -370,7 +370,7 @@ def initialize_context( # ---- LLM Backend ---- if llm_backend is None: - llm_backend = Prompt.ask("Choose an LLM backend", choices=["chatgpt", "ollama", "deepseek"], default="chatgpt") + llm_backend = Prompt.ask("Choose an LLM backend", choices=["chatgpt", "claude", "ollama", "deepseek"], default="chatgpt") console.print(f"[cyan]Initializing LLM backend: {llm_backend}[/cyan]") @@ -380,7 +380,14 @@ def initialize_context( raise typer.Exit(1) context.llm_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) context.model_name = "gpt-5.2" - + elif llm_backend == "claude": + anthropic_key = os.getenv("ANTHROPIC_API_KEY") + if not anthropic_key: + console.print("[bold red]Error: ANTHROPIC_API_KEY not set. Use 'caribou config set-anthropic-key'.[/bold red]") + raise typer.Exit(1) + from caribou.core.anthropic_wrapper import AnthropicClient + context.llm_client = AnthropicClient(api_key=anthropic_key) + context.model_name = "claude-sonnet-4-5-20250929" elif llm_backend == "deepseek": if not os.getenv("DEEPSEEK_API_KEY"): console.print("[bold red]Error: DEEPSEEK_API_KEY not set. Use 'caribou config set-deepseek-key'.[/bold red]") @@ -459,7 +466,7 @@ def main_run_callback( dataset: Optional[Path] = typer.Option(None, "--dataset", "-ds", help="Path to the primary dataset file (.h5ad).", readable=True), reference_dataset: Optional[Path] = typer.Option(None, "--reference-dataset", "-ref", help="Path to an optional reference dataset file (.h5ad).", readable=True), resources_dir: Optional[Path] = typer.Option(None, "--resources", help="Path to a directory of resource files to mount.", exists=True, file_okay=False), - llm_backend: Optional[str] = typer.Option(None, "--llm", help="LLM backend: 'chatgpt', 'ollama', or 'deepseek'."), + llm_backend: Optional[str] = typer.Option(None, "--llm", help="LLM backend: 'chatgpt', 'claude', 'ollama', or 'deepseek'."), ollama_host: str = typer.Option("http://localhost:11434", "--ollama-host", help="Base URL for Ollama backend."), sandbox: Optional[str] = typer.Option(None, "--sandbox", help="Sandbox backend: 'docker' or 'singularity'."), force_refresh: bool = typer.Option(False, "--force-refresh", help="Force refresh/rebuild of the sandbox environment."), @@ -515,7 +522,7 @@ def run_interactive( dataset: Path = typer.Option(None, "--dataset", "-ds", help="Path to the primary dataset file (.h5ad).", readable=True), reference_dataset: Path = typer.Option(None, "--reference-dataset", "-ref", help="Path to an optional reference dataset file (.h5ad).", readable=True), resources_dir: Path = typer.Option(None, "--resources", help="Path to a directory of resource files to mount.", exists=True, file_okay=False), - llm_backend: str = typer.Option(None, "--llm", help="LLM backend to use: 'chatgpt', 'ollama', or 'deepseek'."), + llm_backend: str = typer.Option(None, "--llm", help="LLM backend to use: 'chatgpt', 'claude', 'ollama', or 'deepseek'."), ollama_host: str = typer.Option("http://localhost:11434", "--ollama-host", help="Base URL for Ollama backend."), sandbox: str = typer.Option(None, "--sandbox", help="Sandbox backend to use: 'docker' or 'singularity'."), force_refresh: bool = typer.Option(False, "--force-refresh", help="Force refresh/rebuild of the sandbox environment."), @@ -560,7 +567,7 @@ def run_auto( dataset: Path = typer.Option(None, "--dataset", "-ds", help="Path to the primary dataset file (.h5ad).", readable=True), reference_dataset: Path = typer.Option(None, "--reference-dataset", "-ref", help="Path to an optional reference dataset file (.h5ad).", readable=True), resources_dir: Path = typer.Option(None, "--resources", help="Path to a directory of resource files to mount.", exists=True, file_okay=False), - llm_backend: str = typer.Option(None, "--llm", help="LLM backend to use: 'chatgpt', 'ollama', or 'deepseek'."), + llm_backend: str = typer.Option(None, "--llm", help="LLM backend to use: 'chatgpt', 'claude', 'ollama', or 'deepseek'."), ollama_host: str = typer.Option("http://localhost:11434", "--ollama-host", help="Base URL for Ollama backend."), sandbox: str = typer.Option(None, "--sandbox", help="Sandbox backend to use: 'docker' or 'singularity'."), force_refresh: bool = typer.Option(False, "--force-refresh", help="Force refresh/rebuild of the sandbox environment."), diff --git a/caribou/src/caribou/core/anthropic_wrapper.py b/caribou/src/caribou/core/anthropic_wrapper.py new file mode 100644 index 0000000..afbbcb1 --- /dev/null +++ b/caribou/src/caribou/core/anthropic_wrapper.py @@ -0,0 +1,75 @@ +""" +Lightweight Anthropic client wrapper that mimics the subset of the OpenAI +chat API used by CARIBOU. Exposes a `.chat.completions.create(...)` method +that returns an object shaped like the OpenAI response: + + resp.choices[0].message.content +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Dict, List, Optional + +import anthropic + + +class AnthropicClient: + """ + Example: + client = AnthropicClient(api_key="sk-ant-...", model="claude-sonnet-4-5-20250929") + resp = client.chat.completions.create(model="claude-sonnet-4-5-20250929", messages=[...]) + print(resp.choices[0].message.content) + """ + + def __init__( + self, + *, + api_key: str, + model: str = "claude-sonnet-4-5-20250929", + max_output_tokens: int = 1024, + base_url: Optional[str] = None, + ): + client_kwargs: Dict[str, Any] = {"api_key": api_key} + if base_url: + client_kwargs["base_url"] = base_url + self._client = anthropic.Anthropic(**client_kwargs) + self._default_model = model + self._max_output_tokens = max_output_tokens + self.chat = SimpleNamespace(completions=SimpleNamespace(create=self._chat_create)) + + def _chat_create( + self, + *, + messages: List[Dict[str, str]], + model: Optional[str] = None, + temperature: Optional[float] = None, + max_output_tokens: Optional[int] = None, + **_: Any, + ): + system_parts: List[str] = [] + converted: List[Dict[str, str]] = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "system": + system_parts.append(str(content)) + continue + converted.append({"role": role if role in ("assistant", "user") else "user", "content": content}) + + system_prompt = "\n\n".join(system_parts) if system_parts else None + + response = self._client.messages.create( + model=model or self._default_model, + system=system_prompt, + messages=converted, + temperature=temperature, + max_tokens=max_output_tokens or self._max_output_tokens, + ) + + text_chunks = [block.text for block in response.content if getattr(block, "type", None) == "text"] + content = "".join(text_chunks) + + message = SimpleNamespace(content=content, role="assistant") + choice = SimpleNamespace(message=message, index=0, finish_reason=getattr(response, "stop_reason", "stop")) + return SimpleNamespace(choices=[choice]) diff --git a/caribou/tests/QUICKSTART.md b/caribou/tests/QUICKSTART.md new file mode 100644 index 0000000..0cf088f --- /dev/null +++ b/caribou/tests/QUICKSTART.md @@ -0,0 +1,127 @@ +# CARIBOU Test Suite - Quick Start Guide + +## Installation + +1. **Activate your conda environment (if using conda):** + +```bash +conda activate olaf # or your environment name +``` + +2. **Install test dependencies:** + +```bash +cd /data1/peerd/riffled/riffled/Olaf_project/CARIBOU +python -m pip install pytest pytest-cov +``` + +Or install all requirements: + +```bash +python -m pip install -r requirements.txt +``` + +**Important:** Make sure pytest is installed in the same Python environment where `anthropic`, `openai`, and other CARIBOU dependencies are installed. + +## Running Tests + +### Option 1: Using the Test Runner Script (Recommended) + +```bash +cd caribou/tests +./run_tests.sh +``` + +Available options: +- `./run_tests.sh --unit` - Run only unit tests +- `./run_tests.sh --integration` - Run only integration tests +- `./run_tests.sh --verbose` - Verbose output +- `./run_tests.sh --coverage` - Generate coverage report + +### Option 2: Using pytest Directly + +From the project root directory: + +```bash +# Run all tests +pytest caribou/tests/ + +# Run specific test categories +pytest caribou/tests/unit/ +pytest caribou/tests/integration/ + +# Run specific test file +pytest caribou/tests/unit/test_message_utils.py + +# Run with verbose output +pytest caribou/tests/ -v + +# Run with coverage +pytest caribou/tests/ --cov=caribou --cov-report=html +``` + +## What Gets Tested + +✅ **LLM API Wrappers** +- AnthropicClient (OpenAI compatibility) +- OllamaClient (local models) + +✅ **Message Routing** +- Delegation detection (`delegate_to_agent`) +- RAG query detection (`query_rag_`) +- Artifact extraction (notes, TODOs) + +✅ **History Management** +- MemoryManager with episodic summarization +- Context assembly and compression + +✅ **Agent System** +- Multi-agent configuration +- Prompt generation +- Agent switching + +✅ **End-to-End Integration** +- Complete message flows +- Multi-agent conversations +- Error handling + +## Verifying the Setup + +Run a quick smoke test: + +```bash +pytest caribou/tests/unit/test_message_utils.py::TestDelegationDetection::test_detect_simple_delegation -v +``` + +Expected output: +``` +test_detect_simple_delegation PASSED +``` + +## Troubleshooting + +**Import errors?** +The tests automatically add `caribou/src` to the Python path via `conftest.py`. If you still get import errors, verify the directory structure: + +``` +CARIBOU/ +└── caribou/ + ├── src/ + │ └── caribou/ + │ ├── core/ + │ ├── execution/ + │ └── agents/ + └── tests/ + ├── conftest.py # ← Should add src/ to path + ├── unit/ + └── integration/ +``` + +**Tests hanging?** +All external API calls are mocked - tests should run quickly (< 10 seconds total). + +## Next Steps + +- See [README.md](README.md) for detailed documentation +- Run with coverage to see what's tested: `./run_tests.sh --coverage` +- Open `htmlcov/index.html` to view the coverage report diff --git a/caribou/tests/README.md b/caribou/tests/README.md new file mode 100644 index 0000000..b4cfc64 --- /dev/null +++ b/caribou/tests/README.md @@ -0,0 +1,309 @@ +# CARIBOU Test Suite + +Comprehensive unit and integration tests for CARIBOU's LLM API integration, message routing, history management, and system prompt handling. + +## Overview + +This test suite ensures the reliability and correctness of: + +- **LLM API Wrappers**: AnthropicClient and OllamaClient for OpenAI-compatible interfaces +- **Message Routing**: Delegation detection, RAG query detection, and artifact extraction +- **History Management**: MemoryManager with episodic summarization +- **Agent System**: Multi-agent configuration, prompt generation, and agent switching +- **Integration Flows**: End-to-end message flows with all features combined + +## Test Structure + +``` +tests/ +├── unit/ # Unit tests for individual components +│ ├── test_anthropic_wrapper.py # AnthropicClient tests +│ ├── test_ollama_wrapper.py # OllamaClient tests +│ ├── test_message_utils.py # Message parsing and routing tests +│ ├── test_memory_manager.py # MemoryManager tests +│ └── test_agent_system.py # Agent system and management tests +├── integration/ # Integration tests +│ └── test_message_flow.py # End-to-end flow tests +├── fixtures/ # Test fixtures and data +├── conftest.py # Shared pytest fixtures +├── run_tests.sh # Test runner script +└── README.md # This file +``` + +## Prerequisites + +Install test dependencies: + +```bash +pip install pytest pytest-cov +``` + +Or install all CARIBOU dependencies (includes test deps): + +```bash +pip install -r requirements.txt +``` + +## Running Tests + +### Quick Start + +Run all tests: + +```bash +cd caribou/tests +./run_tests.sh +``` + +Or using pytest directly from project root: + +```bash +pytest caribou/tests/ +``` + +### Specific Test Categories + +**Unit tests only:** +```bash +./run_tests.sh --unit +# or +pytest caribou/tests/unit/ +``` + +**Integration tests only:** +```bash +./run_tests.sh --integration +# or +pytest caribou/tests/integration/ +``` + +**Specific test file:** +```bash +pytest caribou/tests/unit/test_anthropic_wrapper.py +``` + +**Specific test class or function:** +```bash +pytest caribou/tests/unit/test_message_utils.py::TestDelegationDetection +pytest caribou/tests/unit/test_message_utils.py::TestDelegationDetection::test_detect_simple_delegation +``` + +### Test Options + +**Verbose output:** +```bash +./run_tests.sh --verbose +# or +pytest caribou/tests/ -v +``` + +**Show print statements:** +```bash +pytest caribou/tests/ -s +``` + +**Run with coverage:** +```bash +./run_tests.sh --coverage +# or +pytest caribou/tests/ --cov=caribou --cov-report=html --cov-report=term +``` + +Coverage report will be in `htmlcov/index.html` + +**Stop on first failure:** +```bash +pytest caribou/tests/ -x +``` + +**Run only failed tests from last run:** +```bash +pytest caribou/tests/ --lf +``` + +## Test Coverage + +### Unit Tests + +#### test_anthropic_wrapper.py +- ✅ Client initialization (default and custom params) +- ✅ OpenAI-compatible interface structure +- ✅ System message extraction and combination +- ✅ Role filtering (assistant, user, system) +- ✅ API call parameter handling +- ✅ Response formatting to OpenAI structure +- ✅ Multiple text block handling +- ✅ Edge cases (empty messages, missing content, etc.) + +#### test_ollama_wrapper.py +- ✅ Client initialization with host variants +- ✅ OpenAI-compatible interface structure +- ✅ API call with temperature parameter +- ✅ ND-JSON response parsing +- ✅ Multi-line response handling +- ✅ Response structure validation +- ✅ Error handling (HTTP errors, invalid JSON, no message) + +#### test_message_utils.py +- ✅ Delegation command detection (various formats) +- ✅ RAG query detection +- ✅ Artifact extraction (notes, TODOs, checkboxes, code fences) +- ✅ Code block counting +- ✅ Code preview generation +- ✅ Edge cases (unicode, multiline, empty content) + +#### test_memory_manager.py +- ✅ Initialization with various parameters +- ✅ Message pinning strategy +- ✅ Adding messages and pivotal code +- ✅ System prompt updates +- ✅ Context assembly (pinned + pivotal + summaries + working) +- ✅ Summarization triggering logic +- ✅ Episodic summarization +- ✅ Multiple summarization rounds +- ✅ Error handling in summarization +- ✅ Context layout verification + +#### test_agent_system.py +- ✅ Command and Agent class creation +- ✅ Agent prompt generation (basic, with commands, with RAG, with samples) +- ✅ AgentSystem creation and agent retrieval +- ✅ Loading from JSON configuration +- ✅ Code sample loading from disk +- ✅ Extracting possible actions from agents +- ✅ Agent switching logic +- ✅ Memory manager updates during switch +- ✅ Action space updates during switch + +### Integration Tests + +#### test_message_flow.py +- ✅ Simple conversation flow +- ✅ Message flow with MemoryManager +- ✅ Delegation detection and agent switching +- ✅ Multi-agent conversation with memory +- ✅ RAG query detection and handling +- ✅ Artifact extraction from responses +- ✅ Complete workflow with all features +- ✅ Error resilience +- ✅ Long conversations with summarization + +## Writing New Tests + +### Test File Naming + +- Unit tests: `test_.py` in `unit/` +- Integration tests: `test_.py` in `integration/` + +### Test Class Naming + +Use descriptive class names grouped by functionality: + +```python +class TestComponentName: + """Test ComponentName functionality.""" + + def test_specific_behavior(self): + """Test that specific behavior works correctly.""" + pass +``` + +### Using Fixtures + +Common fixtures are defined in `conftest.py`: + +```python +def test_with_mock_client(mock_llm_client): + """Test using mock LLM client.""" + client = mock_llm_client(responses=["Response 1", "Response 2"]) + # Use client in test +``` + +Available fixtures: +- `mock_anthropic_response` - Create mock Anthropic responses +- `mock_openai_response` - Create mock OpenAI responses +- `sample_messages` - Pre-built message history +- `sample_agent_system` - Sample agent configuration +- `mock_llm_client` - Mock LLM client factory + +### Mocking External APIs + +Always mock external API calls in tests: + +```python +from unittest.mock import Mock, patch + +@patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") +def test_api_call(mock_anthropic): + mock_instance = Mock() + mock_response = Mock(content=[...]) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + # Test code here +``` + +## Continuous Integration + +To run tests in CI/CD: + +```bash +# Install dependencies +pip install -r requirements.txt +pip install pytest pytest-cov + +# Run tests with coverage +pytest caribou/tests/ --cov=caribou --cov-report=xml --cov-report=term + +# Fail if coverage is below threshold (optional) +pytest caribou/tests/ --cov=caribou --cov-fail-under=80 +``` + +## Troubleshooting + +### Import Errors + +Make sure CARIBOU is installed or the path is set: + +```bash +# From project root +pip install -e . +# or +export PYTHONPATH="${PYTHONPATH}:$(pwd)" +``` + +### Tests Hanging + +Some tests may hang if they're waiting for actual API calls. Ensure all external calls are mocked. + +### Fixture Not Found + +If pytest can't find a fixture, check: +1. Is it defined in `conftest.py`? +2. Is `conftest.py` in the correct location? +3. Are you importing fixtures correctly? + +## Best Practices + +1. **One assertion per test** (generally) - Makes failures easier to diagnose +2. **Test names should be descriptive** - `test_delegation_with_underscores_in_name` not `test_1` +3. **Mock external dependencies** - Never make real API calls in tests +4. **Use fixtures for common setup** - Keeps tests DRY +5. **Test edge cases** - Empty inputs, None values, extreme values +6. **Test error conditions** - Don't just test the happy path +7. **Keep tests fast** - Unit tests should run in milliseconds +8. **Make tests independent** - Each test should be runnable in isolation + +## Contributing + +When adding new features to CARIBOU: + +1. Write tests first (TDD) or alongside the feature +2. Ensure all tests pass: `./run_tests.sh` +3. Check coverage: `./run_tests.sh --coverage` +4. Add new test cases for edge cases +5. Update this README if adding new test categories + +## Questions? + +For questions about the test suite, please open an issue on the CARIBOU GitHub repository. diff --git a/caribou/tests/__init__.py b/caribou/tests/__init__.py new file mode 100644 index 0000000..025f495 --- /dev/null +++ b/caribou/tests/__init__.py @@ -0,0 +1,6 @@ +""" +CARIBOU Test Suite + +Unit and integration tests for LLM API calls, message routing, +history management, and system prompt integration. +""" diff --git a/caribou/tests/conftest.py b/caribou/tests/conftest.py new file mode 100644 index 0000000..9a7f131 --- /dev/null +++ b/caribou/tests/conftest.py @@ -0,0 +1,119 @@ +""" +Pytest configuration and shared fixtures for CARIBOU tests. +""" +import sys +from pathlib import Path + +# Add the caribou/src directory to the Python path so imports work +caribou_src = Path(__file__).parent.parent / "src" +if str(caribou_src) not in sys.path: + sys.path.insert(0, str(caribou_src)) + +import pytest +from types import SimpleNamespace +from typing import List, Dict, Any + + +@pytest.fixture +def mock_anthropic_response(): + """Create a mock Anthropic API response.""" + def _create_response(content: str, stop_reason: str = "end_turn"): + text_block = SimpleNamespace(type="text", text=content) + return SimpleNamespace( + content=[text_block], + stop_reason=stop_reason, + id="msg_123", + model="claude-sonnet-4-5-20250929", + role="assistant", + type="message", + ) + return _create_response + + +@pytest.fixture +def mock_openai_response(): + """Create a mock OpenAI API response.""" + def _create_response(content: str, finish_reason: str = "stop"): + message = SimpleNamespace(content=content, role="assistant") + choice = SimpleNamespace(message=message, index=0, finish_reason=finish_reason) + return SimpleNamespace( + choices=[choice], + id="chatcmpl-123", + model="gpt-4", + object="chat.completion", + ) + return _create_response + + +@pytest.fixture +def sample_messages(): + """Sample message history for testing.""" + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you!"}, + {"role": "user", "content": "Can you help me with a task?"}, + ] + + +@pytest.fixture +def sample_agent_system(): + """Sample agent system configuration.""" + return { + "global_policy": "Always be helpful and accurate.", + "agents": { + "planner": { + "prompt": "You are a planning agent.", + "neighbors": { + "delegate_to_coder": { + "target_agent": "coder", + "description": "Delegate coding tasks" + } + }, + "code_samples": [], + "rag": {"enabled": False} + }, + "coder": { + "prompt": "You are a coding agent.", + "neighbors": { + "delegate_to_planner": { + "target_agent": "planner", + "description": "Go back to planning" + } + }, + "code_samples": [], + "rag": {"enabled": False} + } + } + } + + +class MockLLMClient: + """Mock LLM client for testing.""" + + def __init__(self, responses: List[str] = None): + self.responses = responses or ["Mock response"] + self.call_count = 0 + self.calls = [] + + # Mock the nested structure: client.chat.completions.create() + self.chat = SimpleNamespace( + completions=SimpleNamespace(create=self._create) + ) + + def _create(self, **kwargs): + """Mock the chat.completions.create method.""" + self.calls.append(kwargs) + + response_text = self.responses[min(self.call_count, len(self.responses) - 1)] + self.call_count += 1 + + message = SimpleNamespace(content=response_text, role="assistant") + choice = SimpleNamespace(message=message, index=0, finish_reason="stop") + return SimpleNamespace(choices=[choice]) + + +@pytest.fixture +def mock_llm_client(): + """Fixture for mock LLM client.""" + return lambda responses=None: MockLLMClient(responses) diff --git a/caribou/tests/integration/__init__.py b/caribou/tests/integration/__init__.py new file mode 100644 index 0000000..feaeed3 --- /dev/null +++ b/caribou/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests for CARIBOU components.""" diff --git a/caribou/tests/integration/test_message_flow.py b/caribou/tests/integration/test_message_flow.py new file mode 100644 index 0000000..c6c5d13 --- /dev/null +++ b/caribou/tests/integration/test_message_flow.py @@ -0,0 +1,553 @@ +""" +Integration tests for end-to-end message flow. + +Tests the complete flow: LLM API call -> message routing -> history update +-> delegation handling -> agent switching. +""" +import pytest +from types import SimpleNamespace +from unittest.mock import Mock, patch +import tempfile +import json + +from caribou.agents.AgentSystem import Agent, AgentSystem, Command +from caribou.execution.MemoryManager import MemoryManager +from caribou.execution.ActionSpace import AgentActionSpace +from caribou.execution.message_utils import detect_delegation, detect_rag, _extract_artifacts_from_msg +from caribou.execution.agent_management import _apply_agent_switch, _extract_possible_actions + + +class TestBasicMessageFlow: + """Test basic message flow without special features.""" + + def test_simple_conversation_flow(self, mock_llm_client): + """Test simple back-and-forth conversation.""" + client = mock_llm_client(responses=[ + "Hello! How can I help you?", + "Sure, I can help with that.", + "Task completed successfully." + ]) + + history = [ + {"role": "system", "content": "You are a helpful assistant."} + ] + + # Simulate conversation turns + history.append({"role": "user", "content": "Hello"}) + response1 = client.chat.completions.create(messages=history) + history.append({"role": "assistant", "content": response1.choices[0].message.content}) + + history.append({"role": "user", "content": "Can you help me?"}) + response2 = client.chat.completions.create(messages=history) + history.append({"role": "assistant", "content": response2.choices[0].message.content}) + + history.append({"role": "user", "content": "Do the task"}) + response3 = client.chat.completions.create(messages=history) + history.append({"role": "assistant", "content": response3.choices[0].message.content}) + + # Verify conversation structure + assert len(history) == 7 # 1 system + 3 user + 3 assistant + assert history[-1]["role"] == "assistant" + assert "completed successfully" in history[-1]["content"] + assert client.call_count == 3 + + def test_message_flow_with_memory_manager(self, mock_llm_client): + """Test message flow using MemoryManager.""" + client = mock_llm_client(responses=["Response 1", "Response 2"]) + + initial_history = [ + {"role": "system", "content": "System prompt"} + ] + + memory_manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history, + working_history_size=4 + ) + + # Add messages + memory_manager.add_message("user", "Question 1") + context = memory_manager.get_context() + response1 = client.chat.completions.create(messages=context) + memory_manager.add_message("assistant", response1.choices[0].message.content) + + memory_manager.add_message("user", "Question 2") + context = memory_manager.get_context() + response2 = client.chat.completions.create(messages=context) + memory_manager.add_message("assistant", response2.choices[0].message.content) + + # Verify history + assert len(memory_manager._full_history) == 5 # 1 system + 2 user + 2 assistant + final_context = memory_manager.get_context() + assert len(final_context) >= 1 + + +class TestDelegationFlow: + """Test message flow with delegation between agents.""" + + def test_delegation_detection_and_switch(self, mock_llm_client): + """Test detecting delegation and switching agents.""" + # Create agent system + planner = Agent( + name="planner", + prompt="You are a planner.", + commands={ + "delegate_to_coder": Command("delegate_to_coder", "coder", "Send to coder") + }, + code_samples={} + ) + + coder = Agent( + name="coder", + prompt="You are a coder.", + commands={}, + code_samples={} + ) + + agent_system = AgentSystem( + global_policy="Be helpful", + agents={"planner": planner, "coder": coder} + ) + + # Simulate conversation with delegation + client = mock_llm_client(responses=[ + "I'll plan this task first.", + "Now I will delegate_to_coder to implement it.", + "Implementation complete." + ]) + + history = [ + {"role": "system", "content": agent_system.global_policy}, + {"role": "system", "content": planner.get_full_prompt()} + ] + + current_agent = planner + + # Turn 1: Normal response + history.append({"role": "user", "content": "Plan and implement feature X"}) + response1 = client.chat.completions.create(messages=history) + msg1 = response1.choices[0].message.content + history.append({"role": "assistant", "content": msg1}) + + # No delegation detected + assert detect_delegation(msg1) is None + + # Turn 2: Delegation response + history.append({"role": "user", "content": "Continue"}) + response2 = client.chat.completions.create(messages=history) + msg2 = response2.choices[0].message.content + history.append({"role": "assistant", "content": msg2}) + + # Delegation detected + cmd = detect_delegation(msg2) + assert cmd == "delegate_to_coder" + + # Apply agent switch + if cmd and cmd in current_agent.commands: + target_agent_name = current_agent.commands[cmd].target_agent + new_agent = agent_system.get_agent(target_agent_name) + + _apply_agent_switch( + new_agent_prompt=new_agent.get_full_prompt(), + analysis_context="", + history=history, + memory_manager=None, + action_space=None, + new_agent=new_agent + ) + + current_agent = new_agent + + # Verify switch occurred + assert current_agent.name == "coder" + assert any("coder" in msg.get("content", "").lower() for msg in history) + + def test_multi_agent_flow_with_memory(self, mock_llm_client): + """Test multi-agent conversation with memory management.""" + # Create agents + planner = Agent( + name="planner", + prompt="You are a planner.", + commands={ + "delegate_to_coder": Command("delegate_to_coder", "coder", "Send to coder") + }, + code_samples={} + ) + + coder = Agent( + name="coder", + prompt="You are a coder.", + commands={ + "delegate_to_planner": Command("delegate_to_planner", "planner", "Back to planner") + }, + code_samples={} + ) + + agent_system = AgentSystem( + global_policy="Be helpful", + agents={"planner": planner, "coder": coder} + ) + + client = mock_llm_client(responses=[ + "Planning the task.", + "delegate_to_coder to implement.", + "Implementing the feature.", + "delegate_to_planner to review.", + "Review complete." + ]) + + initial_history = [ + {"role": "system", "content": agent_system.global_policy}, + {"role": "system", "content": planner.get_full_prompt()} + ] + + memory_manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history + ) + + current_agent = planner + action_space = AgentActionSpace(agent_name=current_agent.name) + action_space.set_possible_actions(_extract_possible_actions(current_agent)) + + # Simulate multi-turn conversation with agent switches + for turn in range(5): + memory_manager.add_message("user", f"Turn {turn}") + context = memory_manager.get_context() + response = client.chat.completions.create(messages=context) + msg = response.choices[0].message.content + memory_manager.add_message("assistant", msg) + + # Check for delegation + cmd = detect_delegation(msg) + if cmd and cmd in current_agent.commands: + target_name = current_agent.commands[cmd].target_agent + new_agent = agent_system.get_agent(target_name) + + _apply_agent_switch( + new_agent_prompt=new_agent.get_full_prompt(), + analysis_context="", + history=memory_manager._full_history, + memory_manager=memory_manager, + action_space=action_space, + new_agent=new_agent + ) + + current_agent = new_agent + + # Verify multiple switches occurred + assert client.call_count == 5 + # Should have switched between planner and coder + assert action_space.agent_name in ["planner", "coder"] + + +class TestRAGIntegration: + """Test message flow with RAG queries.""" + + def test_rag_query_detection(self, mock_llm_client): + """Test detecting and handling RAG queries.""" + agent = Agent( + name="researcher", + prompt="You are a researcher.", + commands={}, + code_samples={}, + is_rag_enabled=True + ) + + client = mock_llm_client(responses=[ + "Let me search for information.", + "I will query_rag_ to find details.", + "Based on the documentation, here's the answer." + ]) + + history = [ + {"role": "system", "content": agent.get_full_prompt()} + ] + + # Turn 1: Normal response + history.append({"role": "user", "content": "How does the API work?"}) + response1 = client.chat.completions.create(messages=history) + msg1 = response1.choices[0].message.content + history.append({"role": "assistant", "content": msg1}) + assert detect_rag(msg1) is None + + # Turn 2: RAG query + history.append({"role": "user", "content": "Continue"}) + response2 = client.chat.completions.create(messages=history) + msg2 = response2.choices[0].message.content + history.append({"role": "assistant", "content": msg2}) + + rag_query = detect_rag(msg2) + assert rag_query == "API documentation" + + # Simulate RAG retrieval (mock) + rag_result = "RAG retrieved: API uses REST endpoints..." + history.append({"role": "system", "content": f"RAG RESULT: {rag_result}"}) + + # Turn 3: Response with RAG context + history.append({"role": "user", "content": "Continue with RAG results"}) + response3 = client.chat.completions.create(messages=history) + msg3 = response3.choices[0].message.content + history.append({"role": "assistant", "content": msg3}) + + # Verify RAG result was added to history + assert any("RAG RESULT" in msg.get("content", "") for msg in history) + + +class TestArtifactExtraction: + """Test message flow with artifact extraction.""" + + def test_notes_and_todos_extraction(self, mock_llm_client): + """Test extracting notes and TODOs from responses.""" + client = mock_llm_client(responses=[ + """ + I'll start working on this task. + + NOTE: Using Python 3.11 for this project + TODO: Install dependencies + """, + """ + Progress update: + + - [x] Dependencies installed + - [ ] Write main function + - [ ] Add tests + + NOTE: Found a good library for this + """ + ]) + + history = [{"role": "system", "content": "You are a coder."}] + + all_notes = [] + all_todos = [] + + # Turn 1 + history.append({"role": "user", "content": "Start the task"}) + response1 = client.chat.completions.create(messages=history) + msg1 = response1.choices[0].message.content + history.append({"role": "assistant", "content": msg1}) + + notes1, todos1 = _extract_artifacts_from_msg(msg1) + all_notes.extend(notes1) + all_todos.extend(todos1) + + # Turn 2 + history.append({"role": "user", "content": "Continue"}) + response2 = client.chat.completions.create(messages=history) + msg2 = response2.choices[0].message.content + history.append({"role": "assistant", "content": msg2}) + + notes2, todos2 = _extract_artifacts_from_msg(msg2) + all_notes.extend(notes2) + all_todos.extend(todos2) + + # Verify artifacts collected + assert len(all_notes) == 2 + assert any("Python 3.11" in note for note in all_notes) + assert any("good library" in note for note in all_notes) + + assert len(all_todos) == 4 + assert any("Install dependencies" in todo for todo in all_todos) + assert any("Write main function" in todo for todo in all_todos) + + +class TestCompleteIntegrationFlow: + """Test complete integration scenarios.""" + + def test_full_workflow_with_all_features(self, mock_llm_client): + """Test complete workflow with agents, memory, RAG, and artifacts.""" + # Create comprehensive agent system + planner = Agent( + name="planner", + prompt="You are a planner.", + commands={ + "delegate_to_researcher": Command("delegate_to_researcher", "researcher", "Research") + }, + code_samples={} + ) + + researcher = Agent( + name="researcher", + prompt="You are a researcher.", + commands={ + "delegate_to_planner": Command("delegate_to_planner", "planner", "Back to planning") + }, + code_samples={}, + is_rag_enabled=True + ) + + agent_system = AgentSystem( + global_policy="Be thorough and accurate", + agents={"planner": planner, "researcher": researcher} + ) + + client = mock_llm_client(responses=[ + "I will plan the research task.\nNOTE: Starting with literature review", + "delegate_to_researcher to find information", + "I will query_rag_ for context", + "Based on research: ML uses algorithms.\nTODO: Compile findings\n- [ ] Write report", + "delegate_to_planner to finalize", + "Final plan complete.\n- [x] Research done\n- [x] Report written" + ]) + + initial_history = [ + {"role": "system", "content": agent_system.global_policy}, + {"role": "system", "content": planner.get_full_prompt()} + ] + + memory_manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history, + working_history_size=4 + ) + + current_agent = planner + action_space = AgentActionSpace(agent_name=current_agent.name) + all_notes = [] + all_todos = [] + + # Simulate complete workflow + for turn in range(6): + memory_manager.add_message("user", f"Continue with task (turn {turn})") + context = memory_manager.get_context() + response = client.chat.completions.create(messages=context) + msg = response.choices[0].message.content + memory_manager.add_message("assistant", msg) + + # Extract artifacts + notes, todos = _extract_artifacts_from_msg(msg) + all_notes.extend(notes) + all_todos.extend(todos) + + # Check for delegation + cmd = detect_delegation(msg) + if cmd and cmd in current_agent.commands: + target_name = current_agent.commands[cmd].target_agent + new_agent = agent_system.get_agent(target_name) + + _apply_agent_switch( + new_agent_prompt=new_agent.get_full_prompt(), + analysis_context="", + history=memory_manager._full_history, + memory_manager=memory_manager, + action_space=action_space, + new_agent=new_agent + ) + + current_agent = new_agent + + # Check for RAG + rag_query = detect_rag(msg) + if rag_query: + rag_result = f"RAG results for: {rag_query}" + memory_manager.add_message("system", f"RAG RESULT: {rag_result}") + + # Verify complete workflow + assert client.call_count == 6 + assert len(all_notes) >= 1 # Should have collected notes + assert len(all_todos) >= 1 # Should have collected TODOs + assert len(memory_manager._full_history) > 10 # Should have substantial history + + def test_error_resilience_in_flow(self, mock_llm_client): + """Test that flow continues gracefully when components have issues.""" + client = mock_llm_client(responses=[ + "Normal response", + "Invalid delegation: delegate_to_nonexistent", + "Recovery response" + ]) + + agent = Agent( + name="agent", + prompt="Prompt", + commands={ + "delegate_to_other": Command("delegate_to_other", "other", "Delegate") + }, + code_samples={} + ) + + agent_system = AgentSystem( + global_policy="Global", + agents={"agent": agent, "other": agent} + ) + + history = [ + {"role": "system", "content": agent.get_full_prompt()} + ] + + current_agent = agent + + # Turn 1: Normal + history.append({"role": "user", "content": "Start"}) + response1 = client.chat.completions.create(messages=history) + msg1 = response1.choices[0].message.content + history.append({"role": "assistant", "content": msg1}) + + # Turn 2: Invalid delegation (should not crash) + history.append({"role": "user", "content": "Continue"}) + response2 = client.chat.completions.create(messages=history) + msg2 = response2.choices[0].message.content + history.append({"role": "assistant", "content": msg2}) + + cmd = detect_delegation(msg2) + assert cmd == "delegate_to_nonexistent" + + # Try to switch (should fail gracefully) + if cmd and cmd in current_agent.commands: + # This won't execute because command doesn't exist + pass + else: + # Flow continues without switching + pass + + # Turn 3: Recovery + history.append({"role": "user", "content": "Recover"}) + response3 = client.chat.completions.create(messages=history) + msg3 = response3.choices[0].message.content + history.append({"role": "assistant", "content": msg3}) + + # Should complete without crashing + assert client.call_count == 3 + assert current_agent.name == "agent" # No switch occurred + + +class TestMemoryAndContextManagement: + """Test memory management during long conversations.""" + + def test_long_conversation_with_summarization(self, mock_llm_client): + """Test that long conversations trigger summarization.""" + # Create many responses + responses = [f"Response {i}" for i in range(30)] + responses.append("This is a summary of previous messages.") # For summarization call + responses.extend([f"Response {i}" for i in range(30, 35)]) + + client = mock_llm_client(responses=responses) + + initial_history = [{"role": "system", "content": "System"}] + + memory_manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history, + working_history_size=3, + summarization_threshold=5, + chunk_size_to_summarize=5 + ) + + # Add many messages + for i in range(30): + memory_manager.add_message("user", f"Question {i}") + context = memory_manager.get_context() + response = client.chat.completions.create(messages=context) + memory_manager.add_message("assistant", response.choices[0].message.content) + + # Should have triggered summarization + assert len(memory_manager._summarized_log) >= 1 + + # Context should be manageable size + context = memory_manager.get_context() + # Should be much smaller than full history + assert len(context) < len(memory_manager._full_history) diff --git a/caribou/tests/pytest.ini b/caribou/tests/pytest.ini new file mode 100644 index 0000000..6a6bb1b --- /dev/null +++ b/caribou/tests/pytest.ini @@ -0,0 +1,40 @@ +[pytest] +# Pytest configuration for CARIBOU test suite + +# Test discovery patterns +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Output options +addopts = + --strict-markers + --tb=short + --disable-warnings + -ra + +# Test paths +testpaths = . + +# Markers for categorizing tests +markers = + unit: Unit tests for individual components + integration: Integration tests for component interactions + slow: Tests that take a long time to run + api: Tests that interact with external APIs (should be mocked) + +# Coverage options (when using --cov) +[coverage:run] +source = caribou +omit = + */tests/* + */test_*.py + */__pycache__/* + +[coverage:report] +precision = 2 +show_missing = True +skip_covered = False + +[coverage:html] +directory = htmlcov diff --git a/caribou/tests/run_tests.sh b/caribou/tests/run_tests.sh new file mode 100755 index 0000000..ab48e1a --- /dev/null +++ b/caribou/tests/run_tests.sh @@ -0,0 +1,119 @@ +#!/bin/bash + +# Test runner script for CARIBOU test suite +# Usage: ./run_tests.sh [options] + +set -e # Exit on error + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +echo "======================================" +echo "CARIBOU Test Suite Runner" +echo "======================================" +echo "" + +# Get the directory where this script is located +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +PROJECT_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")" + +cd "$PROJECT_ROOT" + +# Parse arguments +TEST_TYPE="all" +VERBOSE="" +COVERAGE="" + +while [[ $# -gt 0 ]]; do + case $1 in + --unit) + TEST_TYPE="unit" + shift + ;; + --integration) + TEST_TYPE="integration" + shift + ;; + -v|--verbose) + VERBOSE="-v" + shift + ;; + --coverage) + COVERAGE="--cov=caribou --cov-report=html --cov-report=term" + shift + ;; + -h|--help) + echo "Usage: $0 [options]" + echo "" + echo "Options:" + echo " --unit Run only unit tests" + echo " --integration Run only integration tests" + echo " -v, --verbose Verbose output" + echo " --coverage Generate coverage report" + echo " -h, --help Show this help message" + exit 0 + ;; + *) + echo -e "${RED}Unknown option: $1${NC}" + exit 1 + ;; + esac +done + +# Detect the correct Python and pytest to use +# Priority: python in current environment > python3 > python +if command -v python &> /dev/null; then + PYTHON_CMD="python" +elif command -v python3 &> /dev/null; then + PYTHON_CMD="python3" +else + echo -e "${RED}Error: No Python interpreter found${NC}" + exit 1 +fi + +# Check if pytest is installed for the current Python +if ! $PYTHON_CMD -m pytest --version &> /dev/null; then + echo -e "${RED}Error: pytest is not installed for $PYTHON_CMD${NC}" + echo "Install with: $PYTHON_CMD -m pip install pytest pytest-cov" + exit 1 +fi + +echo "Using Python: $($PYTHON_CMD --version)" +echo "Using pytest: $($PYTHON_CMD -m pytest --version | head -1)" +echo "" + +# Run tests based on TEST_TYPE +case $TEST_TYPE in + unit) + echo -e "${YELLOW}Running unit tests...${NC}" + echo "" + $PYTHON_CMD -m pytest caribou/tests/unit/ $VERBOSE $COVERAGE + ;; + integration) + echo -e "${YELLOW}Running integration tests...${NC}" + echo "" + $PYTHON_CMD -m pytest caribou/tests/integration/ $VERBOSE $COVERAGE + ;; + all) + echo -e "${YELLOW}Running all tests...${NC}" + echo "" + $PYTHON_CMD -m pytest caribou/tests/ $VERBOSE $COVERAGE + ;; +esac + +# Check exit code +if [ $? -eq 0 ]; then + echo "" + echo -e "${GREEN}======================================" + echo "All tests passed! ✓" + echo -e "======================================${NC}" +else + echo "" + echo -e "${RED}======================================" + echo "Some tests failed! ✗" + echo -e "======================================${NC}" + exit 1 +fi diff --git a/caribou/tests/unit/__init__.py b/caribou/tests/unit/__init__.py new file mode 100644 index 0000000..e152f5f --- /dev/null +++ b/caribou/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for CARIBOU components.""" diff --git a/caribou/tests/unit/test_agent_system.py b/caribou/tests/unit/test_agent_system.py new file mode 100644 index 0000000..8743b3e --- /dev/null +++ b/caribou/tests/unit/test_agent_system.py @@ -0,0 +1,635 @@ +""" +Unit tests for Agent system and agent management. + +Tests agent configuration, command handling, prompt generation, +and agent switching logic. +""" +import pytest +import json +import tempfile +from pathlib import Path +from typing import Dict +from unittest.mock import Mock, patch + +from caribou.agents.AgentSystem import Agent, AgentSystem, Command +from caribou.execution.agent_management import _extract_possible_actions, _apply_agent_switch +from caribou.execution.ActionSpace import AgentActionSpace + + +class TestCommand: + """Test Command class.""" + + def test_command_creation(self): + """Test creating a command.""" + cmd = Command( + name="delegate_to_coder", + target_agent="coder", + description="Send task to coder" + ) + + assert cmd.name == "delegate_to_coder" + assert cmd.target_agent == "coder" + assert cmd.description == "Send task to coder" + + def test_command_repr(self): + """Test command string representation.""" + cmd = Command( + name="test_cmd", + target_agent="test_agent", + description="A" * 50 # Long description + ) + + repr_str = repr(cmd) + assert "test_cmd" in repr_str + assert "test_agent" in repr_str + assert "..." in repr_str # Truncated description + + +class TestAgent: + """Test Agent class.""" + + def test_agent_creation(self): + """Test creating an agent.""" + commands = { + "delegate_to_other": Command("delegate_to_other", "other", "Delegate to other") + } + samples = {"sample1.py": "print('hello')"} + + agent = Agent( + name="test_agent", + prompt="You are a test agent.", + commands=commands, + code_samples=samples, + is_rag_enabled=True + ) + + assert agent.name == "test_agent" + assert agent.prompt == "You are a test agent." + assert len(agent.commands) == 1 + assert len(agent.code_samples) == 1 + assert agent.is_rag_enabled is True + + def test_agent_repr(self): + """Test agent string representation.""" + commands = {"cmd1": Command("cmd1", "target", "desc")} + samples = {"sample.py": "code"} + + agent = Agent( + name="agent1", + prompt="Prompt", + commands=commands, + code_samples=samples + ) + + repr_str = repr(agent) + assert "agent1" in repr_str + assert "cmd1" in repr_str + assert "sample.py" in repr_str + + def test_get_full_prompt_basic(self): + """Test basic prompt generation.""" + agent = Agent( + name="basic", + prompt="Basic prompt", + commands={}, + code_samples={} + ) + + prompt = agent.get_full_prompt() + assert "Basic prompt" in prompt + assert "NOTE:" in prompt # Notes instruction + assert "TODO:" in prompt # TODO instruction + + def test_get_full_prompt_with_global_policy(self): + """Test prompt with global policy.""" + agent = Agent( + name="agent", + prompt="Agent prompt", + commands={}, + code_samples={} + ) + + prompt = agent.get_full_prompt(global_policy="Be helpful") + assert "**GLOBAL POLICY**: Be helpful" in prompt + assert "Agent prompt" in prompt + + def test_get_full_prompt_with_commands(self): + """Test prompt with commands.""" + commands = { + "delegate_to_coder": Command( + "delegate_to_coder", + "coder", + "Send coding tasks" + ), + "delegate_to_planner": Command( + "delegate_to_planner", + "planner", + "Send planning tasks" + ) + } + + agent = Agent( + name="agent", + prompt="Agent prompt", + commands=commands, + code_samples={} + ) + + prompt = agent.get_full_prompt() + assert "delegate_to_coder" in prompt + assert "delegate_to_planner" in prompt + assert "Send coding tasks" in prompt + assert "coder" in prompt + assert "YOU MUST USE THESE EXACT COMMANDS" in prompt + + def test_get_full_prompt_with_rag(self): + """Test prompt with RAG enabled.""" + agent = Agent( + name="agent", + prompt="Agent prompt", + commands={}, + code_samples={}, + is_rag_enabled=True + ) + + prompt = agent.get_full_prompt() + assert "query_rag_" in prompt + assert "knowledge base" in prompt + assert "query_rag_" in prompt # Example + + def test_get_full_prompt_with_code_samples(self): + """Test prompt with code samples.""" + samples = { + "example1.py": "code1", + "example2.py": "code2" + } + + agent = Agent( + name="agent", + prompt="Agent prompt", + commands={}, + code_samples=samples + ) + + prompt = agent.get_full_prompt() + assert "Code Samples Available" in prompt + assert "example1.py" in prompt + assert "example2.py" in prompt + assert "MUST BE REWRITTEN TO BE USED" in prompt + + def test_get_full_prompt_complete(self): + """Test prompt with all features.""" + commands = { + "delegate_to_other": Command("delegate_to_other", "other", "Delegate") + } + samples = {"sample.py": "code"} + + agent = Agent( + name="complete", + prompt="Complete agent", + commands=commands, + code_samples=samples, + is_rag_enabled=True + ) + + prompt = agent.get_full_prompt(global_policy="Global") + + # Check all sections present + assert "**GLOBAL POLICY**: Global" in prompt + assert "Complete agent" in prompt + assert "delegate_to_other" in prompt + assert "query_rag_" in prompt + assert "sample.py" in prompt + assert "NOTE:" in prompt + assert "TODO:" in prompt + + +class TestAgentSystem: + """Test AgentSystem class.""" + + def test_agent_system_creation(self): + """Test creating an agent system.""" + agents = { + "agent1": Agent("agent1", "Prompt 1", {}, {}), + "agent2": Agent("agent2", "Prompt 2", {}, {}) + } + + system = AgentSystem(global_policy="Be helpful", agents=agents) + + assert system.global_policy == "Be helpful" + assert len(system.agents) == 2 + assert "agent1" in system.agents + assert "agent2" in system.agents + + def test_get_agent(self): + """Test retrieving agent by name.""" + agent1 = Agent("agent1", "Prompt", {}, {}) + agents = {"agent1": agent1} + + system = AgentSystem(global_policy="Policy", agents=agents) + + retrieved = system.get_agent("agent1") + assert retrieved is agent1 + + not_found = system.get_agent("nonexistent") + assert not_found is None + + def test_get_all_agents(self): + """Test getting all agents.""" + agents = { + "agent1": Agent("agent1", "P1", {}, {}), + "agent2": Agent("agent2", "P2", {}, {}) + } + + system = AgentSystem(global_policy="Policy", agents=agents) + + all_agents = system.get_all_agents() + assert len(all_agents) == 2 + assert "agent1" in all_agents + assert "agent2" in all_agents + + def test_get_instructions(self): + """Test generating system instructions.""" + commands = { + "delegate_to_agent2": Command("delegate_to_agent2", "agent2", "Delegate") + } + agents = { + "agent1": Agent("agent1", "Agent 1 prompt", commands, {}) + } + + system = AgentSystem(global_policy="Be accurate", agents=agents) + + instructions = system.get_instructions() + assert "Be accurate" in instructions + assert "agent1" in instructions + assert "Agent 1 prompt" in instructions + assert "delegate_to_agent2" in instructions + + def test_agent_system_repr(self): + """Test agent system string representation.""" + agents = { + "agent1": Agent("agent1", "P1", {}, {}), + "agent2": Agent("agent2", "P2", {}, {}) + } + + system = AgentSystem(global_policy="A" * 50, agents=agents) + + repr_str = repr(system) + assert "AgentSystem" in repr_str + assert "agent1" in repr_str + assert "agent2" in repr_str + assert "..." in repr_str # Truncated policy + + +class TestAgentSystemLoadFromJSON: + """Test loading agent system from JSON.""" + + def test_load_basic_config(self): + """Test loading basic configuration.""" + config = { + "global_policy": "Be helpful and accurate.", + "agents": { + "planner": { + "prompt": "You are a planning agent.", + "neighbors": {}, + "code_samples": [], + "rag": {"enabled": False} + } + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config, f) + temp_path = f.name + + try: + system = AgentSystem.load_from_json(temp_path) + + assert system.global_policy == "Be helpful and accurate." + assert len(system.agents) == 1 + assert "planner" in system.agents + + planner = system.get_agent("planner") + assert planner.name == "planner" + assert planner.prompt == "You are a planning agent." + assert planner.is_rag_enabled is False + finally: + Path(temp_path).unlink() + + def test_load_with_commands(self): + """Test loading configuration with commands.""" + config = { + "global_policy": "Global", + "agents": { + "planner": { + "prompt": "Planner prompt", + "neighbors": { + "delegate_to_coder": { + "target_agent": "coder", + "description": "Send to coder" + } + }, + "code_samples": [], + "rag": {"enabled": False} + }, + "coder": { + "prompt": "Coder prompt", + "neighbors": {}, + "code_samples": [], + "rag": {"enabled": False} + } + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config, f) + temp_path = f.name + + try: + system = AgentSystem.load_from_json(temp_path) + + planner = system.get_agent("planner") + assert len(planner.commands) == 1 + assert "delegate_to_coder" in planner.commands + + cmd = planner.commands["delegate_to_coder"] + assert cmd.target_agent == "coder" + assert cmd.description == "Send to coder" + finally: + Path(temp_path).unlink() + + def test_load_with_rag_enabled(self): + """Test loading configuration with RAG enabled.""" + config = { + "global_policy": "Global", + "agents": { + "researcher": { + "prompt": "Research prompt", + "neighbors": {}, + "code_samples": [], + "rag": {"enabled": True} + } + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config, f) + temp_path = f.name + + try: + system = AgentSystem.load_from_json(temp_path) + + researcher = system.get_agent("researcher") + assert researcher.is_rag_enabled is True + finally: + Path(temp_path).unlink() + + @patch('caribou.agents.AgentSystem.USER_CODE_SAMPLES_DIR') + @patch('caribou.agents.AgentSystem.PACKAGE_CODE_SAMPLES_DIR') + def test_load_with_code_samples(self, mock_package_dir, mock_user_dir): + """Test loading configuration with code samples.""" + # Create temporary directories for code samples + with tempfile.TemporaryDirectory() as temp_user_dir, \ + tempfile.TemporaryDirectory() as temp_package_dir: + + mock_user_dir.__truediv__ = lambda self, x: Path(temp_user_dir) / x + mock_package_dir.__truediv__ = lambda self, x: Path(temp_package_dir) / x + + # Create a sample file + sample_path = Path(temp_package_dir) / "sample.py" + sample_path.write_text("print('hello')") + + config = { + "global_policy": "Global", + "agents": { + "coder": { + "prompt": "Coder prompt", + "neighbors": {}, + "code_samples": ["sample.py"], + "rag": {"enabled": False} + } + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config, f) + temp_config_path = f.name + + try: + # Mock the path resolution + with patch.object(Path, 'exists', return_value=True), \ + patch.object(Path, 'read_text', return_value="print('hello')"): + system = AgentSystem.load_from_json(temp_config_path) + + coder = system.get_agent("coder") + assert len(coder.code_samples) == 1 + assert "sample.py" in coder.code_samples + finally: + Path(temp_config_path).unlink() + + +class TestExtractPossibleActions: + """Test extracting possible actions from agents.""" + + def test_extract_from_basic_agent(self): + """Test extracting actions from agent with no special features.""" + agent = Agent(name="basic", prompt="Prompt", commands={}, code_samples={}) + + actions = _extract_possible_actions(agent) + + # Should always have "continue" action + assert len(actions) >= 1 + assert any(action["name"] == "continue" for action in actions) + + def test_extract_with_commands(self): + """Test extracting actions including commands.""" + commands = { + "delegate_to_coder": Command("delegate_to_coder", "coder", "Send to coder"), + "delegate_to_planner": Command("delegate_to_planner", "planner", "Send to planner") + } + agent = Agent(name="agent", prompt="Prompt", commands=commands, code_samples={}) + + actions = _extract_possible_actions(agent) + + assert len(actions) >= 3 # continue + 2 commands + assert any(action["name"] == "delegate_to_coder" for action in actions) + assert any(action["name"] == "delegate_to_planner" for action in actions) + + def test_extract_with_rag(self): + """Test extracting actions with RAG enabled.""" + agent = Agent( + name="agent", + prompt="Prompt", + commands={}, + code_samples={}, + is_rag_enabled=True + ) + + actions = _extract_possible_actions(agent) + + # Should include RAG action + rag_actions = [a for a in actions if "query_rag" in a["name"]] + assert len(rag_actions) == 1 + assert "knowledge base" in rag_actions[0]["detail"] + + def test_extract_all_features(self): + """Test extracting actions with all features.""" + commands = { + "delegate_to_other": Command("delegate_to_other", "other", "Delegate") + } + agent = Agent( + name="agent", + prompt="Prompt", + commands=commands, + code_samples={}, + is_rag_enabled=True + ) + + actions = _extract_possible_actions(agent) + + # Should have: continue + rag + command + assert len(actions) >= 3 + action_names = [a["name"] for a in actions] + assert "continue" in action_names + assert "delegate_to_other" in action_names + assert any("query_rag" in name for name in action_names) + + +class TestApplyAgentSwitch: + """Test applying agent switches.""" + + def test_apply_agent_switch_updates_history(self, mock_llm_client): + """Test that agent switch updates history.""" + new_agent = Agent( + name="new_agent", + prompt="New agent prompt", + commands={}, + code_samples={} + ) + + history = [ + {"role": "system", "content": "Global policy"}, + {"role": "system", "content": "Old agent prompt"}, + {"role": "user", "content": "Hello"} + ] + + _apply_agent_switch( + new_agent_prompt="New agent prompt", + analysis_context="", + history=history, + memory_manager=None, + action_space=None, + new_agent=new_agent + ) + + # Second message should be updated + assert history[1]["content"] == "New agent prompt\n\n" + + def test_apply_agent_switch_with_memory_manager(self, mock_llm_client): + """Test agent switch with memory manager.""" + client = mock_llm_client() + memory_manager = Mock() + + new_agent = Agent( + name="new_agent", + prompt="New prompt", + commands={}, + code_samples={} + ) + + history = [ + {"role": "system", "content": "Global"}, + {"role": "system", "content": "Old prompt"} + ] + + _apply_agent_switch( + new_agent_prompt="New prompt", + analysis_context="", + history=history, + memory_manager=memory_manager, + action_space=None, + new_agent=new_agent + ) + + # Should update memory manager + memory_manager.update_system_prompt.assert_called_once() + memory_manager.add_message.assert_called() + + # Should add reminder to history + assert any("REMINDER" in msg.get("content", "") for msg in history) + + def test_apply_agent_switch_with_action_space(self): + """Test agent switch with action space.""" + action_space = AgentActionSpace(agent_name="old_agent") + + new_agent = Agent( + name="new_agent", + prompt="New prompt", + commands={ + "delegate_to_other": Command("delegate_to_other", "other", "Delegate") + }, + code_samples={} + ) + + history = [ + {"role": "system", "content": "Global"}, + {"role": "system", "content": "Old prompt"} + ] + + _apply_agent_switch( + new_agent_prompt="New prompt", + analysis_context="", + history=history, + memory_manager=None, + action_space=action_space, + new_agent=new_agent + ) + + # Should update action space + assert action_space.agent_name == "new_agent" + + # Should add action space message to history + action_space_msgs = [msg for msg in history if "ACTION SPACE UPDATE" in msg.get("content", "")] + assert len(action_space_msgs) >= 1 + + def test_apply_agent_switch_with_analysis_context(self): + """Test agent switch includes analysis context.""" + new_agent = Agent(name="agent", prompt="Prompt", commands={}, code_samples={}) + + history = [ + {"role": "system", "content": "Global"}, + {"role": "system", "content": "Old"} + ] + + analysis_context = "\nAdditional context about the task." + + _apply_agent_switch( + new_agent_prompt="New prompt", + analysis_context=analysis_context, + history=history, + memory_manager=None, + action_space=None, + new_agent=new_agent + ) + + # Should include analysis context in updated prompt + assert "Additional context about the task." in history[1]["content"] + + def test_apply_agent_switch_inserts_if_needed(self): + """Test agent switch inserts prompt if history is short.""" + new_agent = Agent(name="agent", prompt="Prompt", commands={}, code_samples={}) + + history = [{"role": "system", "content": "Global"}] + + _apply_agent_switch( + new_agent_prompt="New prompt", + analysis_context="", + history=history, + memory_manager=None, + action_space=None, + new_agent=new_agent + ) + + # Should insert at index 1 + assert len(history) == 2 + assert history[1]["content"] == "New prompt\n\n" diff --git a/caribou/tests/unit/test_anthropic_wrapper.py b/caribou/tests/unit/test_anthropic_wrapper.py new file mode 100644 index 0000000..0fa9fab --- /dev/null +++ b/caribou/tests/unit/test_anthropic_wrapper.py @@ -0,0 +1,365 @@ +""" +Unit tests for AnthropicClient wrapper. + +Tests the OpenAI API compatibility layer and message conversion. +""" +import pytest +from types import SimpleNamespace +from unittest.mock import Mock, patch, MagicMock + +from caribou.core.anthropic_wrapper import AnthropicClient + + +class TestAnthropicClientInitialization: + """Test AnthropicClient initialization.""" + + def test_init_with_minimal_params(self): + """Test initialization with only required parameters.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + client = AnthropicClient(api_key="test-key") + + assert client._default_model == "claude-sonnet-4-5-20250929" + assert client._max_output_tokens == 1024 + mock_anthropic.assert_called_once_with(api_key="test-key") + + def test_init_with_custom_params(self): + """Test initialization with custom parameters.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + client = AnthropicClient( + api_key="test-key", + model="claude-opus-4", + max_output_tokens=2048, + base_url="https://custom.api.com" + ) + + assert client._default_model == "claude-opus-4" + assert client._max_output_tokens == 2048 + mock_anthropic.assert_called_once_with( + api_key="test-key", + base_url="https://custom.api.com" + ) + + def test_chat_completions_interface_exists(self): + """Test that the OpenAI-compatible interface exists.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic"): + client = AnthropicClient(api_key="test-key") + + assert hasattr(client, "chat") + assert hasattr(client.chat, "completions") + assert hasattr(client.chat.completions, "create") + assert callable(client.chat.completions.create) + + +class TestAnthropicClientMessageConversion: + """Test message format conversion from OpenAI to Anthropic.""" + + def test_system_message_extraction(self): + """Test that system messages are extracted and combined.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[SimpleNamespace(type="text", text="Response")], + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + messages = [ + {"role": "system", "content": "Global policy"}, + {"role": "system", "content": "Agent prompt"}, + {"role": "user", "content": "Hello"}, + ] + + client.chat.completions.create(messages=messages) + + call_args = mock_instance.messages.create.call_args + assert call_args[1]["system"] == "Global policy\n\nAgent prompt" + assert len(call_args[1]["messages"]) == 1 + assert call_args[1]["messages"][0] == {"role": "user", "content": "Hello"} + + def test_system_message_absent(self): + """Test handling when no system messages present.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[SimpleNamespace(type="text", text="Response")], + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + messages = [ + {"role": "user", "content": "Hello"}, + ] + + client.chat.completions.create(messages=messages) + + call_args = mock_instance.messages.create.call_args + assert call_args[1]["system"] is None + + def test_role_filtering(self): + """Test that only assistant and user roles are preserved.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[SimpleNamespace(type="text", text="Response")], + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + messages = [ + {"role": "user", "content": "Message 1"}, + {"role": "assistant", "content": "Message 2"}, + {"role": "function", "content": "Message 3"}, # Should be converted to user + {"role": "tool", "content": "Message 4"}, # Should be converted to user + ] + + client.chat.completions.create(messages=messages) + + call_args = mock_instance.messages.create.call_args + converted_messages = call_args[1]["messages"] + + assert len(converted_messages) == 4 + assert converted_messages[0]["role"] == "user" + assert converted_messages[1]["role"] == "assistant" + assert converted_messages[2]["role"] == "user" # function -> user + assert converted_messages[3]["role"] == "user" # tool -> user + + +class TestAnthropicClientAPICall: + """Test actual API call behavior.""" + + def test_api_call_with_default_params(self): + """Test API call with default parameters.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[SimpleNamespace(type="text", text="Test response")], + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + messages = [{"role": "user", "content": "Hello"}] + client.chat.completions.create(messages=messages) + + mock_instance.messages.create.assert_called_once() + call_args = mock_instance.messages.create.call_args[1] + + assert call_args["model"] == "claude-sonnet-4-5-20250929" + assert call_args["max_tokens"] == 1024 + assert call_args["temperature"] is None + + def test_api_call_with_override_params(self): + """Test API call with parameter overrides.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[SimpleNamespace(type="text", text="Test response")], + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key", model="default-model") + + messages = [{"role": "user", "content": "Hello"}] + client.chat.completions.create( + messages=messages, + model="override-model", + temperature=0.7, + max_output_tokens=512 + ) + + call_args = mock_instance.messages.create.call_args[1] + + assert call_args["model"] == "override-model" + assert call_args["max_tokens"] == 512 + assert call_args["temperature"] == 0.7 + + def test_extra_kwargs_ignored(self): + """Test that extra unknown kwargs are ignored.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[SimpleNamespace(type="text", text="Test response")], + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + messages = [{"role": "user", "content": "Hello"}] + # Should not raise error for unknown kwargs + client.chat.completions.create( + messages=messages, + unknown_param="should_be_ignored", + another_unknown=123 + ) + + mock_instance.messages.create.assert_called_once() + + +class TestAnthropicClientResponseFormatting: + """Test response formatting to OpenAI-compatible structure.""" + + def test_response_structure(self): + """Test that response has correct OpenAI-compatible structure.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[SimpleNamespace(type="text", text="Test response")], + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + messages = [{"role": "user", "content": "Hello"}] + response = client.chat.completions.create(messages=messages) + + # Check OpenAI-compatible structure + assert hasattr(response, "choices") + assert len(response.choices) == 1 + assert hasattr(response.choices[0], "message") + assert hasattr(response.choices[0].message, "content") + assert hasattr(response.choices[0].message, "role") + assert response.choices[0].message.content == "Test response" + assert response.choices[0].message.role == "assistant" + + def test_multiple_text_blocks(self): + """Test combining multiple text blocks from Anthropic response.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[ + SimpleNamespace(type="text", text="Part 1 "), + SimpleNamespace(type="text", text="Part 2 "), + SimpleNamespace(type="text", text="Part 3"), + ], + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + messages = [{"role": "user", "content": "Hello"}] + response = client.chat.completions.create(messages=messages) + + assert response.choices[0].message.content == "Part 1 Part 2 Part 3" + + def test_non_text_blocks_ignored(self): + """Test that non-text blocks are filtered out.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[ + SimpleNamespace(type="text", text="Text content"), + SimpleNamespace(type="tool_use", id="tool_123"), # Should be ignored + SimpleNamespace(type="image", source="data:image"), # Should be ignored + ], + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + messages = [{"role": "user", "content": "Hello"}] + response = client.chat.completions.create(messages=messages) + + assert response.choices[0].message.content == "Text content" + + def test_finish_reason_mapping(self): + """Test that stop_reason is mapped to finish_reason.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[SimpleNamespace(type="text", text="Response")], + stop_reason="max_tokens" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + messages = [{"role": "user", "content": "Hello"}] + response = client.chat.completions.create(messages=messages) + + assert response.choices[0].finish_reason == "max_tokens" + + +class TestAnthropicClientEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_messages_list(self): + """Test handling of empty messages list.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[SimpleNamespace(type="text", text="Response")], + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + # Should not raise error, just pass empty list + client.chat.completions.create(messages=[]) + + call_args = mock_instance.messages.create.call_args[1] + assert call_args["messages"] == [] + + def test_missing_content_key(self): + """Test handling of messages missing content key.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[SimpleNamespace(type="text", text="Response")], + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + messages = [ + {"role": "user"}, # Missing content + {"role": "assistant", "content": "Hello"} + ] + + client.chat.completions.create(messages=messages) + + call_args = mock_instance.messages.create.call_args[1] + assert call_args["messages"][0]["content"] == "" + + def test_empty_response_content(self): + """Test handling of empty response from Anthropic.""" + with patch("caribou.core.anthropic_wrapper.anthropic.Anthropic") as mock_anthropic: + mock_instance = Mock() + mock_response = Mock( + content=[], # Empty content + stop_reason="end_turn" + ) + mock_instance.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_instance + + client = AnthropicClient(api_key="test-key") + + messages = [{"role": "user", "content": "Hello"}] + response = client.chat.completions.create(messages=messages) + + assert response.choices[0].message.content == "" diff --git a/caribou/tests/unit/test_memory_manager.py b/caribou/tests/unit/test_memory_manager.py new file mode 100644 index 0000000..f1de81b --- /dev/null +++ b/caribou/tests/unit/test_memory_manager.py @@ -0,0 +1,533 @@ +""" +Unit tests for MemoryManager. + +Tests conversation history management, episodic summarization, +and context assembly. +""" +import pytest +from unittest.mock import Mock, MagicMock +from types import SimpleNamespace + +from caribou.execution.MemoryManager import MemoryManager + + +class TestMemoryManagerInitialization: + """Test MemoryManager initialization.""" + + def test_init_with_default_params(self, mock_llm_client): + """Test initialization with default parameters.""" + client = mock_llm_client() + initial_history = [ + {"role": "system", "content": "Global policy"}, + {"role": "system", "content": "Agent prompt"}, + ] + + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history + ) + + assert manager.config["working_history_size"] == 4 + assert manager.config["summarization_threshold"] == 20 + assert manager.config["chunk_size_to_summarize"] == 10 + assert len(manager._full_history) == 2 + assert len(manager._pinned_messages) == 2 + + def test_init_with_custom_params(self, mock_llm_client): + """Test initialization with custom parameters.""" + client = mock_llm_client() + initial_history = [{"role": "system", "content": "System"}] + + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history, + working_history_size=8, + summarization_threshold=30, + chunk_size_to_summarize=15 + ) + + assert manager.config["working_history_size"] == 8 + assert manager.config["summarization_threshold"] == 30 + assert manager.config["chunk_size_to_summarize"] == 15 + + def test_init_pins_early_messages(self, mock_llm_client): + """Test that early messages are pinned.""" + client = mock_llm_client() + initial_history = [ + {"role": "system", "content": "Message 1"}, + {"role": "system", "content": "Message 2"}, + {"role": "system", "content": "Message 3"}, + {"role": "user", "content": "Message 4"}, + ] + + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history + ) + + # Should pin first 3 messages + assert len(manager._pinned_messages) == 3 + assert manager._pinned_messages[0]["content"] == "Message 1" + assert manager._pinned_messages[1]["content"] == "Message 2" + assert manager._pinned_messages[2]["content"] == "Message 3" + + def test_init_with_short_history(self, mock_llm_client): + """Test initialization with history shorter than pin count.""" + client = mock_llm_client() + initial_history = [{"role": "system", "content": "Only one"}] + + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history + ) + + # Should only pin what exists + assert len(manager._pinned_messages) == 1 + + +class TestMemoryManagerMessageOperations: + """Test adding and updating messages.""" + + def test_add_message(self, mock_llm_client): + """Test adding a new message.""" + client = mock_llm_client() + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "Init"}] + ) + + manager.add_message("user", "Hello") + manager.add_message("assistant", "Hi there") + + assert len(manager._full_history) == 3 + assert manager._full_history[-2] == {"role": "user", "content": "Hello"} + assert manager._full_history[-1] == {"role": "assistant", "content": "Hi there"} + + def test_add_pivotal_code(self, mock_llm_client): + """Test adding pivotal code.""" + client = mock_llm_client() + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "Init"}] + ) + + code = "def foo():\n return 42" + manager.add_pivotal_code(code) + + assert len(manager._pivotal_code) == 1 + pivotal_msg = manager._pivotal_code[0] + assert pivotal_msg["role"] == "system" + assert "PIVOTAL CODE" in pivotal_msg["content"] + assert code in pivotal_msg["content"] + assert "```python" in pivotal_msg["content"] + + def test_update_system_prompt(self, mock_llm_client): + """Test updating system prompt.""" + client = mock_llm_client() + initial_history = [ + {"role": "system", "content": "Global policy"}, + {"role": "system", "content": "Old agent prompt"}, + ] + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history + ) + + new_prompt = "New agent prompt" + manager.update_system_prompt(new_prompt) + + # Should update second pinned message + assert manager._pinned_messages[1]["content"] == new_prompt + # Should also update in full history + assert manager._full_history[1]["content"] == new_prompt + + def test_update_system_prompt_with_short_history(self, mock_llm_client): + """Test updating system prompt when history is too short.""" + client = mock_llm_client() + initial_history = [{"role": "system", "content": "Only one"}] + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history + ) + + new_prompt = "New prompt" + manager.update_system_prompt(new_prompt) + + # Should insert at index 1 + assert len(manager._pinned_messages) == 2 + assert manager._pinned_messages[1]["content"] == new_prompt + + +class TestMemoryManagerContextAssembly: + """Test context assembly without summarization.""" + + def test_get_context_basic(self, mock_llm_client): + """Test basic context retrieval.""" + client = mock_llm_client() + initial_history = [ + {"role": "system", "content": "System"}, + {"role": "user", "content": "Hello"}, + ] + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history, + working_history_size=4 + ) + + context = manager.get_context() + + # Should return pinned + recent working history + assert len(context) >= 2 + assert context[0]["content"] == "System" + + def test_get_context_with_working_history(self, mock_llm_client): + """Test context includes working history.""" + client = mock_llm_client() + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}], + working_history_size=3 + ) + + # Add several messages + for i in range(10): + manager.add_message("user", f"Message {i}") + + context = manager.get_context() + + # Should include pinned + last 3 messages + # Last 3 should be messages 7, 8, 9 + working_history_part = [msg for msg in context if "Message" in msg.get("content", "")] + assert len(working_history_part) == 3 + assert "Message 8" in working_history_part[-2]["content"] + assert "Message 9" in working_history_part[-1]["content"] + + def test_get_context_with_pivotal_code(self, mock_llm_client): + """Test context includes pivotal code.""" + client = mock_llm_client() + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}] + ) + + manager.add_pivotal_code("x = 1") + manager.add_pivotal_code("y = 2") + + context = manager.get_context() + + # Count system messages with PIVOTAL CODE + pivotal_count = sum(1 for msg in context if "PIVOTAL CODE" in msg.get("content", "")) + assert pivotal_count == 2 + + +class TestMemoryManagerSummarization: + """Test episodic summarization logic.""" + + def test_should_summarize_false_initially(self, mock_llm_client): + """Test that summarization is not triggered initially.""" + client = mock_llm_client() + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}], + summarization_threshold=20, + working_history_size=4 + ) + + # Add a few messages (not enough to trigger) + for i in range(10): + manager.add_message("user", f"Message {i}") + + assert not manager._should_summarize() + + def test_should_summarize_true_when_threshold_exceeded(self, mock_llm_client): + """Test that summarization is triggered when threshold exceeded.""" + client = mock_llm_client() + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}], + summarization_threshold=5, + working_history_size=4 + ) + + # Add enough messages to exceed threshold + # Pinned: 1, Working tail: 4, Threshold: 5 + # Need at least 1 (pinned) + 5 (threshold) + 4 (working) = 10 messages + for i in range(15): + manager.add_message("user", f"Message {i}") + + assert manager._should_summarize() + + def test_summarize_chunk_creates_summary(self, mock_llm_client): + """Test that summarization creates a summary message.""" + client = mock_llm_client(responses=["This is a summary of the conversation."]) + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}], + chunk_size_to_summarize=5, + summarization_threshold=5, + working_history_size=2 + ) + + # Add enough messages to allow summarization + for i in range(12): + manager.add_message("user", f"Message {i}") + + # Manually trigger summarization + manager._summarize_chunk() + + # Should have created one summary + assert len(manager._summarized_log) == 1 + assert "EPISODIC SUMMARY" in manager._summarized_log[0]["content"] + assert "This is a summary" in manager._summarized_log[0]["content"] + + def test_summarize_chunk_calls_llm(self, mock_llm_client): + """Test that summarization calls the LLM.""" + client = mock_llm_client(responses=["Summary"]) + manager = MemoryManager( + llm_client=client, + model_name="test-model", + initial_history=[{"role": "system", "content": "System"}], + chunk_size_to_summarize=3, + working_history_size=2 + ) + + for i in range(10): + manager.add_message("user", f"Message {i}") + + manager._summarize_chunk() + + # Check that LLM was called + assert client.call_count >= 1 + call_kwargs = client.calls[0] + assert call_kwargs["model"] == "test-model" + assert call_kwargs["temperature"] == 0.0 + + def test_get_context_triggers_summarization(self, mock_llm_client): + """Test that get_context triggers summarization when needed.""" + client = mock_llm_client(responses=["Summary"]) + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}], + summarization_threshold=5, + chunk_size_to_summarize=5, + working_history_size=2 + ) + + # Add enough messages to trigger summarization + for i in range(15): + manager.add_message("user", f"Message {i}") + + context = manager.get_context() + + # Should have triggered summarization + assert len(manager._summarized_log) >= 1 + + # Summary should be in context + summary_in_context = any("EPISODIC SUMMARY" in msg.get("content", "") for msg in context) + assert summary_in_context + + def test_summarization_preserves_working_history(self, mock_llm_client): + """Test that working history is not summarized.""" + client = mock_llm_client(responses=["Summary"]) + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}], + chunk_size_to_summarize=5, + working_history_size=3 + ) + + for i in range(15): + manager.add_message("user", f"Message {i}") + + # Force summarization + while manager._should_summarize(): + manager._summarize_chunk() + + # Full history should still contain all messages + assert len(manager._full_history) == 16 # 1 initial + 15 added + + # Last 3 messages should be preserved in context + context = manager.get_context() + working_msgs = [msg for msg in context if "Message" in msg.get("content", "")] + assert len(working_msgs) >= 3 + assert "Message 14" in working_msgs[-1]["content"] + + def test_multiple_summarization_rounds(self, mock_llm_client): + """Test multiple rounds of summarization.""" + client = mock_llm_client(responses=["Summary 1", "Summary 2", "Summary 3"]) + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}], + summarization_threshold=3, + chunk_size_to_summarize=5, + working_history_size=2 + ) + + # Add many messages to trigger multiple summarizations + for i in range(30): + manager.add_message("user", f"Message {i}") + # Trigger summarization periodically + if manager._should_summarize(): + manager._summarize_chunk() + + # Should have multiple summaries + assert len(manager._summarized_log) >= 2 + + +class TestMemoryManagerErrorHandling: + """Test error handling in summarization.""" + + def test_summarization_error_does_not_crash(self, mock_llm_client, capsys): + """Test that summarization errors are handled gracefully.""" + client = mock_llm_client() + # Make the LLM call raise an error + client.chat.completions.create = Mock(side_effect=Exception("API Error")) + + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}], + chunk_size_to_summarize=3, + working_history_size=2 + ) + + for i in range(10): + manager.add_message("user", f"Message {i}") + + # Should not raise exception + try: + manager._summarize_chunk() + except Exception: + pytest.fail("Summarization error should be caught") + + # Should print warning + captured = capsys.readouterr() + assert "Warning" in captured.out or "Could not summarize" in captured.out + + +class TestMemoryManagerContextLayout: + """Test the context layout structure.""" + + def test_context_layout_order(self, mock_llm_client): + """Test that context has correct order: pinned + pivotal + summaries + working.""" + client = mock_llm_client(responses=["Summary"]) + initial_history = [ + {"role": "system", "content": "Global"}, + {"role": "system", "content": "Agent"}, + ] + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=initial_history, + chunk_size_to_summarize=3, + summarization_threshold=3, + working_history_size=2 + ) + + # Add pivotal code + manager.add_pivotal_code("important_code()") + + # Add messages to trigger summarization + for i in range(10): + manager.add_message("user", f"Message {i}") + + context = manager.get_context() + + # Find indices of different sections + pinned_idx = None + pivotal_idx = None + summary_idx = None + working_idx = None + + for i, msg in enumerate(context): + content = msg.get("content", "") + if "Global" in content: + pinned_idx = i + elif "PIVOTAL CODE" in content: + pivotal_idx = i + elif "EPISODIC SUMMARY" in content: + summary_idx = i + elif "Message" in content: + if working_idx is None: + working_idx = i + + # Verify order (indices should increase) + assert pinned_idx is not None + if pivotal_idx is not None: + assert pivotal_idx > pinned_idx + if summary_idx is not None: + assert summary_idx > pivotal_idx if pivotal_idx else summary_idx > pinned_idx + + def test_empty_sections_not_included(self, mock_llm_client): + """Test that empty sections are not included in context.""" + client = mock_llm_client() + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}], + working_history_size=2 + ) + + manager.add_message("user", "Hello") + + context = manager.get_context() + + # Should not have pivotal code or summaries + pivotal_count = sum(1 for msg in context if "PIVOTAL CODE" in msg.get("content", "")) + summary_count = sum(1 for msg in context if "EPISODIC SUMMARY" in msg.get("content", "")) + + assert pivotal_count == 0 + assert summary_count == 0 + + +class TestMemoryManagerEdgeCases: + """Test edge cases.""" + + def test_zero_working_history_size(self, mock_llm_client): + """Test with zero working history size.""" + client = mock_llm_client() + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}], + working_history_size=0 + ) + + manager.add_message("user", "Test") + context = manager.get_context() + + # Should still work, just no working history included + assert len(context) >= 1 + + def test_very_large_working_history(self, mock_llm_client): + """Test with working history larger than actual history.""" + client = mock_llm_client() + manager = MemoryManager( + llm_client=client, + model_name="gpt-4", + initial_history=[{"role": "system", "content": "System"}], + working_history_size=1000 + ) + + manager.add_message("user", "Test") + context = manager.get_context() + + # Should include all messages without error + assert len(context) >= 2 diff --git a/caribou/tests/unit/test_message_utils.py b/caribou/tests/unit/test_message_utils.py new file mode 100644 index 0000000..f6fea5a --- /dev/null +++ b/caribou/tests/unit/test_message_utils.py @@ -0,0 +1,463 @@ +""" +Unit tests for message parsing and routing utilities. + +Tests delegation detection, RAG query detection, and artifact extraction. +""" +import pytest + +from caribou.execution.message_utils import ( + detect_delegation, + detect_rag, + _extract_artifacts_from_msg, + _count_code_blocks, + _code_preview, +) + + +class TestDelegationDetection: + """Test delegation command detection.""" + + def test_detect_simple_delegation(self): + """Test detecting simple delegation command.""" + msg = "I will delegate_to_coder to handle this task." + result = detect_delegation(msg) + assert result == "delegate_to_coder" + + def test_detect_delegation_with_underscores(self): + """Test detecting delegation with underscores in agent name.""" + msg = "Let's delegate_to_my_special_agent for this." + result = detect_delegation(msg) + assert result == "delegate_to_my_special_agent" + + def test_detect_delegation_with_numbers(self): + """Test detecting delegation with numbers in agent name.""" + msg = "Time to delegate_to_agent123 now." + result = detect_delegation(msg) + assert result == "delegate_to_agent123" + + def test_no_delegation(self): + """Test when no delegation command present.""" + msg = "This is just a regular message without delegation." + result = detect_delegation(msg) + assert result is None + + def test_delegation_at_start(self): + """Test delegation command at message start.""" + msg = "delegate_to_planner - we need to plan this better." + result = detect_delegation(msg) + assert result == "delegate_to_planner" + + def test_delegation_at_end(self): + """Test delegation command at message end.""" + msg = "Let me handle this by using delegate_to_executor" + result = detect_delegation(msg) + assert result == "delegate_to_executor" + + def test_multiple_delegations(self): + """Test that first delegation is found when multiple present.""" + msg = "First delegate_to_agent1 then delegate_to_agent2" + result = detect_delegation(msg) + assert result == "delegate_to_agent1" + + def test_delegation_case_sensitive(self): + """Test that delegation is case-sensitive.""" + msg = "Should not match DELEGATE_TO_CODER or Delegate_To_Coder" + result = detect_delegation(msg) + assert result is None + + def test_delegation_with_special_chars_fails(self): + """Test that special characters break the pattern.""" + msg = "This delegate_to_agent-name should not match" + result = detect_delegation(msg) + # Should match "delegate_to_agent" only (stops at hyphen) + assert result == "delegate_to_agent" + + def test_empty_message(self): + """Test with empty message.""" + result = detect_delegation("") + assert result is None + + +class TestRAGDetection: + """Test RAG query detection.""" + + def test_detect_simple_rag_query(self): + """Test detecting simple RAG query.""" + msg = "Let me query_rag_ to find info." + result = detect_rag(msg) + assert result == "search for documentation" + + def test_detect_rag_query_with_spaces(self): + """Test RAG query with spaces.""" + msg = "I need to query_rag_" + result = detect_rag(msg) + assert result == "API authentication methods" + + def test_detect_rag_query_at_start(self): + """Test RAG query at message start.""" + msg = "query_rag_ should help us here." + result = detect_rag(msg) + assert result == "database schema" + + def test_no_rag_query(self): + """Test when no RAG query present.""" + msg = "This is just a regular message." + result = detect_rag(msg) + assert result is None + + def test_rag_query_empty_brackets(self): + """Test RAG query with empty brackets won't match (requires at least one char).""" + msg = "What about query_rag_<> this?" + result = detect_rag(msg) + # The [^>]+ pattern requires at least one character, so empty brackets don't match + assert result is None + + def test_multiple_rag_queries(self): + """Test that first RAG query is found when multiple present.""" + msg = "First query_rag_ then query_rag_" + result = detect_rag(msg) + assert result == "query1" + + def test_rag_query_with_newlines(self): + """Test RAG query can match content with newlines.""" + msg = "query_rag_" + result = detect_rag(msg) + # The [^>]+ pattern matches any char except >, including newlines + assert result == "this has\nnewlines" + + def test_empty_message(self): + """Test with empty message.""" + result = detect_rag("") + assert result is None + + +class TestArtifactExtraction: + """Test extraction of notes and TODOs from messages.""" + + def test_extract_note_prefix(self): + """Test extracting NOTE: prefix.""" + msg = "NOTE: This is an important observation." + notes, todos = _extract_artifacts_from_msg(msg) + assert len(notes) == 1 + assert notes[0] == "This is an important observation." + assert len(todos) == 0 + + def test_extract_todo_prefix(self): + """Test extracting TODO: prefix.""" + msg = "TODO: Implement error handling." + notes, todos = _extract_artifacts_from_msg(msg) + assert len(notes) == 0 + assert len(todos) == 1 + assert todos[0] == "Implement error handling." + + def test_extract_checkbox_unchecked(self): + """Test extracting unchecked checkbox.""" + msg = "- [ ] Complete the documentation" + notes, todos = _extract_artifacts_from_msg(msg) + assert len(todos) == 1 + assert todos[0] == "Complete the documentation" + + def test_extract_checkbox_checked(self): + """Test extracting checked checkbox.""" + msg = "- [x] Write unit tests" + notes, todos = _extract_artifacts_from_msg(msg) + assert len(todos) == 1 + assert todos[0] == "Write unit tests" + + def test_extract_checkbox_checked_uppercase(self): + """Test extracting checked checkbox with uppercase X.""" + msg = "- [X] Deploy to production" + notes, todos = _extract_artifacts_from_msg(msg) + assert len(todos) == 1 + assert todos[0] == "Deploy to production" + + def test_extract_notes_code_fence(self): + """Test extracting notes from code fence.""" + msg = """ +```notes +First note +Second note +Third note +``` + """ + notes, todos = _extract_artifacts_from_msg(msg) + assert len(notes) == 3 + assert "First note" in notes + assert "Second note" in notes + assert "Third note" in notes + + def test_extract_todos_code_fence(self): + """Test extracting TODOs from code fence.""" + msg = """ +```todo +Task 1 +Task 2 +``` + """ + notes, todos = _extract_artifacts_from_msg(msg) + assert len(todos) == 2 + assert "Task 1" in todos + assert "Task 2" in todos + + def test_extract_todos_code_fence_plural(self): + """Test extracting TODOs from code fence with plural.""" + msg = """ +```todos +Task A +Task B +``` + """ + notes, todos = _extract_artifacts_from_msg(msg) + assert len(todos) == 2 + assert "Task A" in todos + assert "Task B" in todos + + def test_extract_mixed_artifacts(self): + """Test extracting mixed notes and TODOs.""" + msg = """ +NOTE: Configuration loaded successfully +TODO: Add validation +- [ ] Write tests +- [x] Update documentation + +```notes +System initialized +``` + +```todo +Refactor code +``` + """ + notes, todos = _extract_artifacts_from_msg(msg) + + assert len(notes) == 2 + assert "Configuration loaded successfully" in notes + assert "System initialized" in notes + + assert len(todos) == 4 + assert "Add validation" in todos + assert "Write tests" in todos + assert "Update documentation" in todos + assert "Refactor code" in todos + + def test_extract_case_insensitive_note(self): + """Test that NOTE: is case-insensitive.""" + msg = "note: Lowercase note" + notes, todos = _extract_artifacts_from_msg(msg) + assert len(notes) == 1 + assert notes[0] == "Lowercase note" + + def test_extract_case_insensitive_todo(self): + """Test that TODO: is case-insensitive.""" + msg = "todo: Lowercase todo" + notes, todos = _extract_artifacts_from_msg(msg) + assert len(todos) == 1 + assert todos[0] == "Lowercase todo" + + def test_extract_empty_code_fence_ignored(self): + """Test that empty code fences are ignored.""" + msg = """ +```notes +``` + +```todo +``` + """ + notes, todos = _extract_artifacts_from_msg(msg) + assert len(notes) == 0 + assert len(todos) == 0 + + def test_extract_with_empty_lines(self): + """Test that empty lines are skipped.""" + msg = """ +NOTE: First note + +NOTE: Second note + + +TODO: First task + +TODO: Second task + """ + notes, todos = _extract_artifacts_from_msg(msg) + assert len(notes) == 2 + assert len(todos) == 2 + + def test_no_artifacts(self): + """Test message with no artifacts.""" + msg = "This is just a regular message with no notes or TODOs." + notes, todos = _extract_artifacts_from_msg(msg) + assert len(notes) == 0 + assert len(todos) == 0 + + def test_empty_message(self): + """Test with empty message.""" + notes, todos = _extract_artifacts_from_msg("") + assert len(notes) == 0 + assert len(todos) == 0 + + +class TestCodeBlockCounting: + """Test code block counting.""" + + def test_count_single_code_block(self): + """Test counting single code block.""" + msg = """ +Here is some code: +```python +print("Hello") +``` + """ + count = _count_code_blocks(msg) + assert count == 1 + + def test_count_multiple_code_blocks(self): + """Test counting multiple code blocks.""" + msg = """ +First block: +```python +x = 1 +``` + +Second block: +``` +y = 2 +``` + +Third block: +```python +z = 3 +``` + """ + count = _count_code_blocks(msg) + assert count == 3 + + def test_count_code_block_without_language(self): + """Test counting code block without language specifier.""" + msg = """ +``` +generic code +``` + """ + count = _count_code_blocks(msg) + assert count == 1 + + def test_count_no_code_blocks(self): + """Test counting when no code blocks present.""" + msg = "This message has no code blocks." + count = _count_code_blocks(msg) + assert count == 0 + + def test_count_empty_message(self): + """Test counting with empty message.""" + count = _count_code_blocks("") + assert count == 0 + + def test_count_none_message(self): + """Test counting with None message.""" + count = _count_code_blocks(None) + assert count == 0 + + def test_count_inline_code_ignored(self): + """Test that inline code is not counted.""" + msg = "This has `inline code` but no blocks." + count = _count_code_blocks(msg) + assert count == 0 + + +class TestCodePreview: + """Test code preview generation.""" + + def test_preview_short_code(self): + """Test preview of short code snippet.""" + code = "x = 1\ny = 2" + preview = _code_preview(code) + assert preview == "x = 1\ny = 2" + + def test_preview_truncate_long_code(self): + """Test that long code is truncated.""" + code = "a" * 300 + preview = _code_preview(code, max_chars=200) + assert len(preview) <= 203 # 200 + "..." + assert preview.endswith("...") + + def test_preview_limit_lines(self): + """Test that number of lines is limited.""" + code = "\n".join([f"line {i}" for i in range(10)]) + preview = _code_preview(code, max_lines=4) + lines = preview.split("\n") + assert len(lines) <= 4 + + def test_preview_strips_empty_lines(self): + """Test that empty lines are stripped.""" + code = "\n\nline 1\n\nline 2\n\n" + preview = _code_preview(code) + assert preview == "line 1\nline 2" + + def test_preview_empty_code(self): + """Test preview of empty code.""" + preview = _code_preview("") + assert preview == "(empty code block)" + + def test_preview_whitespace_only(self): + """Test preview of whitespace-only code.""" + preview = _code_preview(" \n \n ") + assert preview == "(empty code block)" + + def test_preview_default_params(self): + """Test preview with default parameters.""" + code = "\n".join([f"line {i}" for i in range(10)]) + preview = _code_preview(code) + # Should limit to 4 lines by default + lines = [ln for ln in preview.split("\n") if ln.strip()] + assert len(lines) <= 4 + + +class TestEdgeCases: + """Test edge cases across all utilities.""" + + def test_unicode_in_delegation(self): + """Test handling unicode in messages.""" + msg = "Let's delegate_to_agent with 🎯 emoji" + result = detect_delegation(msg) + assert result == "delegate_to_agent" + + def test_unicode_in_artifacts(self): + """Test unicode in artifacts.""" + msg = "NOTE: Handle UTF-8 like café and 日本語" + notes, todos = _extract_artifacts_from_msg(msg) + assert len(notes) == 1 + assert "café" in notes[0] + assert "日本語" in notes[0] + + def test_multiline_mixed_content(self): + """Test complex multiline message.""" + msg = """ +I think we should delegate_to_planner first. + +NOTE: Current approach is working +TODO: Add error handling + +```python +def foo(): + pass +``` + +Then query_rag_ for more info. + +- [ ] Review code +- [x] Write tests + """ + + delegation = detect_delegation(msg) + assert delegation == "delegate_to_planner" + + rag = detect_rag(msg) + assert rag == "find examples" + + notes, todos = _extract_artifacts_from_msg(msg) + assert len(notes) == 1 + assert len(todos) == 3 + + count = _count_code_blocks(msg) + assert count == 1 diff --git a/caribou/tests/unit/test_ollama_wrapper.py b/caribou/tests/unit/test_ollama_wrapper.py new file mode 100644 index 0000000..1f518dd --- /dev/null +++ b/caribou/tests/unit/test_ollama_wrapper.py @@ -0,0 +1,330 @@ +""" +Unit tests for OllamaClient wrapper. + +Tests the OpenAI API compatibility layer and ND-JSON response parsing. +""" +import pytest +import json +from types import SimpleNamespace +from unittest.mock import Mock, patch, MagicMock + +from caribou.core.ollama_wrapper import OllamaClient + + +class TestOllamaClientInitialization: + """Test OllamaClient initialization.""" + + def test_init_with_default_params(self): + """Test initialization with default parameters.""" + client = OllamaClient() + + assert client._host == "http://localhost:11434" + assert client._default_model == "deepseek-r1:70b" + + def test_init_with_custom_host(self): + """Test initialization with custom host.""" + client = OllamaClient(host="http://ollama.local:8080", model="llama3") + + assert client._host == "http://ollama.local:8080" + assert client._default_model == "llama3" + + def test_init_adds_http_prefix(self): + """Test that http:// is added if missing.""" + client = OllamaClient(host="localhost:11434") + + assert client._host == "http://localhost:11434" + + def test_init_strips_trailing_slash(self): + """Test that trailing slashes are removed.""" + client = OllamaClient(host="http://localhost:11434/") + + assert client._host == "http://localhost:11434" + + def test_chat_completions_interface_exists(self): + """Test that the OpenAI-compatible interface exists.""" + client = OllamaClient() + + assert hasattr(client, "chat") + assert hasattr(client.chat, "completions") + assert hasattr(client.chat.completions, "create") + assert callable(client.chat.completions.create) + + +class TestOllamaClientAPICall: + """Test Ollama API call behavior.""" + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_basic_api_call(self, mock_post): + """Test a basic API call with default parameters.""" + # Mock response with ND-JSON format + mock_response = Mock() + mock_response.text = '{"message": {"role": "assistant", "content": "Hello!"}}\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient(host="http://localhost:11434", model="llama3") + messages = [{"role": "user", "content": "Hi"}] + + response = client.chat.completions.create(messages=messages) + + # Verify API call + mock_post.assert_called_once() + call_args = mock_post.call_args + + assert call_args[0][0] == "http://localhost:11434/api/chat" + assert call_args[1]["json"]["model"] == "llama3" + assert call_args[1]["json"]["messages"] == messages + assert call_args[1]["json"]["stream"] is False + assert call_args[1]["timeout"] == 300 + + # Verify response structure + assert hasattr(response, "choices") + assert len(response.choices) == 1 + assert response.choices[0].message.content == "Hello!" + assert response.choices[0].message.role == "assistant" + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_api_call_with_temperature(self, mock_post): + """Test API call with temperature parameter.""" + mock_response = Mock() + mock_response.text = '{"message": {"role": "assistant", "content": "Response"}}\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + messages = [{"role": "user", "content": "Test"}] + + client.chat.completions.create(messages=messages, temperature=0.7) + + call_args = mock_post.call_args + payload = call_args[1]["json"] + + assert "options" in payload + assert payload["options"]["temperature"] == 0.7 + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_api_call_without_temperature(self, mock_post): + """Test API call without temperature (should not include options).""" + mock_response = Mock() + mock_response.text = '{"message": {"role": "assistant", "content": "Response"}}\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + messages = [{"role": "user", "content": "Test"}] + + client.chat.completions.create(messages=messages) + + call_args = mock_post.call_args + payload = call_args[1]["json"] + + assert "options" not in payload + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_extra_kwargs_ignored(self, mock_post): + """Test that extra unknown kwargs don't break the call.""" + mock_response = Mock() + mock_response.text = '{"message": {"role": "assistant", "content": "Response"}}\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + messages = [{"role": "user", "content": "Test"}] + + # Should not raise error for unknown kwargs + client.chat.completions.create( + messages=messages, + unknown_param="ignored", + another_param=123 + ) + + mock_post.assert_called_once() + + +class TestOllamaClientResponseParsing: + """Test ND-JSON response parsing.""" + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_parse_single_line_response(self, mock_post): + """Test parsing single-line ND-JSON response.""" + mock_response = Mock() + mock_response.text = '{"message": {"role": "assistant", "content": "Single line"}}\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + response = client.chat.completions.create(messages=[{"role": "user", "content": "Test"}]) + + assert response.choices[0].message.content == "Single line" + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_parse_multiline_ndjson_response(self, mock_post): + """Test parsing multi-line ND-JSON response (takes first message).""" + mock_response = Mock() + mock_response.text = ( + '{"other": "data"}\n' + '{"message": {"role": "assistant", "content": "First message"}}\n' + '{"message": {"role": "assistant", "content": "Second message"}}\n' + ) + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + response = client.chat.completions.create(messages=[{"role": "user", "content": "Test"}]) + + # Should take the first message found + assert response.choices[0].message.content == "First message" + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_parse_response_with_metadata(self, mock_post): + """Test parsing response that includes metadata lines.""" + mock_response = Mock() + mock_response.text = ( + '{"model": "llama3", "created_at": "2024-01-01T00:00:00Z"}\n' + '{"message": {"role": "assistant", "content": "Hello!"}}\n' + '{"done": true}\n' + ) + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + response = client.chat.completions.create(messages=[{"role": "user", "content": "Test"}]) + + assert response.choices[0].message.content == "Hello!" + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_response_structure_matches_openai(self, mock_post): + """Test that response structure matches OpenAI format.""" + mock_response = Mock() + mock_response.text = '{"message": {"role": "assistant", "content": "Test"}}\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + response = client.chat.completions.create(messages=[{"role": "user", "content": "Test"}]) + + # Check OpenAI-compatible structure + assert hasattr(response, "choices") + assert len(response.choices) == 1 + assert hasattr(response.choices[0], "message") + assert hasattr(response.choices[0], "index") + assert hasattr(response.choices[0], "finish_reason") + assert hasattr(response.choices[0].message, "content") + assert hasattr(response.choices[0].message, "role") + + assert response.choices[0].index == 0 + assert response.choices[0].finish_reason == "stop" + assert response.choices[0].message.role == "assistant" + + +class TestOllamaClientErrorHandling: + """Test error handling scenarios.""" + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_http_error_propagates(self, mock_post): + """Test that HTTP errors are propagated.""" + mock_response = Mock() + mock_response.raise_for_status.side_effect = Exception("HTTP 500") + mock_post.return_value = mock_response + + client = OllamaClient() + + with pytest.raises(Exception, match="HTTP 500"): + client.chat.completions.create(messages=[{"role": "user", "content": "Test"}]) + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_no_message_in_response(self, mock_post): + """Test error when no message object found in response.""" + mock_response = Mock() + mock_response.text = '{"model": "llama3"}\n{"done": true}\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + + with pytest.raises(ValueError, match="No message object found"): + client.chat.completions.create(messages=[{"role": "user", "content": "Test"}]) + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_invalid_json_in_response(self, mock_post): + """Test error when response contains invalid JSON.""" + mock_response = Mock() + mock_response.text = 'not valid json\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + + with pytest.raises(json.JSONDecodeError): + client.chat.completions.create(messages=[{"role": "user", "content": "Test"}]) + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_empty_response(self, mock_post): + """Test error when response is empty.""" + mock_response = Mock() + mock_response.text = '' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + + with pytest.raises(ValueError, match="No message object found"): + client.chat.completions.create(messages=[{"role": "user", "content": "Test"}]) + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_timeout_parameter(self, mock_post): + """Test that timeout is set correctly.""" + mock_response = Mock() + mock_response.text = '{"message": {"role": "assistant", "content": "Test"}}\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + client.chat.completions.create(messages=[{"role": "user", "content": "Test"}]) + + call_args = mock_post.call_args + assert call_args[1]["timeout"] == 300 + + +class TestOllamaClientEdgeCases: + """Test edge cases.""" + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_empty_messages_list(self, mock_post): + """Test with empty messages list.""" + mock_response = Mock() + mock_response.text = '{"message": {"role": "assistant", "content": "Response"}}\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + client.chat.completions.create(messages=[]) + + call_args = mock_post.call_args + assert call_args[1]["json"]["messages"] == [] + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_empty_content_in_message(self, mock_post): + """Test parsing message with empty content.""" + mock_response = Mock() + mock_response.text = '{"message": {"role": "assistant", "content": ""}}\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + response = client.chat.completions.create(messages=[{"role": "user", "content": "Test"}]) + + assert response.choices[0].message.content == "" + + @patch("caribou.core.ollama_wrapper.requests.post") + def test_whitespace_in_response(self, mock_post): + """Test parsing response with extra whitespace.""" + mock_response = Mock() + mock_response.text = '\n\n {"message": {"role": "assistant", "content": "Test"}} \n\n' + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + client = OllamaClient() + response = client.chat.completions.create(messages=[{"role": "user", "content": "Test"}]) + + assert response.choices[0].message.content == "Test" diff --git a/requirements.txt b/requirements.txt index b90eaa9..cebeb5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,10 @@ numpy docker dotenv openai +anthropic jupyter_client nbformat -sentence_transformers \ No newline at end of file +sentence_transformers +pytest>=7.0.0 +pytest-cov>=4.0.0 +requests