Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

> **⚠️ WARNING: Large-scale migrations (especially logs/experiments) can be extremely expensive and operationally risky. This tool includes streaming + resumable migration for high-volume event streams, but TB-scale migrations have not been fully soak-tested in production-like conditions. Use with caution and test on a subset first.**

A Python CLI & library for migrating Braintrust organizations with maximum fidelity, leveraging the official `braintrust-api-py` SDK.
A Python CLI & library for migrating Braintrust organizations with maximum fidelity, using direct HTTP requests (via `httpx`) against the Braintrust REST API.

## Overview

Expand Down
126 changes: 86 additions & 40 deletions braintrust_migrate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import httpx
import structlog
from braintrust_api import AsyncBraintrust

from braintrust_migrate.config import BraintrustOrgConfig, MigrationConfig

Expand Down Expand Up @@ -41,7 +40,7 @@ class BraintrustAPIError(BraintrustClientError):


class BraintrustClient:
"""Thin wrapper around braintrust-api-py AsyncClient with additional features.
"""Thin wrapper around httpx.AsyncClient with additional features.

Provides:
- Health checks and connectivity validation
Expand All @@ -66,7 +65,6 @@ def __init__(
self.org_config = org_config
self.migration_config = migration_config
self.org_name = org_name
self._client: AsyncBraintrust | None = None
self._http_client: httpx.AsyncClient | None = None
self._logger = logger.bind(org=org_name, url=str(org_config.url))
self._org_id: str | None = None
Expand All @@ -86,25 +84,18 @@ async def connect(self) -> None:
Raises:
BraintrustConnectionError: If connection fails.
"""
if self._client is not None:
if self._http_client is not None:
return

try:
self._logger.info("Connecting to Braintrust API")

# Create HTTP client for auxiliary requests
# Create HTTP client for requests
self._http_client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
limits=httpx.Limits(max_connections=20, max_keepalive_connections=5),
)

# Create Braintrust API client
self._client = AsyncBraintrust(
api_key=self.org_config.api_key,
base_url=str(self.org_config.url),
http_client=self._http_client,
)

# Perform health check
await self.health_check()

Expand All @@ -119,14 +110,6 @@ async def connect(self) -> None:

async def close(self) -> None:
"""Close the connection to Braintrust API."""
if self._client is not None:
try:
await self._client.close()
except Exception as e:
self._logger.warning("Error closing Braintrust client", error=str(e))
finally:
self._client = None

if self._http_client is not None:
try:
await self._http_client.aclose()
Expand All @@ -137,6 +120,83 @@ async def close(self) -> None:

self._logger.info("Closed connection to Braintrust API")

async def list_projects(
self,
*,
limit: int | None = None,
project_name: str | None = None,
org_name: str | None = None,
page_size: int = 100,
) -> list[dict[str, Any]]:
"""List projects in this organization.

Uses GET /v1/project and paginates via `starting_after`.

Args:
limit: Optional max number of projects to return (client-side cap).
project_name: Optional exact-name filter (server-side).
org_name: Optional org name filter (server-side).
page_size: Page size to request when paginating.
"""
effective_page_size = page_size
if limit is not None:
effective_page_size = max(1, min(page_size, limit))

projects: list[dict[str, Any]] = []
starting_after: str | None = None

while True:
params: dict[str, Any] = {"limit": effective_page_size}
if starting_after is not None:
params["starting_after"] = starting_after
if project_name is not None:
params["project_name"] = project_name
if org_name is not None:
params["org_name"] = org_name

resp = await self.raw_request("GET", "/v1/project", params=params)
if not isinstance(resp, dict):
raise BraintrustAPIError(f"Unexpected project list response: {type(resp)}")
objs = resp.get("objects")
if not isinstance(objs, list):
raise BraintrustAPIError(
f"Unexpected project list response shape: {resp!r}"
)

batch: list[dict[str, Any]] = []
for obj in objs:
if isinstance(obj, dict):
batch.append(obj)

projects.extend(batch)

if limit is not None and len(projects) >= limit:
return projects[:limit]

# No more pages.
if len(objs) < effective_page_size or not objs:
return projects

# Continue with pagination cursor.
last = batch[-1] if batch else None
last_id = last.get("id") if isinstance(last, dict) else None
if not isinstance(last_id, str) or not last_id:
return projects
starting_after = last_id

async def create_project(
self, *, name: str, description: str | None = None
) -> dict[str, Any]:
"""Create (or return existing) project by name via POST /v1/project."""
payload: dict[str, Any] = {"name": name}
if description:
payload["description"] = description

resp = await self.raw_request("POST", "/v1/project", json=payload)
if not isinstance(resp, dict):
raise BraintrustAPIError(f"Unexpected create project response: {type(resp)}")
return resp

async def raw_request(
self,
method: str,
Expand All @@ -148,9 +208,9 @@ async def raw_request(
) -> Any:
"""Perform a raw HTTP request against the Braintrust API.

This is useful for endpoints that are not ergonomically exposed by the
generated `braintrust-api` client, or when we need tight control over
request/response behavior (e.g. cursor-pagination for large logs).
This is useful when we need tight control over request/response behavior
(e.g. cursor-pagination for large logs) or want to avoid additional SDK
dependencies.

Args:
method: HTTP method (GET/POST/etc).
Expand Down Expand Up @@ -269,20 +329,6 @@ async def get_org_id(self) -> str:
f"{last_err}"
)

@property
def client(self) -> AsyncBraintrust:
"""Get the underlying Braintrust API client.

Returns:
The AsyncBraintrust client instance.

Raises:
BraintrustConnectionError: If not connected.
"""
if self._client is None:
raise BraintrustConnectionError(f"Not connected to {self.org_name}")
return self._client

async def health_check(self) -> dict[str, Any]:
"""Perform health check against the Braintrust API.

Expand All @@ -292,12 +338,12 @@ async def health_check(self) -> dict[str, Any]:
Raises:
BraintrustConnectionError: If health check fails.
"""
if self._client is None:
if self._http_client is None:
raise BraintrustConnectionError(f"Not connected to {self.org_name}")

try:
# Try to list projects as a health check
await self._client.projects.list(limit=1)
# List projects as a lightweight health check.
await self.list_projects(limit=1)

health_data = {
"status": "healthy",
Expand Down
100 changes: 49 additions & 51 deletions braintrust_migrate/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Any, ClassVar, TypeVar, cast

import structlog
from braintrust_api.types import Project

from braintrust_migrate.client import BraintrustClient, create_client_pair
from braintrust_migrate.config import Config
Expand Down Expand Up @@ -304,40 +303,45 @@ async def _discover_projects(
"""
self._logger.info("Discovering projects")

# List projects from source
source_projects = await source_client.with_retry(
"list_source_projects", lambda: source_client.client.projects.list()
# List projects from source (REST: GET /v1/project)
raw_projects = cast(
list[dict[str, Any]],
await source_client.with_retry(
"list_source_projects", lambda: source_client.list_projects()
),
)

projects = []

if source_projects is None:
projects = []
# Convert to list if it's an async iterator
elif hasattr(source_projects, "__aiter__"):
async for project in source_projects:
projects.append(project)
else:
projects = list(source_projects)
# Ensure we have basic required fields.
projects: list[dict[str, Any]] = []
for p in raw_projects:
pid = p.get("id")
name = p.get("name")
if isinstance(pid, str) and pid and isinstance(name, str) and name:
projects.append(p)
else:
self._logger.warning(
"Skipping malformed project record",
project_id=pid,
project_name=name,
)

# Filter projects if project_names is specified
if self.config.project_names:
filtered_projects = []
project_names_set = set(self.config.project_names)

for project in projects:
if project.name in project_names_set:
if project.get("name") in project_names_set:
filtered_projects.append(project)

# Log which projects were found and which were not
found_names = {project.name for project in filtered_projects}
found_names = {cast(str, project.get("name")) for project in filtered_projects}
missing_names = project_names_set - found_names

if missing_names:
self._logger.warning(
"Some specified projects were not found in source organization",
missing_projects=list(missing_names),
available_projects=[p.name for p in projects],
available_projects=[p.get("name") for p in projects],
)

self._logger.info(
Expand All @@ -356,18 +360,18 @@ async def _discover_projects(
dest_project_id = await self._ensure_project_exists(project, dest_client)
project_mappings.append(
{
"source_id": project.id,
"source_id": cast(str, project.get("id")),
"dest_id": dest_project_id,
"name": project.name,
"description": getattr(project, "description", None),
"name": cast(str, project.get("name")),
"description": project.get("description"),
}
)

return project_mappings

async def _ensure_project_exists(
self,
source_project: Project,
source_project: dict[str, Any],
dest_client: BraintrustClient,
) -> str:
"""Ensure a project exists in the destination organization.
Expand All @@ -381,56 +385,50 @@ async def _ensure_project_exists(
"""
try:
# Check if project already exists
dest_projects = await dest_client.with_retry(
"list_dest_projects", lambda: dest_client.client.projects.list()
dest_projects = cast(
list[dict[str, Any]],
await dest_client.with_retry(
"list_dest_projects", lambda: dest_client.list_projects()
),
)

# Convert to list and check if project exists
existing_project = None
if dest_projects is None:
existing_project = None
elif hasattr(dest_projects, "__aiter__"):
async for dest_project in dest_projects:
if dest_project.name == source_project.name:
existing_project = dest_project
break
else:
for dest_project in dest_projects:
if dest_project.name == source_project.name:
existing_project = dest_project
break
existing_project: dict[str, Any] | None = None
for dest_project in dest_projects:
if dest_project.get("name") == source_project.get("name"):
existing_project = dest_project
break

if existing_project:
self._logger.debug(
"Project already exists in destination",
project_name=source_project.name,
dest_id=existing_project.id,
project_name=source_project.get("name"),
dest_id=existing_project.get("id"),
)
return existing_project.id
return cast(str, existing_project.get("id"))

# Create project in destination
create_params = {"name": source_project.name}
description = cast(str | None, getattr(source_project, "description", None))
create_params = {"name": source_project.get("name")}
description = cast(str | None, source_project.get("description"))
if description:
create_params["description"] = description

new_project = cast(
Any,
dict[str, Any],
await dest_client.with_retry(
"create_project",
# braintrust-api client is dynamically generated; use Any to avoid type noise
lambda: cast(Any, dest_client.client.projects).create(
**create_params
lambda: dest_client.create_project(
name=cast(str, create_params["name"]),
description=cast(str | None, create_params.get("description")),
),
),
)

new_project_id = cast(str, new_project.id)
new_project_id = cast(str, new_project.get("id"))

self._logger.info(
"Created project in destination",
project_name=source_project.name,
source_id=source_project.id,
project_name=source_project.get("name"),
source_id=source_project.get("id"),
dest_id=new_project_id,
)

Expand All @@ -439,7 +437,7 @@ async def _ensure_project_exists(
except Exception as e:
self._logger.error(
"Failed to ensure project exists",
project_name=source_project.name,
project_name=source_project.get("name"),
error=str(e),
)
raise
Expand Down
Loading