diff --git a/design/read-query-api.md b/design/read-query-api.md new file mode 100644 index 0000000..824e499 --- /dev/null +++ b/design/read-query-api.md @@ -0,0 +1,444 @@ +# Design: Pluto Read/Query API for Neptune Migration + +**Date:** 2025-02-17 +**Status:** Proposal +**Author:** Andrew (Trainy) + +## Context + +A customer migrating from Neptune to Pluto needs programmatic read/query access to runs. Their workflow: + +1. **Pull a run by ID** — after a training job completes, a separate eval pipeline opens the run +2. **Fetch metadata** — read config values (checkpoint paths, hyperparameters) +3. **Run async evaluation** — using the fetched checkpoint +4. **Append results to the same run** — upload eval metrics + plots at the checkpoint's step + +The Pluto Python SDK (`pluto-ml`) is currently **write-only**. However, the Pluto server already exposes a full set of **REST read endpoints** at `/api/runs/*` with API-key authentication. These endpoints power the MCP server and the OpenAPI spec. We just need a Python client that wraps them. + +## Goals + +- Provide Neptune read/query parity sufficient for migration +- Thin wrapper over existing server REST endpoints (no new backend work) +- Return familiar Python types (dicts, lists, pandas DataFrames) +- Support the customer's two-phase workflow: read run → append eval results +- Consistent with existing SDK patterns (Settings, auth, error handling) + +## Non-Goals + +- NQL (Neptune Query Language) — Pluto has its own filtering model via query params +- Local-only reads from SQLite sync DB — server is the source of truth for reads +- Full `run["path/to/key"]` bracket-access API — too much surface area for V1 +- Streaming/real-time subscriptions + +--- + +## Existing Server Read Endpoints + +The Pluto server already has these REST endpoints (API-key auth via `x-api-key` header): + +| Endpoint | Description | +|---|---| +| `GET /api/runs/list` | List runs with search, tag filtering, pagination | +| `GET /api/runs/details/{runId}` | Full run details (config, tags, status, metadata) | +| `GET /api/runs/details/by-display-id/{displayId}` | Run details by display ID (e.g. "MMP-1") | +| `GET /api/runs/metrics` | Time-series metrics with reservoir sampling | +| `GET /api/runs/files` | File metadata with presigned download URLs | +| `GET /api/runs/logs` | Console logs | +| `GET /api/runs/projects` | List all projects | +| `GET /api/runs/metric-names` | Distinct metric names in a project | +| `GET /api/runs/statistics` | Min/max/mean/stddev + anomaly detection | +| `GET /api/runs/compare` | Compare metrics across runs | +| `GET /api/runs/leaderboard` | Rank runs by metric aggregation | +| `GET /api/runs/auth/validate` | Validate API key | + +These are the same endpoints the Pluto MCP server calls. The Python SDK just needs to wrap them. + +--- + +## Proposed API + +### Module: `pluto.query` + +New module at `pluto/query.py`. All functions are standalone (no global state). Users pass project name and authenticate via the same `PLUTO_API_TOKEN` used for logging. + +### Initialization + +```python +import pluto.query as pq + +# Uses PLUTO_API_TOKEN env var and default server URL +runs = pq.list_runs("my-project") + +# Explicit auth and custom server +client = pq.Client( + api_token="plt_...", + host="https://pluto.example.com", # self-hosted +) +runs = client.list_runs("my-project") +``` + +The module-level functions (`pq.list_runs(...)`) are convenience wrappers that create a `Client` from environment/settings. The `Client` class holds auth + base URL and is reusable. + +### Core API + +#### `Client` class + +```python +class Client: + def __init__( + self, + api_token: str | None = None, # Default: PLUTO_API_TOKEN env var + host: str | None = None, # Default: https://pluto.trainy.ai + ): ... +``` + +Internally creates an `httpx.Client` with `x-api-key` header. Reuses the URL resolution logic from `Settings` (derives `url_api` from `host`). + +#### List projects + +```python +client.list_projects() -> list[dict] +``` + +Returns list of projects with `id`, `name`, `runCount`, `createdAt`, `updatedAt`. + +Maps to: `GET /api/runs/projects` + +#### List runs + +```python +client.list_runs( + project: str, + search: str | None = None, # Full-text search on run name + tags: list[str] | None = None, # Filter by tags (AND logic) + limit: int = 50, # Max 200 +) -> list[dict] +``` + +Returns list of run dicts with `id`, `name`, `displayId`, `status`, `tags`, `config`, `createdAt`, `updatedAt`, `url`. + +Maps to: `GET /api/runs/list?projectName=...&search=...&tags=...&limit=...` + +#### Get run details + +```python +client.get_run( + project: str, + run_id: int | str, # Numeric ID or display ID (e.g. "MMP-1") +) -> dict +``` + +Returns full run details: `id`, `name`, `displayId`, `status`, `tags`, `config`, `systemMetadata`, `logNames`, `createdAt`, `updatedAt`, `url`. + +Maps to: `GET /api/runs/details/{runId}` or `GET /api/runs/details/by-display-id/{displayId}` + +#### Fetch metrics + +```python +client.get_metrics( + project: str, + run_id: int, + metric_names: list[str] | None = None, # None = all metrics + limit: int = 10000, +) -> pd.DataFrame # Columns: metric, step, value, time +``` + +Returns a pandas DataFrame with all requested metric series. If `pandas` is not installed, returns a list of dicts instead. + +Maps to: `GET /api/runs/metrics?runId=...&projectName=...&logName=...` + +#### List metric names + +```python +client.get_metric_names( + project: str, + run_id: int | None = None, # None = all metrics in project + search: str | None = None, # Filter by name substring +) -> list[str] +``` + +Maps to: `GET /api/runs/metric-names?projectName=...&runIds=...&search=...` + +#### Fetch files + +```python +client.get_files( + project: str, + run_id: int, + file_name: str | None = None, # Filter by log name +) -> list[dict] # Each has: fileName, fileType, fileSize, step, time, downloadUrl +``` + +Returns file metadata with presigned download URLs. Users can download with `httpx`/`requests`/`urllib`. + +Maps to: `GET /api/runs/files?runId=...&projectName=...&logName=...` + +#### Download file + +```python +client.download_file( + project: str, + run_id: int, + file_name: str, + destination: str | Path = ".", # Directory or file path +) -> Path # Path to downloaded file +``` + +Convenience method: calls `get_files()`, then downloads via presigned URL. + +#### Fetch logs + +```python +client.get_logs( + project: str, + run_id: int, + log_type: str | None = None, # "info", "error", "warning", "debug", "print" + limit: int = 10000, +) -> list[dict] # Each has: message, logType, time, lineNumber, step +``` + +Maps to: `GET /api/runs/logs?runId=...&projectName=...&logType=...&limit=...` + +#### Statistics + +```python +client.get_statistics( + project: str, + run_id: int, + metric_names: list[str] | None = None, +) -> dict # Per-metric: count, min, max, mean, stddev, first, last +``` + +Maps to: `GET /api/runs/statistics?runId=...&projectName=...` + +#### Compare runs + +```python +client.compare_runs( + project: str, + run_ids: list[int], + metric_name: str, +) -> dict # Per-run stats + best run recommendation +``` + +Maps to: `GET /api/runs/compare?runIds=...&projectName=...&logName=...` + +#### Leaderboard + +```python +client.leaderboard( + project: str, + metric_name: str, + aggregation: str = "LAST", # MIN, MAX, AVG, LAST, VARIANCE + direction: str = "ASC", # ASC or DESC + limit: int = 50, +) -> list[dict] # Ranked runs with metric values +``` + +Maps to: `GET /api/runs/leaderboard?projectName=...&logName=...&aggregation=...` + +--- + +### Neptune Migration: Resume Run for Writing + +The customer's workflow requires opening an existing run and appending data. This requires a change to `pluto.init()`, not the query module. + +#### Current behavior + +```python +# Creates a NEW run, or resumes if run_id matches an existing external_id +run = pluto.init(project="my-project", run_id="my-external-id") +``` + +The `run_id` parameter is the **external ID** — a user-provided string for multi-node coordination. The server returns `resumed=true` if a run with that `run_id` already exists. + +#### Proposed: `with_id` parameter + +```python +# Resume an existing run by its server-assigned numeric ID +run = pluto.init(project="my-project", with_id=12345) + +# Resume by display ID +run = pluto.init(project="my-project", with_id="MMP-1") + +# Read-only mode (no sync process, no system monitoring) +run = pluto.init(project="my-project", with_id="MMP-1", mode="read-only") +``` + +**Implementation:** + +1. Add `with_id: int | str | None = None` and `mode: str = "async"` parameters to `pluto.init()` +2. When `with_id` is provided: + - If `mode="read-only"`: return a lightweight `ReadOnlyRun` proxy that wraps `Client.get_run()`, `Client.get_metrics()`, etc. No background workers, no sync process, no system monitoring. + - If `mode="async"` (default): call `/api/runs/create` with the server ID to resume the run for writing. Start sync process and background workers as usual. The `name` parameter is ignored on resume. +3. The server's `/api/runs/create` endpoint already supports resumption — it just needs to also accept the server-assigned run ID (not just external ID) as a lookup key. **This is the only backend change needed.** + +#### `ReadOnlyRun` class + +```python +class ReadOnlyRun: + """Lightweight read-only run proxy. No background workers.""" + + def __init__(self, client: Client, project: str, run_id: int): ... + + @property + def id(self) -> int: ... + + @property + def name(self) -> str: ... + + @property + def config(self) -> dict: ... + + @property + def tags(self) -> list[str]: ... + + @property + def status(self) -> str: ... + + def fetch(self, key: str) -> Any: + """Fetch a config value by key path. e.g. run.fetch('checkpoint_path')""" + ... + + def fetch_metrics(self, metric_names=None, limit=10000) -> pd.DataFrame: + """Fetch metric series as a DataFrame.""" + ... + + def fetch_files(self, file_name=None) -> list[dict]: + """Fetch file metadata with download URLs.""" + ... + + def download(self, file_name: str, destination=".") -> Path: + """Download a file artifact.""" + ... +``` + +--- + +### Neptune Compat Layer Update + +Update `pluto/compat/neptune.py` to support read operations through the compat wrapper: + +```python +# Neptune pattern: +run = neptune.init_run(with_id="PROJ-123", mode="read-only") +checkpoint = run["model/checkpoint_path"].fetch() + +# With updated compat layer, this translates to: +# → pluto.init(with_id="PROJ-123", mode="read-only") +# → run.fetch("model/checkpoint_path") +``` + +This is a lower priority — the customer can use `pluto.query` directly during migration. The compat layer can be updated later for teams that want a drop-in replacement. + +--- + +## Implementation Plan + +### Phase 1: `pluto.query` module (no backend changes) + +1. **`pluto/query.py`** — `Client` class wrapping all `GET /api/runs/*` endpoints +2. **Module-level convenience functions** — `pluto.query.list_runs(...)`, etc. +3. **Export from `pluto/__init__.py`** — `import pluto.query` works +4. **Tests** — unit tests with mocked HTTP, integration tests against staging +5. **Dependencies** — `httpx` (already a dependency), `pandas` optional + +This phase requires **zero backend changes**. All endpoints already exist. + +### Phase 2: Resume run by server ID + +1. **Backend change** — `/api/runs/create` accepts server-assigned run ID for resumption +2. **`pluto.init(with_id=...)`** — add parameter, wire to backend +3. **`ReadOnlyRun` class** — lightweight proxy for read-only mode +4. **Tests** — resume by numeric ID, resume by display ID, read-only mode + +### Phase 3: Neptune compat layer read support (optional) + +1. **`NeptuneRunWrapper`** — support `run["key"].fetch()` bracket access +2. **`init_run(with_id=..., mode="read-only")`** interception +3. **`init_project().fetch_runs_table()`** interception + +--- + +## Customer's Workflow with Proposed API + +```python +import pluto +import pluto.query as pq + +# ---- Phase 1: Read from completed training run ---- + +# Get the run (using display ID from dashboard URL) +run_details = pq.get_run("my-project", "MMP-42") + +# Get checkpoint path from config +checkpoint_path = run_details["config"]["checkpoint_path"] +training_step = run_details["config"]["total_steps"] + +# Get the best validation loss +metrics = pq.get_metrics("my-project", run_details["id"], metric_names=["val/loss"]) +best_step = metrics.loc[metrics["value"].idxmin(), "step"] + +# ---- Phase 2: Run async evaluation ---- +# (load model from checkpoint_path, run eval, produce results) +eval_results = run_evaluation(checkpoint_path) + +# ---- Phase 3: Append eval results to the same run ---- + +# Resume run for writing (Phase 2 feature, requires backend change) +run = pluto.init(project="my-project", with_id=run_details["id"]) + +# Log eval metrics at the training step +run.log({ + "eval/accuracy": eval_results["accuracy"], + "eval/f1": eval_results["f1"], +}, step=training_step) + +# Upload eval plots +run.log({ + "eval/heatmap": pluto.Image("heatmap.png"), + "eval/forecast": pluto.Image("forecast.png"), +}, step=training_step) + +run.finish() +``` + +--- + +## Alternatives Considered + +### A. JSON-RPC 2.0 client wrapping MCP server + +Claude (the customer's AI assistant) suggested building a JSON-RPC client to call the MCP server at `pluto-mcp.trainy.ai`. This was rejected because: + +- The MCP server is in alpha with no public API docs +- The MCP server itself wraps the same REST endpoints we'd call directly +- Adding a JSON-RPC layer adds latency and complexity for no benefit +- The REST endpoints have a documented OpenAPI spec + +### B. tRPC client + +The server's tRPC routes have richer filtering (metric-based sorting, cross-DB joins). However: + +- tRPC uses session auth (cookie-based), not API keys +- tRPC routes are an internal frontend API, not a stable public contract +- The REST OpenAPI endpoints provide sufficient functionality for V1 +- If needed, we can add tRPC support later for advanced filtering + +### C. GraphQL / custom query language + +Over-engineered for V1. The REST endpoints with query parameters cover the customer's use case. A query language can be added later if demand exists. + +--- + +## Open Questions + +1. **Display ID format** — Is the display ID (e.g., "MMP-1") stable and unique within a project? Can we reliably use it for `with_id`? + +2. **Run resume by server ID** — Does the `/api/runs/create` endpoint need modification to accept a server-assigned numeric ID for resumption, or can we use the existing `run_id` (external ID) path by setting `PLUTO_RUN_ID` to the server ID? + +3. **Pagination for metrics** — The `/api/runs/metrics` endpoint uses reservoir sampling (2000 points per metric). Should `get_metrics()` support fetching the full unsampled series? This may require a new backend endpoint or a `full=true` parameter. + +4. **File download auth** — Are the presigned download URLs from `/api/runs/files` sufficient, or do they expire too quickly for batch download workflows? + +5. **Rate limiting** — Are the REST endpoints rate-limited? Should the `Client` class implement backoff? diff --git a/pluto/__init__.py b/pluto/__init__.py index d691a83..1e6599c 100644 --- a/pluto/__init__.py +++ b/pluto/__init__.py @@ -2,6 +2,7 @@ import subprocess from typing import Any, Callable, List, Optional +from . import query from .auth import login, logout from .data import Data, Graph, Histogram, Table from .file import Artifact, Audio, File, Image, Text, Video @@ -35,6 +36,7 @@ 'watch', 'finish', 'setup', + 'query', ) __version__ = '0.0.6' diff --git a/pluto/query.py b/pluto/query.py new file mode 100644 index 0000000..25dd014 --- /dev/null +++ b/pluto/query.py @@ -0,0 +1,714 @@ +""" +Read/query API for Pluto runs. + +Provides programmatic access to runs, metrics, files, and logs stored in Pluto. +Wraps the Pluto server REST API (``/api/runs/*``) with API-key authentication. + +Usage:: + + import pluto.query as pq + + # Uses PLUTO_API_TOKEN env var and default server + runs = pq.list_runs("my-project") + run = pq.get_run("my-project", "MMP-42") + metrics = pq.get_metrics("my-project", run["id"], metric_names=["val/loss"]) + + # Explicit auth / custom server + client = pq.Client(api_token="plt_...", host="https://pluto.example.com") + runs = client.list_runs("my-project") +""" + +import logging +import os +import time +import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import httpx + +logger = logging.getLogger(f'{__name__.split(".")[0]}') +tag = 'Query' + +_DEFAULT_URL_API = 'https://pluto-api.trainy.ai' +_DEFAULT_TIMEOUT = 30 +_RETRY_MAX = 4 +_RETRY_WAIT_MIN = 0.5 +_RETRY_WAIT_MAX = 4.0 + + +class QueryError(Exception): + """Raised when a query to the Pluto server fails.""" + + def __init__(self, message: str, status_code: Optional[int] = None): + self.status_code = status_code + super().__init__(message) + + +class Client: + """HTTP client for reading data from the Pluto server. + + Args: + api_token: API token for authentication. Defaults to ``PLUTO_API_TOKEN`` + environment variable. + host: Server URL (e.g. ``https://pluto.trainy.ai`` or + ``https://pluto-api.trainy.ai``). When a bare hostname or + ``host:port`` is given (matching the ``pluto.init(host=...)`` + pattern), the API URL is derived as ``http://{host}:3001``. + Defaults to ``PLUTO_URL_API`` env var or the production API URL. + timeout: HTTP request timeout in seconds. + """ + + def __init__( + self, + api_token: Optional[str] = None, + host: Optional[str] = None, + timeout: int = _DEFAULT_TIMEOUT, + ) -> None: + self._api_token = api_token or _resolve_api_token() + if not self._api_token: + raise QueryError( + 'No API token provided. Set PLUTO_API_TOKEN environment variable ' + 'or pass api_token to Client().' + ) + + self._url_api = _resolve_url_api(host) + self._client = httpx.Client( + headers={ + 'x-api-key': self._api_token, + 'User-Agent': 'pluto-query', + }, + timeout=httpx.Timeout(timeout), + ) + + def close(self) -> None: + """Close the underlying HTTP client.""" + self._client.close() + + def __enter__(self) -> 'Client': + return self + + def __exit__(self, *exc: Any) -> None: + self.close() + + # ------------------------------------------------------------------ + # Projects + # ------------------------------------------------------------------ + + def list_projects(self) -> List[Dict[str, Any]]: + """List all projects in the organization. + + Returns: + List of project dicts with keys: ``id``, ``name``, ``runCount``, + ``createdAt``, ``updatedAt``. + """ + return self._get('/api/runs/projects') + + # ------------------------------------------------------------------ + # Runs + # ------------------------------------------------------------------ + + def list_runs( + self, + project: str, + search: Optional[str] = None, + tags: Optional[List[str]] = None, + limit: int = 50, + ) -> List[Dict[str, Any]]: + """List runs in a project. + + Args: + project: Project name. + search: Full-text search on run name. + tags: Filter by tags (AND logic). Only runs matching *all* + specified tags are returned. + limit: Maximum number of runs to return (max 200). + + Returns: + List of run dicts with keys: ``id``, ``name``, ``displayId``, + ``status``, ``tags``, ``config``, ``createdAt``, ``updatedAt``, + ``url``. + """ + params: Dict[str, Any] = {'projectName': project, 'limit': min(limit, 200)} + if search is not None: + params['search'] = search + if tags is not None: + params['tags'] = ','.join(tags) + return self._get('/api/runs/list', params=params) + + def get_run( + self, + project: str, + run_id: Union[int, str], + ) -> Dict[str, Any]: + """Get full details for a single run. + + Args: + project: Project name. + run_id: Numeric server ID (``int``) or display ID string + (e.g. ``"MMP-1"``). + + Returns: + Run dict with keys: ``id``, ``name``, ``displayId``, ``status``, + ``tags``, ``config``, ``systemMetadata``, ``logNames``, + ``createdAt``, ``updatedAt``, ``url``. + """ + if isinstance(run_id, int): + return self._get( + f'/api/runs/details/{run_id}', + params={'projectName': project}, + ) + else: + return self._get( + f'/api/runs/details/by-display-id/{run_id}', + params={'projectName': project}, + ) + + # ------------------------------------------------------------------ + # Metrics + # ------------------------------------------------------------------ + + def get_metric_names( + self, + project: str, + run_ids: Optional[List[int]] = None, + search: Optional[str] = None, + limit: int = 500, + ) -> List[str]: + """List distinct metric names. + + Args: + project: Project name. + run_ids: Restrict to these run IDs. ``None`` returns metrics + across the whole project. + search: Substring filter on metric name. + limit: Maximum number of names (max 500). + + Returns: + List of metric name strings. + """ + params: Dict[str, Any] = {'projectName': project, 'limit': min(limit, 500)} + if run_ids is not None: + params['runIds'] = ','.join(str(r) for r in run_ids) + if search is not None: + params['search'] = search + return self._get('/api/runs/metric-names', params=params) + + def get_metrics( + self, + project: str, + run_id: int, + metric_names: Optional[List[str]] = None, + limit: int = 10000, + ) -> Any: + """Fetch time-series metric data for a run. + + The server returns up to ``limit`` data points per metric, sampled + via reservoir sampling when the full series exceeds the limit. + + When *pandas* is installed the return value is a + :class:`~pandas.DataFrame` with columns ``metric``, ``step``, + ``value``, ``time``. Otherwise a list of dicts is returned. + + Args: + project: Project name. + run_id: Numeric server ID. + metric_names: Metric names to fetch. ``None`` fetches all. + limit: Max data points per metric (max 10 000). + + Returns: + ``pandas.DataFrame`` or ``list[dict]``. + """ + params: Dict[str, Any] = { + 'runId': run_id, + 'projectName': project, + 'limit': min(limit, 10000), + } + + if metric_names is not None and len(metric_names) > 1: + # Endpoint only supports a single logName filter, so fetch + # each metric individually and merge the results. + raw: list = [] + for name in metric_names: + p = dict(params) + p['logName'] = name + raw.extend(self._get('/api/runs/metrics', params=p)) + else: + if metric_names is not None and len(metric_names) == 1: + params['logName'] = metric_names[0] + raw = self._get('/api/runs/metrics', params=params) + + return _to_dataframe(raw) + + # ------------------------------------------------------------------ + # Statistics / comparison + # ------------------------------------------------------------------ + + def get_statistics( + self, + project: str, + run_id: int, + metric_names: Optional[List[str]] = None, + ) -> Any: + """Compute statistics for run metrics. + + Returns per-metric aggregations: ``count``, ``min``, ``max``, + ``mean``, ``stddev``, as well as anomaly detection data. + + Args: + project: Project name. + run_id: Numeric server ID. + metric_names: Restrict to these metrics. + + Returns: + Server response dict. + """ + params: Dict[str, Any] = {'runId': run_id, 'projectName': project} + if metric_names is not None and len(metric_names) == 1: + params['logName'] = metric_names[0] + return self._get('/api/runs/statistics', params=params) + + def compare_runs( + self, + project: str, + run_ids: List[int], + metric_name: str, + ) -> Dict[str, Any]: + """Compare a metric across multiple runs. + + Args: + project: Project name. + run_ids: List of numeric run IDs (max 100). + metric_name: The metric to compare. + + Returns: + Dict with per-run statistics and a ``bestRun`` recommendation. + """ + params: Dict[str, Any] = { + 'runIds': ','.join(str(r) for r in run_ids[:100]), + 'projectName': project, + 'logName': metric_name, + } + return self._get('/api/runs/compare', params=params) + + def leaderboard( + self, + project: str, + metric_name: str, + aggregation: str = 'LAST', + direction: str = 'ASC', + limit: int = 50, + offset: int = 0, + ) -> List[Dict[str, Any]]: + """Rank runs by a metric aggregation. + + Args: + project: Project name. + metric_name: Metric to rank by. + aggregation: One of ``MIN``, ``MAX``, ``AVG``, ``LAST``, + ``VARIANCE``. + direction: ``ASC`` or ``DESC``. + limit: Max results (max 100). + offset: Pagination offset. + + Returns: + List of run dicts with metric values. + """ + params: Dict[str, Any] = { + 'projectName': project, + 'logName': metric_name, + 'aggregation': aggregation, + 'direction': direction, + 'limit': min(limit, 100), + 'offset': offset, + } + return self._get('/api/runs/leaderboard', params=params) + + # ------------------------------------------------------------------ + # Files + # ------------------------------------------------------------------ + + def get_files( + self, + project: str, + run_id: int, + file_name: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Get file metadata with presigned download URLs. + + Args: + project: Project name. + run_id: Numeric server ID. + file_name: Filter by log name / file name. + + Returns: + List of file dicts with keys: ``fileName``, ``fileType``, + ``fileSize``, ``step``, ``time``, ``downloadUrl``. + """ + params: Dict[str, Any] = {'runId': run_id, 'projectName': project} + if file_name is not None: + params['logName'] = file_name + return self._get('/api/runs/files', params=params) + + def download_file( + self, + project: str, + run_id: int, + file_name: str, + destination: Union[str, Path] = '.', + ) -> Path: + """Download a file artifact to local disk. + + Args: + project: Project name. + run_id: Numeric server ID. + file_name: Log name of the file to download. + destination: Directory or file path. + + Returns: + Path to the downloaded file. + + Raises: + QueryError: If no matching file is found. + """ + files = self.get_files(project, run_id, file_name=file_name) + if not files: + raise QueryError(f'No file found matching "{file_name}" for run {run_id}') + + file_info = files[0] + url = file_info.get('downloadUrl') or file_info.get('url') + if not url: + raise QueryError(f'No download URL for file "{file_name}"') + + dest = Path(destination) + if dest.is_dir(): + dest = dest / file_info.get('fileName', file_name) + + resp = httpx.get(url, follow_redirects=True, timeout=120) + resp.raise_for_status() + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(resp.content) + return dest + + # ------------------------------------------------------------------ + # Logs + # ------------------------------------------------------------------ + + def get_logs( + self, + project: str, + run_id: int, + log_type: Optional[str] = None, + limit: int = 10000, + offset: int = 0, + ) -> List[Dict[str, Any]]: + """Fetch console logs for a run. + + Args: + project: Project name. + run_id: Numeric server ID. + log_type: Filter by type: ``"info"``, ``"error"``, ``"warning"``, + ``"debug"``, ``"print"``. + limit: Max lines (max 10 000). + offset: Pagination offset. + + Returns: + List of log dicts with keys: ``message``, ``logType``, ``time``, + ``lineNumber``, ``step``. + """ + params: Dict[str, Any] = { + 'runId': run_id, + 'projectName': project, + 'limit': min(limit, 10000), + 'offset': offset, + } + if log_type is not None: + params['logType'] = log_type + return self._get('/api/runs/logs', params=params) + + # ------------------------------------------------------------------ + # Internal HTTP helpers + # ------------------------------------------------------------------ + + def _get( + self, + path: str, + params: Optional[Dict[str, Any]] = None, + retry: int = 0, + ) -> Any: + """Issue a GET request with retry logic.""" + url = f'{self._url_api}{path}' + + try: + resp = self._client.get(url, params=params) + except ( + httpx.ConnectError, + httpx.ReadTimeout, + httpx.WriteTimeout, + httpx.PoolTimeout, + ) as exc: + return self._retry_or_raise( + path, params, retry, f'{type(exc).__name__}: {exc}' + ) + + if resp.status_code in (200, 201): + return resp.json() + + # Retryable server errors + if resp.status_code >= 500 and retry < _RETRY_MAX: + return self._retry_or_raise( + path, params, retry, f'HTTP {resp.status_code}: {resp.text[:200]}' + ) + + # Client errors — don't retry + raise QueryError( + f'{url} returned {resp.status_code}: {resp.text[:500]}', + status_code=resp.status_code, + ) + + def _retry_or_raise( + self, + path: str, + params: Optional[Dict[str, Any]], + retry: int, + error_info: str, + ) -> Any: + if retry >= _RETRY_MAX: + raise QueryError(f'Failed after {retry} retries: {error_info}') + wait = min(_RETRY_WAIT_MIN * (2 ** (retry + 1)), _RETRY_WAIT_MAX) + logger.debug( + '%s: retry %d/%d for %s: %s', + tag, + retry + 1, + _RETRY_MAX, + path, + error_info, + ) + time.sleep(wait) + return self._get(path, params=params, retry=retry + 1) + + +# ====================================================================== +# Module-level convenience functions +# ====================================================================== + +_default_client: Optional[Client] = None + + +def _get_client() -> Client: + global _default_client + if _default_client is None: + _default_client = Client() + return _default_client + + +def list_projects() -> List[Dict[str, Any]]: + """List all projects. See :meth:`Client.list_projects`.""" + return _get_client().list_projects() + + +def list_runs( + project: str, + search: Optional[str] = None, + tags: Optional[List[str]] = None, + limit: int = 50, +) -> List[Dict[str, Any]]: + """List runs in a project. See :meth:`Client.list_runs`.""" + return _get_client().list_runs(project, search=search, tags=tags, limit=limit) + + +def get_run(project: str, run_id: Union[int, str]) -> Dict[str, Any]: + """Get run details. See :meth:`Client.get_run`.""" + return _get_client().get_run(project, run_id) + + +def get_metric_names( + project: str, + run_ids: Optional[List[int]] = None, + search: Optional[str] = None, + limit: int = 500, +) -> List[str]: + """List metric names. See :meth:`Client.get_metric_names`.""" + return _get_client().get_metric_names( + project, + run_ids=run_ids, + search=search, + limit=limit, + ) + + +def get_metrics( + project: str, + run_id: int, + metric_names: Optional[List[str]] = None, + limit: int = 10000, +) -> Any: + """Fetch metric data. See :meth:`Client.get_metrics`.""" + return _get_client().get_metrics( + project, + run_id, + metric_names=metric_names, + limit=limit, + ) + + +def get_statistics( + project: str, + run_id: int, + metric_names: Optional[List[str]] = None, +) -> Any: + """Compute metric statistics. See :meth:`Client.get_statistics`.""" + return _get_client().get_statistics(project, run_id, metric_names=metric_names) + + +def compare_runs( + project: str, + run_ids: List[int], + metric_name: str, +) -> Dict[str, Any]: + """Compare runs by metric. See :meth:`Client.compare_runs`.""" + return _get_client().compare_runs(project, run_ids, metric_name) + + +def leaderboard( + project: str, + metric_name: str, + aggregation: str = 'LAST', + direction: str = 'ASC', + limit: int = 50, + offset: int = 0, +) -> List[Dict[str, Any]]: + """Rank runs by metric. See :meth:`Client.leaderboard`.""" + return _get_client().leaderboard( + project, + metric_name, + aggregation=aggregation, + direction=direction, + limit=limit, + offset=offset, + ) + + +def get_files( + project: str, + run_id: int, + file_name: Optional[str] = None, +) -> List[Dict[str, Any]]: + """Get file metadata. See :meth:`Client.get_files`.""" + return _get_client().get_files(project, run_id, file_name=file_name) + + +def download_file( + project: str, + run_id: int, + file_name: str, + destination: Union[str, Path] = '.', +) -> Path: + """Download a file. See :meth:`Client.download_file`.""" + return _get_client().download_file(project, run_id, file_name, destination) + + +def get_logs( + project: str, + run_id: int, + log_type: Optional[str] = None, + limit: int = 10000, + offset: int = 0, +) -> List[Dict[str, Any]]: + """Fetch console logs. See :meth:`Client.get_logs`.""" + return _get_client().get_logs( + project, + run_id, + log_type=log_type, + limit=limit, + offset=offset, + ) + + +# ====================================================================== +# Helpers +# ====================================================================== + + +def _resolve_api_token() -> Optional[str]: + """Resolve API token from environment or keyring.""" + token = os.environ.get('PLUTO_API_TOKEN') + if token: + return token + + # Deprecated env var + token = os.environ.get('MLOP_API_TOKEN') + if token: + warnings.warn( + 'MLOP_API_TOKEN is deprecated. Use PLUTO_API_TOKEN instead.', + DeprecationWarning, + stacklevel=3, + ) + return token + + # Try keyring (same logic as auth.py login) + try: + import keyring + + try: + assert __import__('sys').platform == 'darwin' + token = keyring.get_password('pluto', 'pluto') + except (keyring.errors.NoKeyringError, AssertionError): + from keyrings.alt.file import PlaintextKeyring + + keyring.set_keyring(PlaintextKeyring()) + token = keyring.get_password('pluto', 'pluto') + return token if token else None + except Exception: + return None + + +def _resolve_url_api(host: Optional[str] = None) -> str: + """Resolve the API base URL.""" + if host is not None: + # If it looks like a full URL, strip to derive url_api + if host.startswith('http://') or host.startswith('https://'): + # e.g. "https://pluto.trainy.ai" → "https://pluto-api.trainy.ai" + # or "https://pluto-api.trainy.ai" → keep as-is + if '/api/' in host or host.rstrip('/').endswith((':3001',)): + return host.rstrip('/') + # Best guess: user passed the app URL. We can't reliably derive + # the API URL, so use it as-is and let the server redirect. + return host.rstrip('/') + else: + # Bare host like "10.0.0.1" or "my-host:3001" — same as Settings.update_host + return f'http://{host}:3001' + + # Environment variable + url = os.environ.get('PLUTO_URL_API') + if url: + return url.rstrip('/') + + url = os.environ.get('MLOP_URL_API') + if url: + warnings.warn( + 'MLOP_URL_API is deprecated. Use PLUTO_URL_API instead.', + DeprecationWarning, + stacklevel=3, + ) + return url.rstrip('/') + + return _DEFAULT_URL_API + + +def _to_dataframe(data: Any) -> Any: + """Convert raw metric data to a pandas DataFrame if pandas is available.""" + if not data: + try: + import pandas as pd + + return pd.DataFrame(columns=['metric', 'step', 'value', 'time']) + except ImportError: + return [] + + try: + import pandas as pd + + return pd.DataFrame(data) + except ImportError: + return data diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..88eec28 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,495 @@ +"""Tests for pluto.query module.""" + +import json +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from pluto.query import Client, QueryError, _resolve_api_token, _resolve_url_api + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _clean_env(monkeypatch): + """Remove query-related env vars for test isolation.""" + for key in ( + 'PLUTO_API_TOKEN', + 'MLOP_API_TOKEN', + 'PLUTO_URL_API', + 'MLOP_URL_API', + ): + monkeypatch.delenv(key, raising=False) + + +@pytest.fixture() +def mock_response(): + """Factory for mock httpx.Response objects.""" + + def _make(status_code=200, json_data=None, text=''): + resp = MagicMock(spec=httpx.Response) + resp.status_code = status_code + resp.json.return_value = json_data if json_data is not None else {} + resp.text = text or json.dumps(json_data or {}) + return resp + + return _make + + +@pytest.fixture() +def client(monkeypatch): + """A Client with a mocked httpx.Client.""" + monkeypatch.setenv('PLUTO_API_TOKEN', 'test-token-123') + c = Client() + c._client = MagicMock(spec=httpx.Client) + return c + + +# --------------------------------------------------------------------------- +# Token resolution +# --------------------------------------------------------------------------- + + +class TestResolveApiToken: + def test_from_pluto_env(self, monkeypatch): + monkeypatch.setenv('PLUTO_API_TOKEN', 'plt_abc') + assert _resolve_api_token() == 'plt_abc' + + def test_from_deprecated_mlop_env(self, monkeypatch): + monkeypatch.setenv('MLOP_API_TOKEN', 'old_token') + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + token = _resolve_api_token() + assert token == 'old_token' + assert any('MLOP_API_TOKEN' in str(x.message) for x in w) + + def test_pluto_takes_precedence(self, monkeypatch): + monkeypatch.setenv('PLUTO_API_TOKEN', 'new_token') + monkeypatch.setenv('MLOP_API_TOKEN', 'old_token') + assert _resolve_api_token() == 'new_token' + + def test_none_when_nothing_set(self): + # keyring may or may not work, but at minimum no env var + token = _resolve_api_token() + # Token might come from keyring if configured, but in CI it should be None + assert token is None or isinstance(token, str) + + +# --------------------------------------------------------------------------- +# URL resolution +# --------------------------------------------------------------------------- + + +class TestResolveUrlApi: + def test_default(self): + assert _resolve_url_api() == 'https://pluto-api.trainy.ai' + + def test_full_url(self): + assert ( + _resolve_url_api('https://my-api.example.com') + == 'https://my-api.example.com' + ) + + def test_full_url_trailing_slash(self): + assert ( + _resolve_url_api('https://my-api.example.com/') + == 'https://my-api.example.com' + ) + + def test_bare_host(self): + assert _resolve_url_api('10.0.0.1') == 'http://10.0.0.1:3001' + + def test_env_var(self, monkeypatch): + monkeypatch.setenv('PLUTO_URL_API', 'https://env-api.example.com') + assert _resolve_url_api() == 'https://env-api.example.com' + + def test_deprecated_env_var(self, monkeypatch): + monkeypatch.setenv('MLOP_URL_API', 'https://old-api.example.com') + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + url = _resolve_url_api() + assert url == 'https://old-api.example.com' + assert any('MLOP_URL_API' in str(x.message) for x in w) + + def test_host_param_overrides_env(self, monkeypatch): + monkeypatch.setenv('PLUTO_URL_API', 'https://env-api.example.com') + assert ( + _resolve_url_api('https://param-api.example.com') + == 'https://param-api.example.com' + ) + + +# --------------------------------------------------------------------------- +# Client initialization +# --------------------------------------------------------------------------- + + +class TestClientInit: + def test_requires_token(self): + with pytest.raises(QueryError, match='No API token'): + Client() + + def test_from_env(self, monkeypatch): + monkeypatch.setenv('PLUTO_API_TOKEN', 'plt_abc') + c = Client() + assert c._api_token == 'plt_abc' + c.close() + + def test_explicit_token(self, monkeypatch): + c = Client(api_token='explicit_token') + assert c._api_token == 'explicit_token' + c.close() + + def test_context_manager(self, monkeypatch): + monkeypatch.setenv('PLUTO_API_TOKEN', 'plt_abc') + with Client() as c: + assert c._api_token == 'plt_abc' + + def test_custom_host(self, monkeypatch): + monkeypatch.setenv('PLUTO_API_TOKEN', 'plt_abc') + c = Client(host='10.0.0.1') + assert c._url_api == 'http://10.0.0.1:3001' + c.close() + + +# --------------------------------------------------------------------------- +# list_projects +# --------------------------------------------------------------------------- + + +class TestListProjects: + def test_success(self, client, mock_response): + data = [ + {'id': 1, 'name': 'proj-a', 'runCount': 5}, + {'id': 2, 'name': 'proj-b', 'runCount': 3}, + ] + client._client.get.return_value = mock_response(200, data) + result = client.list_projects() + assert result == data + client._client.get.assert_called_once() + call_args = client._client.get.call_args + assert '/api/runs/projects' in call_args[0][0] + + +# --------------------------------------------------------------------------- +# list_runs +# --------------------------------------------------------------------------- + + +class TestListRuns: + def test_basic(self, client, mock_response): + data = [{'id': 1, 'name': 'run-1', 'displayId': 'MMP-1'}] + client._client.get.return_value = mock_response(200, data) + result = client.list_runs('my-project') + assert result == data + call_args = client._client.get.call_args + assert call_args[1]['params']['projectName'] == 'my-project' + assert call_args[1]['params']['limit'] == 50 + + def test_with_search_and_tags(self, client, mock_response): + client._client.get.return_value = mock_response(200, []) + client.list_runs('proj', search='experiment', tags=['v2', 'prod']) + call_args = client._client.get.call_args + assert call_args[1]['params']['search'] == 'experiment' + assert call_args[1]['params']['tags'] == 'v2,prod' + + def test_limit_capped_at_200(self, client, mock_response): + client._client.get.return_value = mock_response(200, []) + client.list_runs('proj', limit=999) + call_args = client._client.get.call_args + assert call_args[1]['params']['limit'] == 200 + + +# --------------------------------------------------------------------------- +# get_run +# --------------------------------------------------------------------------- + + +class TestGetRun: + def test_by_numeric_id(self, client, mock_response): + data = {'id': 42, 'name': 'run-42', 'config': {'lr': 0.001}} + client._client.get.return_value = mock_response(200, data) + result = client.get_run('proj', 42) + assert result == data + call_url = client._client.get.call_args[0][0] + assert '/api/runs/details/42' in call_url + + def test_by_display_id(self, client, mock_response): + data = {'id': 42, 'displayId': 'MMP-1', 'config': {}} + client._client.get.return_value = mock_response(200, data) + result = client.get_run('proj', 'MMP-1') + assert result == data + call_url = client._client.get.call_args[0][0] + assert '/api/runs/details/by-display-id/MMP-1' in call_url + + +# --------------------------------------------------------------------------- +# get_metric_names +# --------------------------------------------------------------------------- + + +class TestGetMetricNames: + def test_basic(self, client, mock_response): + data = ['loss', 'accuracy', 'val/loss'] + client._client.get.return_value = mock_response(200, data) + result = client.get_metric_names('proj') + assert result == data + + def test_with_run_ids(self, client, mock_response): + client._client.get.return_value = mock_response(200, []) + client.get_metric_names('proj', run_ids=[1, 2, 3]) + params = client._client.get.call_args[1]['params'] + assert params['runIds'] == '1,2,3' + + +# --------------------------------------------------------------------------- +# get_metrics +# --------------------------------------------------------------------------- + + +class TestGetMetrics: + def test_single_metric(self, client, mock_response): + data = [ + {'metric': 'loss', 'step': 0, 'value': 1.0, 'time': '2025-01-01'}, + {'metric': 'loss', 'step': 1, 'value': 0.5, 'time': '2025-01-01'}, + ] + client._client.get.return_value = mock_response(200, data) + result = client.get_metrics('proj', 42, metric_names=['loss']) + params = client._client.get.call_args[1]['params'] + assert params['logName'] == 'loss' + # Should return a DataFrame if pandas available + try: + import pandas as pd + + assert isinstance(result, pd.DataFrame) + assert len(result) == 2 + except ImportError: + assert isinstance(result, list) + assert len(result) == 2 + + def test_multiple_metrics(self, client, mock_response): + data_loss = [{'metric': 'loss', 'step': 0, 'value': 1.0}] + data_acc = [{'metric': 'acc', 'step': 0, 'value': 0.8}] + client._client.get.side_effect = [ + mock_response(200, data_loss), + mock_response(200, data_acc), + ] + client.get_metrics('proj', 42, metric_names=['loss', 'acc']) + assert client._client.get.call_count == 2 + + def test_all_metrics(self, client, mock_response): + data = [{'metric': 'loss', 'step': 0, 'value': 1.0}] + client._client.get.return_value = mock_response(200, data) + client.get_metrics('proj', 42) + params = client._client.get.call_args[1]['params'] + assert 'logName' not in params + + def test_empty_returns_empty_dataframe(self, client, mock_response): + client._client.get.return_value = mock_response(200, []) + result = client.get_metrics('proj', 42) + try: + import pandas as pd + + assert isinstance(result, pd.DataFrame) + assert len(result) == 0 + except ImportError: + assert result == [] + + +# --------------------------------------------------------------------------- +# get_statistics +# --------------------------------------------------------------------------- + + +class TestGetStatistics: + def test_basic(self, client, mock_response): + data = {'loss': {'min': 0.1, 'max': 1.0, 'mean': 0.5}} + client._client.get.return_value = mock_response(200, data) + result = client.get_statistics('proj', 42) + assert result == data + + +# --------------------------------------------------------------------------- +# compare_runs +# --------------------------------------------------------------------------- + + +class TestCompareRuns: + def test_basic(self, client, mock_response): + data = {'runs': [], 'bestRun': None} + client._client.get.return_value = mock_response(200, data) + client.compare_runs('proj', [1, 2, 3], 'loss') + params = client._client.get.call_args[1]['params'] + assert params['runIds'] == '1,2,3' + assert params['logName'] == 'loss' + + +# --------------------------------------------------------------------------- +# leaderboard +# --------------------------------------------------------------------------- + + +class TestLeaderboard: + def test_defaults(self, client, mock_response): + client._client.get.return_value = mock_response(200, []) + client.leaderboard('proj', 'val/loss') + params = client._client.get.call_args[1]['params'] + assert params['aggregation'] == 'LAST' + assert params['direction'] == 'ASC' + assert params['limit'] == 50 + + def test_custom_params(self, client, mock_response): + client._client.get.return_value = mock_response(200, []) + client.leaderboard('proj', 'acc', aggregation='MAX', direction='DESC', limit=10) + params = client._client.get.call_args[1]['params'] + assert params['aggregation'] == 'MAX' + assert params['direction'] == 'DESC' + assert params['limit'] == 10 + + +# --------------------------------------------------------------------------- +# get_files +# --------------------------------------------------------------------------- + + +class TestGetFiles: + def test_basic(self, client, mock_response): + data = [{'fileName': 'model.pt', 'downloadUrl': 'https://s3/model.pt'}] + client._client.get.return_value = mock_response(200, data) + result = client.get_files('proj', 42) + assert result == data + + def test_with_filter(self, client, mock_response): + client._client.get.return_value = mock_response(200, []) + client.get_files('proj', 42, file_name='checkpoint') + params = client._client.get.call_args[1]['params'] + assert params['logName'] == 'checkpoint' + + +# --------------------------------------------------------------------------- +# download_file +# --------------------------------------------------------------------------- + + +class TestDownloadFile: + def test_download(self, client, mock_response, tmp_path): + file_data = [ + { + 'fileName': 'model.pt', + 'downloadUrl': 'https://s3/model.pt', + 'fileSize': 100, + } + ] + client._client.get.return_value = mock_response(200, file_data) + + dl_response = MagicMock() + dl_response.content = b'model-bytes' + dl_response.raise_for_status = MagicMock() + + with patch('pluto.query.httpx.get', return_value=dl_response): + path = client.download_file('proj', 42, 'model.pt', destination=tmp_path) + + assert path == tmp_path / 'model.pt' + assert path.read_bytes() == b'model-bytes' + + def test_no_matching_file(self, client, mock_response): + client._client.get.return_value = mock_response(200, []) + with pytest.raises(QueryError, match='No file found'): + client.download_file('proj', 42, 'missing.pt') + + +# --------------------------------------------------------------------------- +# get_logs +# --------------------------------------------------------------------------- + + +class TestGetLogs: + def test_basic(self, client, mock_response): + data = [{'message': 'hello', 'logType': 'info', 'lineNumber': 1}] + client._client.get.return_value = mock_response(200, data) + result = client.get_logs('proj', 42) + assert result == data + + def test_with_type_filter(self, client, mock_response): + client._client.get.return_value = mock_response(200, []) + client.get_logs('proj', 42, log_type='error') + params = client._client.get.call_args[1]['params'] + assert params['logType'] == 'error' + + +# --------------------------------------------------------------------------- +# Error handling & retries +# --------------------------------------------------------------------------- + + +class TestErrorHandling: + def test_client_error_raises(self, client, mock_response): + client._client.get.return_value = mock_response(404, text='Not Found') + with pytest.raises(QueryError, match='404'): + client.list_projects() + + def test_server_error_retries_and_raises(self, client, mock_response): + client._client.get.return_value = mock_response( + 500, + text='Internal Server Error', + ) + with pytest.raises(QueryError, match='500'): + client.list_projects() + # Should have retried: 1 initial + 4 retries = 5 calls + assert client._client.get.call_count == 5 + + def test_connection_error_retries(self, client): + client._client.get.side_effect = httpx.ConnectError('connection refused') + with pytest.raises(QueryError, match='Failed after'): + client.list_projects() + assert client._client.get.call_count == 5 + + +# --------------------------------------------------------------------------- +# Module-level convenience functions +# --------------------------------------------------------------------------- + + +class TestModuleFunctions: + def test_list_runs_creates_default_client(self, monkeypatch, mock_response): + monkeypatch.setenv('PLUTO_API_TOKEN', 'test-token') + + import pluto.query as pq + + # Reset default client + pq._default_client = None + + mock_client_instance = MagicMock(spec=Client) + mock_client_instance.list_runs.return_value = [{'id': 1}] + + with patch.object(pq, 'Client', return_value=mock_client_instance): + result = pq.list_runs('my-project') + + assert result == [{'id': 1}] + mock_client_instance.list_runs.assert_called_once_with( + 'my-project', search=None, tags=None, limit=50 + ) + + # Clean up + pq._default_client = None + + def test_import_as_submodule(self): + import pluto.query as pq + + assert hasattr(pq, 'Client') + assert hasattr(pq, 'list_runs') + assert hasattr(pq, 'get_run') + assert hasattr(pq, 'get_metrics') + + def test_accessible_from_pluto(self): + import pluto + + assert hasattr(pluto, 'query') + assert hasattr(pluto.query, 'Client')