diff --git a/README.md b/README.md index 1cc109c..88fde16 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ > **⚠️ WARNING: Large-scale migrations (especially logs/experiments) can be extremely expensive and operationally risky. This tool includes streaming + resumable migration for high-volume event streams, but TB-scale migrations have not been fully soak-tested in production-like conditions. Use with caution and test on a subset first.** -A Python CLI & library for migrating Braintrust organizations with maximum fidelity, leveraging the official `braintrust-api-py` SDK. +A Python CLI & library for migrating Braintrust organizations with maximum fidelity, using direct HTTP requests (via `httpx`) against the Braintrust REST API. ## Overview diff --git a/braintrust_migrate/client.py b/braintrust_migrate/client.py index 4cfab2c..cab5fdc 100644 --- a/braintrust_migrate/client.py +++ b/braintrust_migrate/client.py @@ -12,7 +12,6 @@ import httpx import structlog -from braintrust_api import AsyncBraintrust from braintrust_migrate.config import BraintrustOrgConfig, MigrationConfig @@ -41,7 +40,7 @@ class BraintrustAPIError(BraintrustClientError): class BraintrustClient: - """Thin wrapper around braintrust-api-py AsyncClient with additional features. + """Thin wrapper around httpx.AsyncClient with additional features. Provides: - Health checks and connectivity validation @@ -66,7 +65,6 @@ def __init__( self.org_config = org_config self.migration_config = migration_config self.org_name = org_name - self._client: AsyncBraintrust | None = None self._http_client: httpx.AsyncClient | None = None self._logger = logger.bind(org=org_name, url=str(org_config.url)) self._org_id: str | None = None @@ -86,25 +84,18 @@ async def connect(self) -> None: Raises: BraintrustConnectionError: If connection fails. """ - if self._client is not None: + if self._http_client is not None: return try: self._logger.info("Connecting to Braintrust API") - # Create HTTP client for auxiliary requests + # Create HTTP client for requests self._http_client = httpx.AsyncClient( timeout=httpx.Timeout(30.0), limits=httpx.Limits(max_connections=20, max_keepalive_connections=5), ) - # Create Braintrust API client - self._client = AsyncBraintrust( - api_key=self.org_config.api_key, - base_url=str(self.org_config.url), - http_client=self._http_client, - ) - # Perform health check await self.health_check() @@ -119,14 +110,6 @@ async def connect(self) -> None: async def close(self) -> None: """Close the connection to Braintrust API.""" - if self._client is not None: - try: - await self._client.close() - except Exception as e: - self._logger.warning("Error closing Braintrust client", error=str(e)) - finally: - self._client = None - if self._http_client is not None: try: await self._http_client.aclose() @@ -137,6 +120,83 @@ async def close(self) -> None: self._logger.info("Closed connection to Braintrust API") + async def list_projects( + self, + *, + limit: int | None = None, + project_name: str | None = None, + org_name: str | None = None, + page_size: int = 100, + ) -> list[dict[str, Any]]: + """List projects in this organization. + + Uses GET /v1/project and paginates via `starting_after`. + + Args: + limit: Optional max number of projects to return (client-side cap). + project_name: Optional exact-name filter (server-side). + org_name: Optional org name filter (server-side). + page_size: Page size to request when paginating. + """ + effective_page_size = page_size + if limit is not None: + effective_page_size = max(1, min(page_size, limit)) + + projects: list[dict[str, Any]] = [] + starting_after: str | None = None + + while True: + params: dict[str, Any] = {"limit": effective_page_size} + if starting_after is not None: + params["starting_after"] = starting_after + if project_name is not None: + params["project_name"] = project_name + if org_name is not None: + params["org_name"] = org_name + + resp = await self.raw_request("GET", "/v1/project", params=params) + if not isinstance(resp, dict): + raise BraintrustAPIError(f"Unexpected project list response: {type(resp)}") + objs = resp.get("objects") + if not isinstance(objs, list): + raise BraintrustAPIError( + f"Unexpected project list response shape: {resp!r}" + ) + + batch: list[dict[str, Any]] = [] + for obj in objs: + if isinstance(obj, dict): + batch.append(obj) + + projects.extend(batch) + + if limit is not None and len(projects) >= limit: + return projects[:limit] + + # No more pages. + if len(objs) < effective_page_size or not objs: + return projects + + # Continue with pagination cursor. + last = batch[-1] if batch else None + last_id = last.get("id") if isinstance(last, dict) else None + if not isinstance(last_id, str) or not last_id: + return projects + starting_after = last_id + + async def create_project( + self, *, name: str, description: str | None = None + ) -> dict[str, Any]: + """Create (or return existing) project by name via POST /v1/project.""" + payload: dict[str, Any] = {"name": name} + if description: + payload["description"] = description + + resp = await self.raw_request("POST", "/v1/project", json=payload) + if not isinstance(resp, dict): + raise BraintrustAPIError(f"Unexpected create project response: {type(resp)}") + return resp + async def raw_request( self, method: str, @@ -148,9 +208,9 @@ async def raw_request( ) -> Any: """Perform a raw HTTP request against the Braintrust API. - This is useful for endpoints that are not ergonomically exposed by the - generated `braintrust-api` client, or when we need tight control over - request/response behavior (e.g. cursor-pagination for large logs). + This is useful when we need tight control over request/response behavior + (e.g. cursor-pagination for large logs) or want to avoid additional SDK + dependencies. Args: method: HTTP method (GET/POST/etc). @@ -269,20 +329,6 @@ async def get_org_id(self) -> str: f"{last_err}" ) - @property - def client(self) -> AsyncBraintrust: - """Get the underlying Braintrust API client. - - Returns: - The AsyncBraintrust client instance. - - Raises: - BraintrustConnectionError: If not connected. - """ - if self._client is None: - raise BraintrustConnectionError(f"Not connected to {self.org_name}") - return self._client - async def health_check(self) -> dict[str, Any]: """Perform health check against the Braintrust API. @@ -292,12 +338,12 @@ async def health_check(self) -> dict[str, Any]: Raises: BraintrustConnectionError: If health check fails. """ - if self._client is None: + if self._http_client is None: raise BraintrustConnectionError(f"Not connected to {self.org_name}") try: - # Try to list projects as a health check - await self._client.projects.list(limit=1) + # List projects as a lightweight health check. + await self.list_projects(limit=1) health_data = { "status": "healthy", diff --git a/braintrust_migrate/orchestration.py b/braintrust_migrate/orchestration.py index 27e64e3..36614f3 100644 --- a/braintrust_migrate/orchestration.py +++ b/braintrust_migrate/orchestration.py @@ -8,7 +8,6 @@ from typing import Any, ClassVar, TypeVar, cast import structlog -from braintrust_api.types import Project from braintrust_migrate.client import BraintrustClient, create_client_pair from braintrust_migrate.config import Config @@ -304,21 +303,26 @@ async def _discover_projects( """ self._logger.info("Discovering projects") - # List projects from source - source_projects = await source_client.with_retry( - "list_source_projects", lambda: source_client.client.projects.list() + # List projects from source (REST: GET /v1/project) + raw_projects = cast( + list[dict[str, Any]], + await source_client.with_retry( + "list_source_projects", lambda: source_client.list_projects() + ), ) - - projects = [] - - if source_projects is None: - projects = [] - # Convert to list if it's an async iterator - elif hasattr(source_projects, "__aiter__"): - async for project in source_projects: - projects.append(project) - else: - projects = list(source_projects) + # Ensure we have basic required fields. + projects: list[dict[str, Any]] = [] + for p in raw_projects: + pid = p.get("id") + name = p.get("name") + if isinstance(pid, str) and pid and isinstance(name, str) and name: + projects.append(p) + else: + self._logger.warning( + "Skipping malformed project record", + project_id=pid, + project_name=name, + ) # Filter projects if project_names is specified if self.config.project_names: @@ -326,18 +330,18 @@ async def _discover_projects( project_names_set = set(self.config.project_names) for project in projects: - if project.name in project_names_set: + if project.get("name") in project_names_set: filtered_projects.append(project) # Log which projects were found and which were not - found_names = {project.name for project in filtered_projects} + found_names = {cast(str, project.get("name")) for project in filtered_projects} missing_names = project_names_set - found_names if missing_names: self._logger.warning( "Some specified projects were not found in source organization", missing_projects=list(missing_names), - available_projects=[p.name for p in projects], + available_projects=[p.get("name") for p in projects], ) self._logger.info( @@ -356,10 +360,10 @@ async def _discover_projects( dest_project_id = await self._ensure_project_exists(project, dest_client) project_mappings.append( { - "source_id": project.id, + "source_id": cast(str, project.get("id")), "dest_id": dest_project_id, - "name": project.name, - "description": getattr(project, "description", None), + "name": cast(str, project.get("name")), + "description": project.get("description"), } ) @@ -367,7 +371,7 @@ async def _discover_projects( async def _ensure_project_exists( self, - source_project: Project, + source_project: dict[str, Any], dest_client: BraintrustClient, ) -> str: """Ensure a project exists in the destination organization. @@ -381,56 +385,50 @@ async def _ensure_project_exists( """ try: # Check if project already exists - dest_projects = await dest_client.with_retry( - "list_dest_projects", lambda: dest_client.client.projects.list() + dest_projects = cast( + list[dict[str, Any]], + await dest_client.with_retry( + "list_dest_projects", lambda: dest_client.list_projects() + ), ) - # Convert to list and check if project exists - existing_project = None - if dest_projects is None: - existing_project = None - elif hasattr(dest_projects, "__aiter__"): - async for dest_project in dest_projects: - if dest_project.name == source_project.name: - existing_project = dest_project - break - else: - for dest_project in dest_projects: - if dest_project.name == source_project.name: - existing_project = dest_project - break + existing_project: dict[str, Any] | None = None + for dest_project in dest_projects: + if dest_project.get("name") == source_project.get("name"): + existing_project = dest_project + break if existing_project: self._logger.debug( "Project already exists in destination", - project_name=source_project.name, - dest_id=existing_project.id, + project_name=source_project.get("name"), + dest_id=existing_project.get("id"), ) - return existing_project.id + return cast(str, existing_project.get("id")) # Create project in destination - create_params = {"name": source_project.name} - description = cast(str | None, getattr(source_project, "description", None)) + create_params = {"name": source_project.get("name")} + description = cast(str | None, source_project.get("description")) if description: create_params["description"] = description new_project = cast( - Any, + dict[str, Any], await dest_client.with_retry( "create_project", - # braintrust-api client is dynamically generated; use Any to avoid type noise - lambda: cast(Any, dest_client.client.projects).create( - **create_params + lambda: dest_client.create_project( + name=cast(str, create_params["name"]), + description=cast(str | None, create_params.get("description")), ), ), ) - new_project_id = cast(str, new_project.id) + new_project_id = cast(str, new_project.get("id")) self._logger.info( "Created project in destination", - project_name=source_project.name, - source_id=source_project.id, + project_name=source_project.get("name"), + source_id=source_project.get("id"), dest_id=new_project_id, ) @@ -439,7 +437,7 @@ async def _ensure_project_exists( except Exception as e: self._logger.error( "Failed to ensure project exists", - project_name=source_project.name, + project_name=source_project.get("name"), error=str(e), ) raise diff --git a/pyproject.toml b/pyproject.toml index 6577d24..79ae75b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ classifiers = [ ] dependencies = [ - "braintrust-api>=0.6.0", "typer[all]>=0.9.0", "pydantic>=2.5.0", "structlog>=23.2.0", diff --git a/tests/conftest.py b/tests/conftest.py index 8fef682..c315e39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,34 +3,18 @@ from unittest.mock import AsyncMock, Mock import pytest -from braintrust_api import AsyncBraintrust -from braintrust_api.types import ( - ACL, - Dataset, - Experiment, - Group, - Project, - ProjectScore, - ProjectTag, - Prompt, - Role, - SpanIFrame, - View, -) from braintrust_migrate.config import BraintrustOrgConfig, MigrationConfig @pytest.fixture -def org_config(): +def org_config() -> BraintrustOrgConfig: """Create a test organization configuration.""" - return BraintrustOrgConfig( - api_key="test-api-key", url="https://test.braintrust.dev" - ) + return BraintrustOrgConfig(api_key="test-api-key", url="https://test.braintrust.dev") @pytest.fixture -def migration_config(): +def migration_config() -> MigrationConfig: """Create a test migration configuration.""" return MigrationConfig( batch_size=50, @@ -41,94 +25,23 @@ def migration_config(): ) -@pytest.fixture -def mock_source_client(): - """Create a mock source client with common API endpoints.""" +def _make_mock_client() -> Mock: client = Mock() client.with_retry = AsyncMock() - - # Mock the underlying braintrust client - client.client = Mock(spec=AsyncBraintrust) - client.client.projects = Mock() - client.client.projects.list = AsyncMock() - client.client.datasets = Mock() - client.client.datasets.list = AsyncMock() - client.client.experiments = Mock() - client.client.experiments.list = AsyncMock() - client.client.prompts = Mock() - client.client.prompts.list = AsyncMock() - client.client.functions = Mock() - client.client.functions.list = AsyncMock() - client.client.roles = Mock() - client.client.roles.list = AsyncMock() - client.client.groups = Mock() - client.client.groups.list = AsyncMock() - client.client.acls = Mock() - client.client.acls.list_org = AsyncMock() - client.client.span_iframes = Mock() - client.client.span_iframes.list = AsyncMock() - client.client.views = Mock() - client.client.views.list = AsyncMock() - client.client.logs = Mock() - client.client.logs.list = AsyncMock() - client.client.project_tags = Mock() - client.client.project_tags.list = AsyncMock() - client.client.project_tags.create = AsyncMock() - client.client.project_scores = Mock() - client.client.project_scores.list = AsyncMock() - + client.raw_request = AsyncMock() return client @pytest.fixture -def mock_dest_client(): - """Create a mock destination client with common API endpoints.""" - client = Mock() - client.with_retry = AsyncMock() +def mock_source_client() -> Mock: + """Mock source `BraintrustClient`-like object.""" + return _make_mock_client() - # Mock the underlying braintrust client - client.client = Mock(spec=AsyncBraintrust) - client.client.projects = Mock() - client.client.projects.list = AsyncMock() - client.client.projects.create = AsyncMock() - client.client.datasets = Mock() - client.client.datasets.list = AsyncMock() - client.client.datasets.create = AsyncMock() - client.client.experiments = Mock() - client.client.experiments.list = AsyncMock() - client.client.experiments.create = AsyncMock() - client.client.prompts = Mock() - client.client.prompts.list = AsyncMock() - client.client.prompts.create = AsyncMock() - client.client.functions = Mock() - client.client.functions.list = AsyncMock() - client.client.functions.create = AsyncMock() - client.client.roles = Mock() - client.client.roles.list = AsyncMock() - client.client.roles.create = AsyncMock() - client.client.groups = Mock() - client.client.groups.list = AsyncMock() - client.client.groups.create = AsyncMock() - client.client.acls = Mock() - client.client.acls.list_org = AsyncMock() - client.client.acls.create = AsyncMock() - client.client.span_iframes = Mock() - client.client.span_iframes.list = AsyncMock() - client.client.span_iframes.create = AsyncMock() - client.client.views = Mock() - client.client.views.list = AsyncMock() - client.client.views.create = AsyncMock() - client.client.logs = Mock() - client.client.logs.list = AsyncMock() - client.client.logs.create = AsyncMock() - client.client.project_tags = Mock() - client.client.project_tags.list = AsyncMock() - client.client.project_tags.create = AsyncMock() - client.client.project_scores = Mock() - client.client.project_scores.list = AsyncMock() - client.client.project_scores.create = AsyncMock() - return client +@pytest.fixture +def mock_dest_client() -> Mock: + """Mock destination `BraintrustClient`-like object.""" + return _make_mock_client() @pytest.fixture @@ -139,250 +52,6 @@ def temp_checkpoint_dir(tmp_path): return checkpoint_dir -@pytest.fixture -def sample_project(): - """Create a sample Project for testing.""" - return Project( - id="project-123", - name="Test Project", - description="A test project", - user_id="user-456", - created="2024-01-01T00:00:00Z", - deleted_at=None, - ) - - -@pytest.fixture -def sample_dataset(): - """Create a sample Dataset for testing.""" - return Dataset( - id="dataset-123", - project_id="project-456", - name="Test Dataset", - description="A test dataset", - user_id="user-789", - created="2024-01-01T00:00:00Z", - deleted_at=None, - ) - - -@pytest.fixture -def sample_experiment(): - """Create a sample Experiment for testing.""" - return Experiment( - id="experiment-123", - project_id="project-456", - name="Test Experiment", - description="A test experiment", - user_id="user-789", - created="2024-01-01T00:00:00Z", - deleted_at=None, - ) - - -@pytest.fixture -def sample_prompt(): - """Create a sample Prompt for testing.""" - return Prompt( - id="prompt-123", - project_id="project-456", - name="Test Prompt", - description="A test prompt", - prompt_data={"prompt": "Hello, world!"}, - user_id="user-789", - created="2024-01-01T00:00:00Z", - deleted_at=None, - ) - - -@pytest.fixture -def sample_span_iframe(): - """Create a sample SpanIFrame for testing.""" - return SpanIFrame( - id="span-iframe-123", - project_id="project-456", - name="Test Span Iframe", - url="https://example.com/iframe", - description="A test span iframe", - post_message=True, - user_id="user-789", - created="2024-01-01T00:00:00Z", - deleted_at=None, - ) - - -@pytest.fixture -def sample_view(): - """Create a sample View for testing.""" - return View( - id="view-123", - project_id="project-456", - name="Test View", - description="A test view", - view_type="project", - object_type="project", - object_id="project-456", - user_id="user-789", - created="2024-01-01T00:00:00Z", - deleted_at=None, - ) - - -@pytest.fixture -def sample_project_tag(): - """Create a sample project tag for testing.""" - project_tag = Mock(spec=ProjectTag) - project_tag.id = "tag-123" - project_tag.name = "Test Tag" - project_tag.project_id = "project-456" - project_tag.user_id = "user-789" - project_tag.created = "2024-01-01T00:00:00Z" - project_tag.description = "A test project tag" - project_tag.color = "#FF0000" - - # Mock the to_dict method to return a proper dictionary - project_tag.to_dict.return_value = { - "id": "tag-123", - "name": "Test Tag", - "project_id": "project-456", - "user_id": "user-789", - "created": "2024-01-01T00:00:00Z", - "description": "A test project tag", - "color": "#FF0000", - } - - return project_tag - - -@pytest.fixture -def sample_role(): - """Create a sample role for testing.""" - role = Mock(spec=Role) - role.id = "role-123" - role.name = "Test Role" - role.org_id = "org-456" - role.user_id = "user-789" - role.created = "2024-01-01T00:00:00Z" - role.description = "A test role" - role.deleted_at = None - role.member_permissions = [ - {"permission": "read", "restrict_object_type": None}, - {"permission": "create", "restrict_object_type": "project"}, - ] - role.member_roles = None - - # Mock the to_dict method to return a proper dictionary - role.to_dict.return_value = { - "id": "role-123", - "name": "Test Role", - "org_id": "org-456", - "user_id": "user-789", - "created": "2024-01-01T00:00:00Z", - "description": "A test role", - "deleted_at": None, - "member_permissions": [ - {"permission": "read", "restrict_object_type": None}, - {"permission": "create", "restrict_object_type": "project"}, - ], - "member_roles": None, - } - - return role - - -@pytest.fixture -def sample_group(): - """Create a sample group for testing.""" - group = Mock(spec=Group) - group.id = "group-123" - group.name = "Test Group" - group.org_id = "org-456" - group.user_id = "user-789" - group.created = "2024-01-01T00:00:00Z" - group.description = "A test group" - group.deleted_at = None - group.member_groups = None - group.member_users = ["user-123", "user-456"] - - # Mock the to_dict method to return a proper dictionary - group.to_dict.return_value = { - "id": "group-123", - "name": "Test Group", - "org_id": "org-456", - "user_id": "user-789", - "created": "2024-01-01T00:00:00Z", - "description": "A test group", - "deleted_at": None, - "member_groups": None, - "member_users": ["user-123", "user-456"], - } - return group - - -@pytest.fixture -def sample_acl(): - """Create a sample ACL for testing.""" - acl = Mock(spec=ACL) - acl.id = "acl-123" - acl.object_type = "project" - acl.object_id = "project-456" - acl.group_id = "group-789" - acl.user_id = None - acl.permission = "read" - acl.role_id = None - acl.restrict_object_type = None - acl.object_org_id = "org-456" - acl.created = "2024-01-01T00:00:00Z" - - # Mock the to_dict method to return a proper dictionary - acl.to_dict.return_value = { - "id": "acl-123", - "object_type": "project", - "object_id": "project-456", - "group_id": "group-789", - "user_id": None, - "permission": "read", - "role_id": None, - "restrict_object_type": None, - "object_org_id": "org-456", - "created": "2024-01-01T00:00:00Z", - } - - return acl - - -@pytest.fixture -def sample_project_score(): - """Create a sample project score for testing.""" - project_score = Mock(spec=ProjectScore) - project_score.id = "score-123" - project_score.name = "Test Score" - project_score.project_id = "project-456" - project_score.user_id = "user-789" - project_score.created = "2024-01-01T00:00:00Z" - project_score.description = "A test project score" - project_score.score_type = "slider" - project_score.categories = None - project_score.config = None - project_score.position = None - - # Mock the to_dict method to return a proper dictionary - project_score.to_dict.return_value = { - "id": "score-123", - "name": "Test Score", - "project_id": "project-456", - "user_id": "user-789", - "created": "2024-01-01T00:00:00Z", - "description": "A test project score", - "score_type": "slider", - "categories": None, - "config": None, - "position": None, - } - - return project_score - - # Test constants that can be reused across tests TEST_PROJECT_ID = "test-project-123" TEST_DEST_PROJECT_ID = "dest-project-456" diff --git a/tests/integration/test_migration_flow.py b/tests/integration/test_migration_flow.py index b533a4a..170b88f 100644 --- a/tests/integration/test_migration_flow.py +++ b/tests/integration/test_migration_flow.py @@ -3,9 +3,11 @@ import tempfile from contextlib import asynccontextmanager from pathlib import Path +from typing import cast from unittest.mock import AsyncMock, Mock, patch import pytest +from pydantic import HttpUrl from braintrust_migrate.config import BraintrustOrgConfig, Config, MigrationConfig from braintrust_migrate.orchestration import MigrationOrchestrator @@ -23,12 +25,14 @@ def config(temp_checkpoint_dir): """Create a test configuration.""" return Config( source=BraintrustOrgConfig( - api_key="source-key", url="https://source.braintrust.dev" + api_key="source-key", + url=cast(HttpUrl, HttpUrl("https://source.braintrust.dev")), ), destination=BraintrustOrgConfig( - api_key="dest-key", url="https://dest.braintrust.dev" + api_key="dest-key", + url=cast(HttpUrl, HttpUrl("https://dest.braintrust.dev")), ), - migration=MigrationConfig(batch_size=10, max_retries=2), + migration=MigrationConfig(batch_size=10, retry_attempts=2), state_dir=temp_checkpoint_dir, resources=["datasets"], # Only test datasets for simplicity ) @@ -71,58 +75,25 @@ async def mock_create_client_pair(source_config, dest_config, migration_config): source_mock = Mock() dest_mock = Mock() - # Create mock client attributes for mock_client, is_source in [(source_mock, True), (dest_mock, False)]: - mock_client_attr = Mock() - mock_projects = Mock() - mock_projects.list = AsyncMock(return_value=[]) - mock_client_attr.projects = mock_projects - mock_client.client = mock_client_attr - - async def mock_with_retry(op_name, coro_func, is_source=is_source): - # Call the lambda function to get the coroutine, then return mock data - try: - # Call the function to get the coroutine (but don't await it) - coro = coro_func() - if hasattr(coro, "__await__"): - await coro # Consume the coroutine - except Exception: - # If it fails, just ignore - we're mocking anyway - pass - - # Return different data based on operation name and client type - if is_source and "list_source_projects" in op_name: - # Return a mock project object with the expected attributes - mock_project = Mock() - mock_project.id = mock_project_data["id"] - mock_project.name = mock_project_data["name"] - mock_project.description = mock_project_data.get("description") - return [mock_project] - elif not is_source and "list_dest_projects" in op_name: - return [] - elif not is_source and "create_project" in op_name: - # Return a mock project object with the expected attributes - mock_project = Mock() - mock_project.id = mock_project_data["id"] - mock_project.name = mock_project_data["name"] - mock_project.description = mock_project_data.get("description") - return mock_project - elif "list_datasets" in op_name: - # raw_request returns dict with objects key - return {"objects": []} - elif "list_experiments" in op_name: - # raw_request returns dict with objects key - return {"objects": []} - elif "list_functions" in op_name: - # raw_request returns dict with objects key - return {"objects": []} - else: - return [] + # Projects are now handled via REST helpers, not the SDK. + mock_client.list_projects = AsyncMock( + return_value=[mock_project_data] if is_source else [] + ) + mock_client.create_project = AsyncMock(return_value=mock_project_data) + + async def mock_with_retry(_op_name, coro_func): + res = coro_func() + if hasattr(res, "__await__"): + return await res + return res mock_client.with_retry = mock_with_retry - # Mock raw_request to return empty objects async def mock_raw_request(method, path, **kwargs): + # All resource list calls return empty for this integration test. + assert method.upper() in {"GET", "POST"} + assert path.startswith("/v1/") return {"objects": []} mock_client.raw_request = mock_raw_request diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index de9d3f5..db3d085 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -4,7 +4,6 @@ import httpx import pytest -from braintrust_api import AsyncBraintrust from braintrust_migrate.client import BraintrustClient, BraintrustConnectionError from braintrust_migrate.config import BraintrustOrgConfig, MigrationConfig @@ -27,27 +26,13 @@ def migration_config(): """Create a test migration configuration.""" return MigrationConfig( batch_size=50, - max_retries=3, + retry_attempts=3, retry_delay=1.0, - timeout=30.0, + max_concurrent=10, + checkpoint_interval=50, ) -@pytest.fixture -def mock_braintrust_client(): - """Create a mock Braintrust client.""" - client = Mock(spec=AsyncBraintrust) - client.projects = Mock() - client.projects.list = AsyncMock() - client.datasets = AsyncMock() - client.prompts = AsyncMock() - client.tools = AsyncMock() - client.functions = AsyncMock() - client.experiments = AsyncMock() - client.close = AsyncMock() - return client - - @pytest.mark.asyncio class TestBraintrustClient: """Test the BraintrustClient wrapper.""" @@ -59,31 +44,29 @@ async def test_initialization(self, org_config, migration_config): assert client.org_config == org_config assert client.migration_config == migration_config assert client.org_name == "test-org" - assert client._client is None assert client._http_client is None - async def test_context_manager( - self, org_config, migration_config, mock_braintrust_client - ): + async def test_context_manager(self, org_config, migration_config): """Test async context manager behavior.""" - with patch("braintrust_migrate.client.AsyncBraintrust") as mock_bt: - mock_bt.return_value = mock_braintrust_client - - async with BraintrustClient( - org_config, migration_config, "test-org" - ) as client: - assert client.client == mock_braintrust_client - - mock_braintrust_client.close.assert_called_once() + with patch.object( + BraintrustClient, + "raw_request", + new_callable=AsyncMock, + return_value={"objects": []}, + ): + async with BraintrustClient(org_config, migration_config, "test-org") as c: + assert isinstance(c._http_client, httpx.AsyncClient) + # After context manager exit, client should be closed/reset. + assert c._http_client is None async def test_health_check_success(self, org_config, migration_config): """Test successful health check.""" - with patch("braintrust_migrate.client.AsyncBraintrust") as mock_bt: - mock_client = Mock(spec=AsyncBraintrust) - mock_client.projects = Mock() - mock_client.projects.list = AsyncMock(return_value=[]) - mock_bt.return_value = mock_client - + with patch.object( + BraintrustClient, + "raw_request", + new_callable=AsyncMock, + return_value={"objects": []}, + ) as mock_raw_request: client = BraintrustClient(org_config, migration_config, "test-org") await client.connect() @@ -92,30 +75,26 @@ async def test_health_check_success(self, org_config, migration_config): assert result["status"] == "healthy" assert result["projects_accessible"] is True # Health check is called once during connect() and once explicitly - assert mock_client.projects.list.call_count == EXPECTED_HEALTH_CHECK_CALLS + assert mock_raw_request.call_count == EXPECTED_HEALTH_CHECK_CALLS async def test_health_check_failure(self, org_config, migration_config): """Test health check failure.""" - with patch("braintrust_migrate.client.AsyncBraintrust") as mock_bt: - mock_client = Mock(spec=AsyncBraintrust) - mock_client.projects = Mock() - mock_client.projects.list = AsyncMock(side_effect=Exception("API Error")) - mock_bt.return_value = mock_client - - client = BraintrustClient(org_config, migration_config, "test-org") - client._client = mock_client # Set directly to bypass connect() - + client = BraintrustClient(org_config, migration_config, "test-org") + client._http_client = Mock(spec=httpx.AsyncClient) + with patch.object( + client, "raw_request", new=AsyncMock(side_effect=Exception("API Error")) + ): with pytest.raises(BraintrustConnectionError): await client.health_check() async def test_http_client_configuration(self, org_config, migration_config): """Test HTTP client is properly configured.""" - with patch("braintrust_migrate.client.AsyncBraintrust") as mock_bt: - mock_client = Mock(spec=AsyncBraintrust) - mock_client.projects = Mock() - mock_client.projects.list = AsyncMock(return_value=[]) - mock_bt.return_value = mock_client - + with patch.object( + BraintrustClient, + "raw_request", + new_callable=AsyncMock, + return_value={"objects": []}, + ): client = BraintrustClient(org_config, migration_config, "test-org") await client.connect() @@ -126,12 +105,12 @@ async def test_http_client_configuration(self, org_config, migration_config): # Note: httpx.AsyncClient doesn't expose limits as a public attribute # The limits are set during construction but not accessible for testing - async def test_client_property_not_connected(self, org_config, migration_config): - """Test that client property raises error when not connected.""" + async def test_raw_request_not_connected(self, org_config, migration_config): + """Test that raw_request raises error when not connected.""" client = BraintrustClient(org_config, migration_config, "test-org") with pytest.raises(BraintrustConnectionError, match="Not connected"): - _ = client.client + await client.raw_request("GET", "/v1/project") async def test_with_retry(self, org_config, migration_config): """Test the with_retry method.""" diff --git a/uv.lock b/uv.lock index 66b3d44..61b7fd9 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" [[package]] @@ -25,29 +25,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916, upload-time = "2025-03-17T00:02:52.713Z" }, ] -[[package]] -name = "braintrust-api" -version = "0.6.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "distro" }, - { name = "httpx" }, - { name = "pydantic" }, - { name = "sniffio" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/2a/37/f79b259b038a8c0554ed99bd061e4e02b46effadae6742b1c25c24112906/braintrust_api-0.6.0.tar.gz", hash = "sha256:8f33639be6bd80063cb63e3be14f24d29c02860abe77319cda6ea06326ae12f8", size = 185903, upload-time = "2024-11-28T20:16:23.567Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5c/42/b3cbdeda7c6843f6804587a06404e8ddde28e9a3b384fc9222c54dbb5215/braintrust_api-0.6.0-py3-none-any.whl", hash = "sha256:30b67dac746805a1fe0134a4311896eb9f4289b963cd18dd51edc2b099569bdf", size = 268967, upload-time = "2024-11-28T20:16:21.51Z" }, -] - [[package]] name = "braintrust-migrate" version = "0.1.0" source = { editable = "." } dependencies = [ - { name = "braintrust-api" }, { name = "httpx" }, { name = "pydantic" }, { name = "python-dotenv" }, @@ -71,7 +53,6 @@ dev = [ [package.metadata] requires-dist = [ - { name = "braintrust-api", specifier = ">=0.6.0" }, { name = "httpx", specifier = ">=0.25.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.7.0" }, { name = "pydantic", specifier = ">=2.5.0" }, @@ -162,15 +143,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/1a/0b9c32220ad694d66062f571cc5cedfa9997b64a591e8a500bb63de1bd40/coverage-7.8.2-py3-none-any.whl", hash = "sha256:726f32ee3713f7359696331a18daf0c3b3a70bb0ae71141b9d3c52be7c595e32", size = 203623, upload-time = "2025-05-23T11:39:53.846Z" }, ] -[[package]] -name = "distro" -version = "1.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722, upload-time = "2023-12-24T09:54:32.31Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, -] - [[package]] name = "h11" version = "0.16.0"