diff --git a/.github/workflows/braintrust-container.yaml b/.github/workflows/braintrust-container.yaml new file mode 100644 index 0000000000..50c223e297 --- /dev/null +++ b/.github/workflows/braintrust-container.yaml @@ -0,0 +1,137 @@ +name: Braintrust Container + +on: + push: + branches: + - master + paths: + - 'scripts/scheduled/braintrust_*.py' + - 'requirements-braintrust.txt' + - 'build/braintrust/**' + - '.github/workflows/braintrust-container.yaml' + tags: + - 'v*' + pull_request: + paths: + - 'scripts/scheduled/braintrust_*.py' + - 'requirements-braintrust.txt' + - 'build/braintrust/**' + - '.github/workflows/braintrust-container.yaml' + workflow_dispatch: + +concurrency: + group: braintrust-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-dev: + if: ${{ github.event_name == 'pull_request' || (github.event_name == 'push' && !startsWith(github.ref, 'refs/tags/')) || github.event_name == 'workflow_dispatch' }} + name: "Braintrust Image Build (Dev)" + permissions: + contents: 'read' + id-token: 'write' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - id: auth + name: Authenticate to Google Cloud + uses: google-github-actions/auth@v2 + with: + token_format: 'access_token' + workload_identity_provider: 'projects/${{ secrets.DEV_GKE_PROJECT_ID}}/locations/global/workloadIdentityPools/github/providers/github' + service_account: '${{ secrets.DEV_GKE_SA }}' + - name: Login to GAR + uses: docker/login-action@v3 + with: + registry: us-east1-docker.pkg.dev + username: oauth2accesstoken + password: '${{ steps.auth.outputs.access_token }}' + - name: Get branch name + id: branch-raw + uses: tj-actions/branch-names@v5.1 + - name: Format branch name + id: branch-name + run: >- + echo "current_branch="$(echo ${{ steps.branch-raw.outputs.current_branch }} + | awk '{print tolower($0)}' + | sed 's|.*/\([^/]*\)/.*|\1|; t; s|.*|\0|' + | sed 's/[^a-z0-9\.\-]//g') + >> $GITHUB_OUTPUT + - name: Get current date + id: date + run: echo "date=$(date +'%Y%m%d%H%M')" >> $GITHUB_OUTPUT + - name: Generate image metadata + id: meta + uses: docker/metadata-action@v3 + with: + images: | + us-east1-docker.pkg.dev/${{ secrets.DEV_PROJECT }}/containers/sefaria-braintrust-${{ steps.branch-name.outputs.current_branch }} + tags: | + type=ref,event=branch + type=sha,enable=true,priority=100,prefix=sha-,suffix=-${{ steps.date.outputs.date }},format=short + type=sha + flavor: | + latest=true + - name: Build and push + uses: docker/build-push-action@v6 + with: + context: . + push: true + file: ./build/braintrust/Dockerfile + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + + build-prod: + if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') }} + name: "Braintrust Image Build (Prod)" + permissions: + contents: 'read' + id-token: 'write' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - id: auth + name: Authenticate to Google Cloud + uses: google-github-actions/auth@v2 + with: + token_format: 'access_token' + workload_identity_provider: 'projects/${{ secrets.PROD_GKE_PROJECT_ID}}/locations/global/workloadIdentityPools/github/providers/github' + service_account: '${{ secrets.PROD_GKE_SA }}' + - name: Login to GAR + uses: docker/login-action@v3 + with: + registry: us-east1-docker.pkg.dev + username: oauth2accesstoken + password: '${{ steps.auth.outputs.access_token }}' + - name: Get current date + id: date + run: echo "date=$(date +'%Y%m%d%H%M')" >> $GITHUB_OUTPUT + - name: Generate image metadata + id: meta + uses: docker/metadata-action@v3 + with: + images: | + us-east1-docker.pkg.dev/${{ secrets.PROD_GKE_PROJECT }}/containers/${{ secrets.IMAGE_NAME }}-braintrust + tags: | + type=ref,event=tag + type=sha,enable=true,priority=100,prefix=sha-,suffix=-${{ steps.date.outputs.date }},format=short + type=sha + type=semver,pattern={{raw}} + flavor: | + latest=true + - name: Build and push + uses: docker/build-push-action@v6 + with: + context: . + push: true + file: ./build/braintrust/Dockerfile + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} diff --git a/build/braintrust/Dockerfile b/build/braintrust/Dockerfile new file mode 100644 index 0000000000..7feafb7f75 --- /dev/null +++ b/build/braintrust/Dockerfile @@ -0,0 +1,20 @@ +FROM python:3.11-slim + +LABEL org.opencontainers.image.source="https://github.com/Sefaria/Sefaria-Project" +LABEL org.opencontainers.image.description="Braintrust automation scripts for log backup and dataset tagging" + +COPY requirements-braintrust.txt /tmp/requirements.txt +RUN pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt + +RUN useradd --create-home --shell /bin/bash --uid 1000 --gid 0 braintrust + +WORKDIR /app + +RUN mkdir -p /app/scripts /app/shared && chown -R braintrust:braintrust /app + +COPY scripts/scheduled/braintrust_backup_logs.py /app/scripts/braintrust_backup_logs.py +COPY scripts/scheduled/braintrust_tag_and_push.py /app/scripts/braintrust_tag_and_push.py + +USER braintrust + +ENTRYPOINT ["python"] diff --git a/build/ci/production-helm-deploy.sh b/build/ci/production-helm-deploy.sh index c34329fb2e..02ca750375 100755 --- a/build/ci/production-helm-deploy.sh +++ b/build/ci/production-helm-deploy.sh @@ -6,6 +6,7 @@ export WEB_IMAGE="us-east1-docker.pkg.dev/$PROJECT_ID/containers/$IMAGE_NAME-web export NODE_IMAGE="us-east1-docker.pkg.dev/$PROJECT_ID/containers/$IMAGE_NAME-node" export ASSET_IMAGE="us-east1-docker.pkg.dev/$PROJECT_ID/containers/$IMAGE_NAME-asset" export LINKER_IMAGE="us-east1-docker.pkg.dev/$PROJECT_ID/containers/$IMAGE_NAME-linker" +export BRAINTRUST_IMAGE="us-east1-docker.pkg.dev/$PROJECT_ID/containers/$IMAGE_NAME-braintrust" export TAG="$GIT_COMMIT" yq e -i '.web.containerImage.imageRegistry = strenv(WEB_IMAGE)' $1 @@ -18,6 +19,8 @@ yq e -i '.linker.containerImage.tag = strenv(TAG)' $1 yq e -i '.nodejs.containerImage.tag = strenv(TAG)' $1 yq e -i '.nginx.containerImage.tag = strenv(TAG)' $1 yq e -i '.monitor.containerImage.tag = strenv(TAG)' $1 +yq e -i '.cronJobs.braintrust.image.repository = strenv(BRAINTRUST_IMAGE)' $1 +yq e -i '.cronJobs.braintrust.image.tag = strenv(TAG)' $1 helm repo add sefaria-project https://sefaria.github.io/Sefaria-Project helm upgrade -i production sefaria-project/sefaria --version $CHART_VERSION --namespace $NAMESPACE -f $1 --debug --timeout=30m0s diff --git a/build/ci/production-values.yaml b/build/ci/production-values.yaml index b3439d5bf6..d676b389e5 100644 --- a/build/ci/production-values.yaml +++ b/build/ci/production-values.yaml @@ -193,6 +193,21 @@ cronJobs: enabled: true weeklyEmailNotifications: enabled: true + braintrust: + enabled: true + image: + repository: + tag: + backupLogs: + enabled: true + schedule: "0 1 * * 0" + serviceAccount: braintrust-backup-logs + bucket: braintrust-logs + prefix: "logs/" + tagAndPush: + enabled: true + schedule: "0 2 * * *" + serviceAccount: braintrust-tag-push secrets: localSettings: ref: local-settings-secrets-production @@ -200,6 +215,10 @@ secrets: ref: backup-manager-secret-production slackWebhook: ref: slack-webhook-production + braintrust: + ref: braintrust-secret-production + anthropic: + ref: anthropic-api-key-production instrumentation: enabled: false otelEndpoint: "http://otel-collector-collector.monitoring:4317" diff --git a/helm-chart/sefaria/templates/cronjob/braintrust-backup-logs.yaml b/helm-chart/sefaria/templates/cronjob/braintrust-backup-logs.yaml new file mode 100644 index 0000000000..b101ca7659 --- /dev/null +++ b/helm-chart/sefaria/templates/cronjob/braintrust-backup-logs.yaml @@ -0,0 +1,92 @@ +{{- if and .Values.cronJobs.braintrust.enabled .Values.cronJobs.braintrust.backupLogs.enabled }} +--- +apiVersion: batch/v1 +kind: CronJob +metadata: + name: {{ .Values.deployEnv }}-braintrust-backup-logs + labels: + {{- include "sefaria.labels" . | nindent 4 }} +spec: + schedule: "{{ .Values.cronJobs.braintrust.backupLogs.schedule }}" + concurrencyPolicy: Forbid + jobTemplate: + spec: + backoffLimit: 1 + template: + spec: + serviceAccount: {{ .Values.cronJobs.braintrust.backupLogs.serviceAccount }} + initContainers: + # Init container: Query Braintrust logs and create CSV + - name: braintrust-log-exporter + image: "{{ .Values.cronJobs.braintrust.image.repository }}:{{ .Values.cronJobs.braintrust.image.tag }}" + env: + - name: BRAINTRUST_API_KEY + valueFrom: + secretKeyRef: + name: {{ .Values.secrets.braintrust.ref }} + key: api-key + - name: BRAINTRUST_PROJECT_ID + valueFrom: + secretKeyRef: + name: {{ .Values.secrets.braintrust.ref }} + key: project-id + volumeMounts: + - mountPath: /tmp + name: shared-volume + command: ["python"] + args: ["/app/scripts/braintrust_backup_logs.py"] + resources: + requests: + memory: "256Mi" + cpu: "250m" + limits: + memory: "500Mi" + cpu: "1000m" + containers: + # Main container: Upload CSV to GCS bucket + - name: braintrust-log-uploader + image: google/cloud-sdk + volumeMounts: + - mountPath: /tmp + name: shared-volume + env: + - name: BUCKET + value: {{ .Values.cronJobs.braintrust.backupLogs.bucket }} + - name: PREFIX + value: {{ .Values.cronJobs.braintrust.backupLogs.prefix }} + command: ["bash"] + args: + - "-c" + - | + set -e + + # Find the most recent CSV file + CSV_FILE=$(ls -t /tmp/logs_backup_*.csv 2>/dev/null | head -1) + + if [ -z "$CSV_FILE" ]; then + echo "No CSV file found in /tmp" + exit 0 + fi + + FILENAME=$(basename "$CSV_FILE") + DESTINATION="gs://${BUCKET}/${PREFIX}${FILENAME}" + + echo "Uploading $CSV_FILE to $DESTINATION" + gsutil cp "$CSV_FILE" "$DESTINATION" + echo "Upload complete" + + # Cleanup + rm -f "$CSV_FILE" + resources: + requests: + memory: "256Mi" + cpu: "100m" + limits: + memory: "500Mi" + restartPolicy: OnFailure + volumes: + - name: shared-volume + emptyDir: {} + successfulJobsHistoryLimit: 1 + failedJobsHistoryLimit: 2 +{{- end }} diff --git a/helm-chart/sefaria/templates/cronjob/braintrust-tag-and-push.yaml b/helm-chart/sefaria/templates/cronjob/braintrust-tag-and-push.yaml new file mode 100644 index 0000000000..f2b303d2ca --- /dev/null +++ b/helm-chart/sefaria/templates/cronjob/braintrust-tag-and-push.yaml @@ -0,0 +1,64 @@ +{{- if and .Values.cronJobs.braintrust.enabled .Values.cronJobs.braintrust.tagAndPush.enabled }} +--- +apiVersion: batch/v1 +kind: CronJob +metadata: + name: {{ .Values.deployEnv }}-braintrust-tag-and-push + labels: + {{- include "sefaria.labels" . | nindent 4 }} +spec: + schedule: "{{ .Values.cronJobs.braintrust.tagAndPush.schedule }}" + concurrencyPolicy: Forbid + jobTemplate: + spec: + backoffLimit: 1 + template: + spec: + serviceAccount: {{ .Values.cronJobs.braintrust.tagAndPush.serviceAccount }} + securityContext: + fsGroup: 1000 + containers: + - name: braintrust-tag-and-push + image: "{{ .Values.cronJobs.braintrust.image.repository }}:{{ .Values.cronJobs.braintrust.image.tag }}" + env: + - name: BRAINTRUST_API_KEY + valueFrom: + secretKeyRef: + name: {{ .Values.secrets.braintrust.ref }} + key: api-key + - name: BRAINTRUST_PROJECT_ID + valueFrom: + secretKeyRef: + name: {{ .Values.secrets.braintrust.ref }} + key: project-id + - name: ANTHROPIC_API_KEY + valueFrom: + secretKeyRef: + name: {{ .Values.secrets.anthropic.ref }} + key: api-key + - name: BRAINTRUST_SHARED_STORAGE + value: "/shared/braintrust" + volumeMounts: + - mountPath: /shared/braintrust + name: shared-storage + command: ["python"] + args: ["/app/scripts/braintrust_tag_and_push.py", "all"] + resources: + limits: + memory: "3Gi" + cpu: "2000m" + requests: + memory: "1Gi" + cpu: "500m" + restartPolicy: OnFailure + volumes: + - name: shared-storage + {{- if .Values.cronJobs.braintrust.tagAndPush.usePvc }} + persistentVolumeClaim: + claimName: {{ .Values.cronJobs.braintrust.tagAndPush.pvcName }} + {{- else }} + emptyDir: {} + {{- end }} + successfulJobsHistoryLimit: 1 + failedJobsHistoryLimit: 2 +{{- end }} diff --git a/helm-chart/sefaria/values.yaml b/helm-chart/sefaria/values.yaml index dd2fcd9af8..28d1c12286 100644 --- a/helm-chart/sefaria/values.yaml +++ b/helm-chart/sefaria/values.yaml @@ -386,6 +386,17 @@ secrets: # should be commented out and vice-versa. ref: elastic-admin # data: + braintrust: + # Braintrust API credentials (api-key, project-id) + ref: braintrust-secret + # data: + # api-key: "" + # project-id: "" + anthropic: + # Anthropic API key for Claude tagging + ref: anthropic-api-key + # data: + # api-key: "" # Settings for various cronjobs cronJobs: @@ -418,6 +429,23 @@ cronJobs: enabled: false syncMongoProductionData: enabled: false + braintrust: + enabled: false + image: + repository: us-east1-docker.pkg.dev/production-deployment/containers/sefaria-braintrust + tag: latest + backupLogs: + enabled: false + schedule: "0 1 * * 0" + serviceAccount: braintrust-backup-logs + bucket: braintrust-logs + prefix: "logs/" + tagAndPush: + enabled: false + schedule: "0 2 * * *" + serviceAccount: braintrust-tag-push + usePvc: false + pvcName: braintrust-shared-storage localSettings: DEBUG: true diff --git a/requirements-braintrust.txt b/requirements-braintrust.txt new file mode 100644 index 0000000000..35e1a86195 --- /dev/null +++ b/requirements-braintrust.txt @@ -0,0 +1,6 @@ +braintrust==0.15.0 +langchain-anthropic==0.3.22 +requests>=2.31.0 +structlog>=23.2.0 +google-cloud-logging>=3.5.0 +google-cloud-storage>=2.10.0 diff --git a/scripts/scheduled/braintrust_backup_logs.py b/scripts/scheduled/braintrust_backup_logs.py new file mode 100644 index 0000000000..288ca0a72d --- /dev/null +++ b/scripts/scheduled/braintrust_backup_logs.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +""" +Task 1: Backup logs from Braintrust last 7 days to CSV in GCS bucket "braintrust-logs". + +This script is called from the init container of the braintrust-backup-logs cronjob. +The CSV is created in /tmp and will be uploaded by the main container. + +Run weekly (e.g., Sundays). +""" +import sys +import os +import csv +import re +from datetime import datetime, timedelta, timezone + +import structlog +import requests + +logger = structlog.get_logger(__name__) + + +def get_braintrust_api_key(): + """Get Braintrust API key from environment.""" + api_key = os.getenv("BRAINTRUST_API_KEY") + if not api_key: + raise RuntimeError("BRAINTRUST_API_KEY environment variable is required") + return api_key + + +def query_braintrust_logs(days=7): + """ + Query logs from Braintrust using BTQL API. + + Args: + days: Number of days back to retrieve + + Returns: + List of log dicts + """ + logger.info("querying_braintrust_logs", days=days) + + api_key = get_braintrust_api_key() + project_id = os.getenv("BRAINTRUST_PROJECT_ID", "") + + if not project_id: + raise RuntimeError("BRAINTRUST_PROJECT_ID environment variable is required") + + # Validate project_id format (UUID) to prevent BTQL injection + uuid_pattern = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.IGNORECASE) + if not uuid_pattern.match(project_id): + raise RuntimeError(f"BRAINTRUST_PROJECT_ID must be a valid UUID, got: {project_id!r}") + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + # Calculate date range + days_ago = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat() + + # SQL query to get logs from last N days + query = f""" +SELECT * +FROM project_logs('{project_id}', shape => 'traces') +WHERE created >= '{days_ago}' +""" + + try: + response = requests.post( + "https://api.braintrust.dev/btql", + headers=headers, + json={"query": query, "fmt": "json"}, + timeout=60 + ) + response.raise_for_status() + + data = response.json() + logs = data.get("results", []) + + logger.info("braintrust_logs_fetched", count=len(logs)) + return logs + + except requests.exceptions.RequestException as e: + logger.error("query_braintrust_logs_failed", error=str(e), exc_info=True) + raise + + +def logs_to_csv(logs, filepath): + """ + Convert log entries to CSV format. + + Args: + logs: List of log entry dicts + filepath: Path to write CSV file to + """ + if not logs: + logger.warning("no_logs_to_export") + return False + + logger.info("converting_to_csv", count=len(logs), filepath=filepath) + + # Get all unique keys from logs to use as CSV headers + fieldnames = set() + for log in logs: + if isinstance(log, dict): + fieldnames.update(log.keys()) + fieldnames = sorted(list(fieldnames)) + + try: + with open(filepath, 'w', newline='', encoding='utf-8') as f: + writer = csv.DictWriter(f, fieldnames=fieldnames, restval='') + writer.writeheader() + for log in logs: + if isinstance(log, dict): + # Flatten nested dicts to strings + flat_log = {} + for key, val in log.items(): + if isinstance(val, (dict, list)): + flat_log[key] = str(val) + else: + flat_log[key] = val + writer.writerow(flat_log) + + logger.info("csv_created", filepath=filepath, rows=len(logs)) + return True + + except Exception as e: + logger.error("csv_creation_failed", error=str(e), exc_info=True) + raise + + +def main(): + logger.info("starting_braintrust_backup_logs") + + try: + # Step 1: Query Braintrust logs from last 7 days + logs = query_braintrust_logs(days=7) + + if not logs: + logger.warning("no_logs_retrieved") + sys.exit(0) # Don't fail if no logs + + # Step 2: Create CSV file in /tmp (will be uploaded by main container) + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d") + csv_filename = f"/tmp/logs_backup_{timestamp}.csv" + + if logs_to_csv(logs, csv_filename): + logger.info("completed_braintrust_backup_logs", csv_file=csv_filename) + else: + logger.warning("no_csv_created") + sys.exit(0) + + except Exception as e: + logger.error("braintrust_backup_logs_failed", error=str(e), exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/scheduled/braintrust_tag_and_push.py b/scripts/scheduled/braintrust_tag_and_push.py new file mode 100644 index 0000000000..9713e3549d --- /dev/null +++ b/scripts/scheduled/braintrust_tag_and_push.py @@ -0,0 +1,769 @@ +# -*- coding: utf-8 -*- +""" +Braintrust automation: Tag ALL logs with Claude, then push to relevant datasets. + +CORRECT FLOW: +1. Retrieve ALL tags from Braintrust +2. Filter tags: keep only those with "dataset-tagging" in their DESCRIPTION +3. Retrieve ALL logs from last 24 hours (NO filtering) +4. Tag ALL logs using Claude, constrained to use only the filtered tags +5. Save tagged logs to shared storage +6. Retrieve ALL datasets from Braintrust +7. Filter datasets: keep only those with [[relevant_tags: ["a","b"]]] in their DESCRIPTION +8. Match logs to datasets based on relevant_tags and insert (with deduplication) + +Run daily at 2 AM. +""" +import sys +import os +import json +import re +import time +from datetime import datetime, timedelta, timezone + +import structlog +import requests +from langchain_anthropic import ChatAnthropic +import braintrust + +logger = structlog.get_logger(__name__) + +# Constant filter for dataset tagging tags +DATASET_TAGGING_FILTER = "dataset-tagging" + +# Configurable limits for Claude tagging +MAX_LOGS_PER_RUN = int(os.getenv("BRAINTRUST_MAX_LOGS", "500")) +TAGGING_DELAY_SECONDS = float(os.getenv("BRAINTRUST_TAGGING_DELAY", "0.5")) +MAX_TAGS_IN_PROMPT = int(os.getenv("BRAINTRUST_MAX_TAGS_IN_PROMPT", "50")) + +# Shared storage path (from environment variable) +SHARED_STORAGE_PATH = os.getenv("BRAINTRUST_SHARED_STORAGE", "/shared/braintrust") +TAGGED_LOGS_FILE = os.path.join(SHARED_STORAGE_PATH, "tagged_logs.jsonl") + + +def get_braintrust_api_key(): + """Get Braintrust API key from environment.""" + api_key = os.getenv("BRAINTRUST_API_KEY") + if not api_key: + raise RuntimeError("BRAINTRUST_API_KEY environment variable is required") + return api_key + + +def get_anthropic_api_key(): + """Get Anthropic API key from environment.""" + api_key = os.getenv("ANTHROPIC_API_KEY") + if not api_key: + raise RuntimeError("ANTHROPIC_API_KEY environment variable is required") + return api_key + + +def fetch_and_filter_tags(): + """ + Step 1: Fetch ALL tags, then filter for "dataset-tagging". + + IMPORTANT: We look at tag.description and keep only tags whose description + contains "dataset-tagging". These filtered tag names are the ones Claude + will be allowed to assign to logs. + + Returns: + List of tag names (strings) whose description contains "dataset-tagging" + """ + logger.info("fetching_and_filtering_tags") + + api_key = get_braintrust_api_key() + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + try: + response = requests.get( + "https://api.braintrust.dev/v1/project_tag", + headers=headers, + timeout=30 + ) + response.raise_for_status() + + data = response.json() + all_tags = data.get("objects", []) + + # Filter: Check if tag.description contains "dataset-tagging" + # Return: The tag names (these will be used as available tags for Claude) + filtered_tags = [ + tag["name"] for tag in all_tags + if tag.get("description", "") and DATASET_TAGGING_FILTER.lower() in tag.get("description", "").lower() + ] + + logger.info( + "tags_filtered", + total_tags=len(all_tags), + filtered_tags_count=len(filtered_tags), + filtered_tag_names=filtered_tags + ) + + return filtered_tags + + except requests.exceptions.RequestException as e: + logger.error("fetch_tags_failed", error=str(e), exc_info=True) + raise + + +def query_all_logs(hours=24): + """ + Step 2: Query ALL logs from the last N hours (NO filtering by tags). + + We get all logs because we will tag them all with Claude. + + Args: + hours: Number of hours back to retrieve + + Returns: + List of log dicts + """ + logger.info("querying_all_logs", hours=hours) + + api_key = get_braintrust_api_key() + project_id = os.getenv("BRAINTRUST_PROJECT_ID", "") + + if not project_id: + raise RuntimeError("BRAINTRUST_PROJECT_ID environment variable is required") + + # Validate project_id format (UUID) to prevent BTQL injection + uuid_pattern = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.IGNORECASE) + if not uuid_pattern.match(project_id): + raise RuntimeError(f"BRAINTRUST_PROJECT_ID must be a valid UUID, got: {project_id!r}") + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + # Calculate time range + hours_ago = (datetime.now(timezone.utc) - timedelta(hours=hours)).isoformat() + + # SQL query to get ALL logs from last N hours (NO tag filter) + query = f""" +SELECT * +FROM project_logs('{project_id}', shape => 'traces') +WHERE created >= '{hours_ago}' +""" + + try: + response = requests.post( + "https://api.braintrust.dev/btql", + headers=headers, + json={"query": query, "fmt": "json"}, + timeout=60 + ) + response.raise_for_status() + + data = response.json() + logs = data.get("results", []) + + logger.info("all_logs_fetched", count=len(logs)) + return logs + + except requests.exceptions.RequestException as e: + logger.error("query_logs_failed", error=str(e), exc_info=True) + raise + + +def get_claude_client(): + """Initialize Claude client from environment.""" + api_key = get_anthropic_api_key() + + return ChatAnthropic( + model="claude-3-5-haiku-20241022", + temperature=0, + max_tokens=256, + api_key=api_key + ) + + +def tag_log_with_claude(client, log_entry, available_tags): + """ + Step 3: Use Claude to assign relevant tags from available_tags to a log. + + Args: + client: ChatAnthropic client + log_entry: Dict with log data + available_tags: List of valid tag names to choose from (filtered to "dataset-tagging" tags) + + Returns: + List of relevant tags (strings) + """ + message = str(log_entry.get("input", log_entry.get("message", "")))[:500] + output = str(log_entry.get("output", ""))[:500] + log_id = log_entry.get("id", "") + + prompt_tags = available_tags[:MAX_TAGS_IN_PROMPT] + if len(available_tags) > MAX_TAGS_IN_PROMPT: + logger.warning("tags_truncated_for_prompt", total=len(available_tags), used=MAX_TAGS_IN_PROMPT) + tags_str = ", ".join(prompt_tags) + + prompt = f"""Analyze this log entry and assign relevant tags that categorize it. +Select from ONLY these available tags: {tags_str} + +You may select 1-3 tags. If none of the available tags are appropriate, return an empty array. + +Log ID: {log_id} +Input: {message} +Output: {output} + +Return ONLY a JSON array of tags, like: ["tag1", "tag2"] or [] +""" + + try: + response = client.invoke(prompt) + response_text = response.content.strip() + + # Parse JSON response + try: + tags = json.loads(response_text) + except json.JSONDecodeError as e: + logger.warning("invalid_claude_json", error=str(e), log_id=log_id) + return [] + + if isinstance(tags, list): + # Validate that returned tags are in available_tags + return [str(t).strip() for t in tags if str(t).strip() in available_tags] + else: + logger.warning("invalid_claude_response_type", response_type=type(tags).__name__, log_id=log_id) + return [] + + except Exception as e: + logger.error("claude_tagging_error", error=str(e), log_id=log_id) + return [] + + +def tag_all_logs(logs, available_tags): + """ + Step 4: Tag ALL logs using Claude with the filtered available tags. + + Args: + logs: List of ALL log entries + available_tags: List of valid tag names (filtered to "dataset-tagging" tags) + + Returns: + List of logs with 'relevant_tags' field added + """ + if not logs: + logger.warning("no_logs_to_tag") + return [] + + if len(logs) > MAX_LOGS_PER_RUN: + logger.warning("logs_capped", total=len(logs), cap=MAX_LOGS_PER_RUN) + logs = logs[:MAX_LOGS_PER_RUN] + + logger.info("tagging_all_logs", total_logs=len(logs), available_tags_count=len(available_tags)) + client = get_claude_client() + tagged_logs = [] + + for idx, log in enumerate(logs): + tags = tag_log_with_claude(client, log, available_tags) + log["relevant_tags"] = tags + tagged_logs.append(log) + + if (idx + 1) % 10 == 0: + logger.info("tagging_progress", processed=idx + 1, total=len(logs)) + + if TAGGING_DELAY_SECONDS > 0: + time.sleep(TAGGING_DELAY_SECONDS) + + logger.info("completed_tagging", total_logs=len(tagged_logs)) + return tagged_logs + + +def save_tagged_logs(tagged_logs): + """ + Step 5: Save tagged logs to shared storage. + + Args: + tagged_logs: List of tagged log dicts + """ + if not tagged_logs: + logger.warning("no_tagged_logs_to_save") + return + + os.makedirs(SHARED_STORAGE_PATH, exist_ok=True) + + try: + with open(TAGGED_LOGS_FILE, 'w', encoding='utf-8') as f: + for log in tagged_logs: + f.write(json.dumps(log) + '\n') + + logger.info("saved_tagged_logs", file=TAGGED_LOGS_FILE, count=len(tagged_logs)) + + except Exception as e: + logger.error("save_failed", error=str(e), exc_info=True) + raise + + +def init_step(): + """ + Init step: Steps 1-5: Filter tags, query all logs, tag them, save. + """ + logger.info("starting_init_step") + + try: + # Step 1: Filter tags to those with "dataset-tagging" in description + available_tags = fetch_and_filter_tags() + + if not available_tags: + logger.warning("no_dataset_tagging_tags_found") + return + + # Step 2: Query ALL logs from last 24 hours (no filtering) + logs = query_all_logs(hours=24) + + if not logs: + logger.warning("no_logs_retrieved") + return + + # Remove duplicates by log ID + unique_logs = {} + logs_without_id = 0 + for log in logs: + log_id = log.get("id") + if log_id: + unique_logs[log_id] = log + else: + logs_without_id += 1 + if logs_without_id > 0: + logger.warning("logs_missing_id", count=logs_without_id) + logs = list(unique_logs.values()) + + # Step 3-4: Tag all logs with Claude using filtered tags + tagged_logs = tag_all_logs(logs, available_tags) + + # Step 5: Save tagged logs to shared storage + save_tagged_logs(tagged_logs) + + logger.info("completed_init_step", total_tagged=len(tagged_logs)) + + except Exception as e: + logger.error("init_step_failed", error=str(e), exc_info=True) + raise + + +def load_tagged_logs(): + """ + Load tagged logs from shared storage. + + Returns: + List of tagged log dicts + """ + if not os.path.exists(TAGGED_LOGS_FILE): + logger.warning("no_tagged_logs_file", file=TAGGED_LOGS_FILE) + return [] + + logger.info("loading_tagged_logs", file=TAGGED_LOGS_FILE) + logs = [] + + try: + skipped_lines = 0 + with open(TAGGED_LOGS_FILE, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + logs.append(json.loads(line)) + except json.JSONDecodeError as e: + skipped_lines += 1 + logger.warning("skipping_corrupted_jsonl_line", line_num=line_num, error=str(e)) + + if skipped_lines > 0: + logger.warning("jsonl_lines_skipped", skipped=skipped_lines, loaded=len(logs)) + + logger.info("loaded_tagged_logs", count=len(logs)) + return logs + + except Exception as e: + logger.error("load_tagged_logs_failed", error=str(e), exc_info=True) + raise + + +def fetch_all_datasets(): + """ + Step 6: Fetch ALL datasets from Braintrust API. + + Returns: + List of dataset dicts with metadata + """ + logger.info("fetching_all_datasets") + + api_key = get_braintrust_api_key() + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + try: + response = requests.get( + "https://api.braintrust.dev/v1/dataset", + headers=headers, + timeout=30 + ) + response.raise_for_status() + + data = response.json() + datasets = data.get("objects", []) + + logger.info("all_datasets_fetched", count=len(datasets)) + return datasets + + except requests.exceptions.RequestException as e: + logger.error("fetch_datasets_failed", error=str(e), exc_info=True) + raise + + +def extract_relevant_tags_from_description(description): + """ + Extract relevant_tags from dataset description. + + IMPORTANT: Datasets specify which tags they accept using this pattern in their description: + [[relevant_tags: ["tag1", "tag2", "tag3"]]] + + This method parses that pattern and returns the tags. + + Args: + description: Dataset description string + + Returns: + Set of tag names, or empty set if pattern not found + """ + if not description: + return set() + + try: + # Find pattern [[relevant_tags: ["tag1", "tag2"]]] + # More robust pattern: match up to first ]] to avoid nested bracket issues + pattern = r'\[\[relevant_tags:\s*\[([^\]]*)\]\s*\]\]' + match = re.search(pattern, description) + + if not match: + return set() + + # Extract tags from JSON array + tags_str = "[" + match.group(1) + "]" + try: + tags = json.loads(tags_str) + except json.JSONDecodeError as e: + logger.warning("extract_tags_json_error", description=description[:100], error=str(e)) + return set() + + if not isinstance(tags, list): + logger.warning("extract_tags_not_list", description=description[:100], tag_type=type(tags).__name__) + return set() + + return set(str(tag).strip() for tag in tags if tag) + + except Exception as e: + logger.warning("extract_tags_error", description=description[:100], error=str(e)) + return set() + + +def filter_datasets_by_relevant_tags(datasets): + """ + Step 6b: Filter datasets that have [[relevant_tags: [...]]] in their DESCRIPTION. + + Returns: + Dict mapping dataset_id -> {dataset_obj, relevant_tags} + """ + logger.info("filtering_datasets_by_relevant_tags") + + filtered = {} + + for dataset in datasets: + # Look at dataset.description and extract the [[relevant_tags: [...]]] pattern + relevant_tags = extract_relevant_tags_from_description(dataset.get("description", "")) + + if relevant_tags: + filtered[dataset["id"]] = { + "dataset": dataset, + "relevant_tags": relevant_tags + } + + logger.info("datasets_filtered", total=len(datasets), filtered_count=len(filtered)) + return filtered + + +def optimize_matching_order(logs_count, datasets_count): + """ + Determine which collection to iterate first when matching logs to datasets. + + Iterates the smaller collection as the outer loop to reduce total comparisons. + If fewer logs, iterate logs and find matching datasets. + If fewer datasets, iterate datasets and find matching logs. + + Args: + logs_count: Number of logs + datasets_count: Number of datasets + + Returns: + "logs_first" or "datasets_first" + """ + if logs_count <= datasets_count: + return "logs_first" + else: + return "datasets_first" + + +def match_logs_to_datasets(logs, filtered_datasets): + """ + Match logs to datasets based on relevant_tags. + + Optimized matching: iterate through smaller set first. + + Args: + logs: List of tagged logs + filtered_datasets: Dict of dataset_id -> {dataset_obj, relevant_tags} + + Returns: + Dict mapping dataset_id -> list of logs to insert + """ + logger.info("matching_logs_to_datasets", logs_count=len(logs), datasets_count=len(filtered_datasets)) + + # Choose optimal iteration order + strategy = optimize_matching_order(len(logs), len(filtered_datasets)) + + matches = {ds_id: [] for ds_id in filtered_datasets.keys()} + + if strategy == "logs_first": + # Iterate logs, find matching datasets + logger.info("using_logs_first_strategy") + for log in logs: + log_tags = set(log.get("relevant_tags", [])) + + for ds_id, ds_info in filtered_datasets.items(): + dataset_tags = ds_info["relevant_tags"] + + # Check if any log tags match dataset tags (set intersection) + if log_tags & dataset_tags: + matches[ds_id].append(log) + + else: + # Iterate datasets, find matching logs (more efficient) + logger.info("using_datasets_first_strategy") + log_tag_map = {} # Map: tag -> list of logs with that tag + for log in logs: + for tag in log.get("relevant_tags", []): + if tag not in log_tag_map: + log_tag_map[tag] = [] + log_tag_map[tag].append(log) + + for ds_id, ds_info in filtered_datasets.items(): + dataset_tags = ds_info["relevant_tags"] + + # Collect all logs that match any dataset tag (deduplicate by object identity) + seen = set() + matching_logs = [] + for tag in dataset_tags: + if tag in log_tag_map: + for log in log_tag_map[tag]: + if id(log) not in seen: + seen.add(id(log)) + matching_logs.append(log) + + matches[ds_id] = matching_logs + + logger.info("matching_complete", total_matches=sum(len(v) for v in matches.values())) + return matches + + +def get_existing_log_ids_in_dataset(dataset): + """ + Query Braintrust dataset to get IDs of logs already inserted. + + Args: + dataset: Braintrust dataset instance + + Returns: + Set of log IDs that already exist in the dataset + """ + try: + existing_ids = set() + + for row in dataset: + log_id = row.get("id") or (row.get("input", {}).get("id") if isinstance(row.get("input"), dict) else None) + if log_id: + existing_ids.add(str(log_id)) + + return existing_ids + + except Exception as e: + logger.warning("query_dataset_logs_error", error=str(e)) + return set() + + +def push_logs_to_dataset(dataset, logs): + """ + Push logs to a single dataset, deduplicating against existing records. + + Args: + dataset: Braintrust dataset instance + logs: List of logs to insert + + Returns: + (inserted_count, skipped_count) + + Raises: + RuntimeError: If insertion failures exceed 10% of logs + """ + if not logs: + return 0, 0 + + inserted_count = 0 + skipped_count = 0 + failed_count = 0 + failed_ids = [] + + existing_ids = get_existing_log_ids_in_dataset(dataset) + + for log in logs: + log_id = str(log.get("id", "")) + + if not log_id: + logger.warning("skipping_log_without_id", log_keys=list(log.keys())[:5]) + skipped_count += 1 + continue + + if log_id in existing_ids: + skipped_count += 1 + continue + + try: + dataset.insert( + input=log, + expected=None, + metadata={ + "relevant_tags": log.get("relevant_tags", []), + "timestamp": log.get("created", ""), + } + ) + inserted_count += 1 + + except Exception as e: + failed_count += 1 + failed_ids.append(log_id) + logger.error("insert_log_failed", log_id=log_id, error=str(e)) + + # Check for excessive failures + total_attempted = inserted_count + failed_count + if total_attempted > 0: + failure_rate = failed_count / total_attempted + if failure_rate > 0.1: # More than 10% failure rate + raise RuntimeError( + f"Insertion failure rate ({failure_rate:.1%}) exceeds threshold. " + f"Failed logs: {failed_ids[:10]}" + ) + + return inserted_count, skipped_count + + +def push_step(): + """ + Push step: Steps 6-8: Load logs, fetch datasets, match, and insert. + """ + logger.info("starting_push_step") + + try: + # Load tagged logs from shared storage + logs = load_tagged_logs() + + if not logs: + logger.warning("no_tagged_logs_to_push") + return + + # Step 6: Fetch all datasets + datasets = fetch_all_datasets() + + # Step 6b: Filter datasets by [[relevant_tags: [...]]] pattern in description + filtered_datasets = filter_datasets_by_relevant_tags(datasets) + + if not filtered_datasets: + logger.warning("no_datasets_with_relevant_tags") + return + + # Step 7: Match logs to datasets based on relevant_tags + matches = match_logs_to_datasets(logs, filtered_datasets) + + # Step 8: Insert logs to each dataset (with deduplication) + total_inserted = 0 + total_skipped = 0 + failed_datasets = [] + + for ds_id, ds_info in filtered_datasets.items(): + if ds_id not in matches or not matches[ds_id]: + continue + + dataset_obj = ds_info["dataset"] + logs_for_dataset = matches[ds_id] + + logger.info("pushing_to_dataset", dataset_id=ds_id, dataset_name=dataset_obj.get("name"), logs_count=len(logs_for_dataset)) + + try: + # Validate required dataset identifiers + project_name = dataset_obj.get("project_name") + dataset_name = dataset_obj.get("name") + if not project_name or not str(project_name).strip(): + raise ValueError(f"Missing or empty project_name for dataset_id={ds_id}") + if not dataset_name or not str(dataset_name).strip(): + raise ValueError(f"Missing or empty dataset name for dataset_id={ds_id}") + + # Initialize Braintrust dataset + dataset = braintrust.init_dataset( + project=str(project_name).strip(), + name=str(dataset_name).strip() + ) + + inserted, skipped = push_logs_to_dataset(dataset, logs_for_dataset) + total_inserted += inserted + total_skipped += skipped + + logger.info("dataset_push_complete", dataset_id=ds_id, inserted=inserted, skipped=skipped) + + except Exception as e: + failed_datasets.append(ds_id) + logger.error("dataset_push_failed", dataset_id=ds_id, error=str(e)) + + logger.info("completed_push_step", total_inserted=total_inserted, total_skipped=total_skipped) + + if failed_datasets: + raise RuntimeError(f"Push failed for {len(failed_datasets)} dataset(s): {failed_datasets}") + + except Exception as e: + logger.error("push_step_failed", error=str(e), exc_info=True) + raise + + +def main(): + if len(sys.argv) < 2: + print("Usage: braintrust_tag_and_push.py [init|push|all]") + print(" init - Init step: filter tags, query all logs, tag them, save") + print(" push - Push step: load logs, fetch datasets, match, and insert") + print(" all - Run both steps sequentially") + sys.exit(1) + + command = sys.argv[1].lower() + + try: + if command == "init": + init_step() + elif command == "push": + push_step() + elif command == "all": + init_step() + push_step() + else: + print(f"Unknown command: {command}") + sys.exit(1) + + except Exception as e: + logger.error("execution_failed", error=str(e), exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + main()