Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,40 @@ prefix = your_s3_prefix
```

[AWS configuration and credential files](https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-files.html) can be used for authentication and settings. Authentication credentials should be specified in `~/.aws/credentials`. Settings like `region`, `output`, `endpoint_url` should be specified in `~/.aws/config`. Multiple different profiles can be defined and the specific profile to use can be specified in the `aws` section of the `.ini` file.

### Loading From MLflow

Simplexity provides a high‑level loader to reconstruct models and read run data from MLflow.

Quick start:

```python
from simplexity.loaders import ExperimentLoader

# Use your MLflow run ID and tracking URI (e.g., Databricks)
loader = ExperimentLoader.from_mlflow(
run_id="<RUN_ID>",
tracking_uri="databricks", # or None to rely on env MLFLOW_TRACKING_URI
)

# Load saved Hydra config and inspect
cfg = loader.load_config()
print("Model target:", cfg.predictive_model.instance._target_)

# Discover checkpoints and load the latest model
print("Available checkpoints:", loader.list_checkpoints())
model = loader.load_model(step="latest")

# Fetch metrics as a tidy pandas DataFrame
df = loader.load_metrics(pattern="validation/*") # glob filter optional
print(df.head())
```

Notes:

- PyTorch models: if your run used a PyTorch model (e.g., `transformer_lens.HookedTransformer`), ensure the package is installed in your environment. The loader first tries JAX’s `PredictiveModel` path, then falls back to `torch.nn.Module` and sets `model.eval()` by default.
- Persistence: the loader reconstructs the persister from the saved config. If the run has no persistence or no checkpoints, `load_model()` raises an informative error.
- S3 credentials: if your persister uses `S3Persister.from_config`, you can override the location of the `.ini` via `ExperimentLoader.from_mlflow(..., config_path="/path/to/config.ini")`.
- Metrics filtering uses glob syntax (e.g., `"validation/*"`).

See `notebooks/experiment_loader_demo.ipynb` for a runnable example.
396 changes: 396 additions & 0 deletions notebooks/experiment_loader_demo.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"pandas",
"penzai",
"plotly",
"transformer-lens>=2.16.1",
"treescope",
]

Expand Down
4 changes: 4 additions & 0 deletions simplexity/loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .experiment_loader import ExperimentLoader

Check failure on line 1 in simplexity/loaders/__init__.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (D104)

simplexity/loaders/__init__.py:1:1: D104 Missing docstring in public package

__all__ = ["ExperimentLoader"]

162 changes: 162 additions & 0 deletions simplexity/loaders/experiment_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

from omegaconf import DictConfig, OmegaConf

from simplexity.logging.mlflow_reader import MLflowRunReader
from simplexity.logging.run_reader import RunReader
from simplexity.persistence.model_persister import ModelPersister
from simplexity.predictive_models.predictive_model import PredictiveModel
from simplexity.utils.hydra import typed_instantiate


@dataclass
class ExperimentLoader:
"""High-level loader for reconstructing models and reading run data."""

reader: RunReader
config_path: str | None = None
_cached_config: DictConfig | None = None

# --- Constructors ---
@classmethod
def from_mlflow(cls, run_id: str, tracking_uri: str | None = None, config_path: str | None = None) -> "ExperimentLoader":

Check failure on line 25 in simplexity/loaders/experiment_loader.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (E501)

simplexity/loaders/experiment_loader.py:25:121: E501 Line too long (125 > 120)

Check failure on line 25 in simplexity/loaders/experiment_loader.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (UP037)

simplexity/loaders/experiment_loader.py:25:107: UP037 Remove quotes from type annotation

Check failure on line 25 in simplexity/loaders/experiment_loader.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (D102)

simplexity/loaders/experiment_loader.py:25:9: D102 Missing docstring in public method
reader = MLflowRunReader(run_id=run_id, tracking_uri=tracking_uri)
return cls(reader=reader, config_path=config_path)

# --- Accessors ---
def load_config(self) -> DictConfig:

Check failure on line 30 in simplexity/loaders/experiment_loader.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (D102)

simplexity/loaders/experiment_loader.py:30:9: D102 Missing docstring in public method
if self._cached_config is None:
self._cached_config = self.reader.get_config()
return self._cached_config

def load_metrics(self, pattern: str | None = None):

Check failure on line 35 in simplexity/loaders/experiment_loader.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (D102)

simplexity/loaders/experiment_loader.py:35:9: D102 Missing docstring in public method
return self.reader.get_metrics(pattern=pattern)

# --- Helper methods ---
def _resolve_device(self, device: str) -> str:
"""Resolve 'auto' device to actual PyTorch device."""
if device != "auto":
return device

try:
import torch
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps" # Apple Silicon
else:
return "cpu"
except ImportError:
return "cpu"

# --- Model reconstruction ---
def _instantiate_model_and_persister(self) -> tuple[object, ModelPersister | None, DictConfig]:
cfg = self.load_config()
try:
# Handle device resolution for 'auto' setting
model_config = OmegaConf.to_container(cfg.predictive_model.instance, resolve=False) # type: ignore[arg-type]
if isinstance(model_config, dict):
cfg_section = model_config.get('cfg')
if isinstance(cfg_section, dict) and cfg_section.get('device') == 'auto':
cfg_section = dict(cfg_section)
cfg_section['device'] = self._resolve_device('auto')
model_config = dict(model_config)
model_config['cfg'] = cfg_section

# First, try JAX-style PredictiveModel
try:
model = typed_instantiate(model_config, PredictiveModel)
except Exception:
# Fallback: try PyTorch if available
try:
import torch # type: ignore
except Exception as torch_import_err: # pragma: no cover - import guard
raise RuntimeError(
"Failed to instantiate predictive model from run config.\n"
"Tried PredictiveModel (JAX) and PyTorch fallback, but PyTorch is not available.\n"
"Install torch support (uv sync --extra pytorch) or ensure the model package is installed.\n"
f"Underlying error: {torch_import_err}"
) from torch_import_err

# Instantiate expecting a torch.nn.Module
try:
model = typed_instantiate(model_config, torch.nn.Module) # type: ignore[attr-defined]
# Put model into eval mode by default for inference
if hasattr(model, "eval"):
model.eval()
except Exception as e2:
raise RuntimeError(
"Failed to instantiate predictive model from run config.\n"
"Ensure the model's Python package is installed (e.g., transformer_lens for HookedTransformer).\n"
f"Underlying error: {e2}"
) from e2
except Exception as e:
raise RuntimeError(
"Failed to instantiate predictive model from run config.\n"
"Ensure the model's Python package is installed (e.g., `transformer_lens`).\n"
f"Underlying error: {e}"
) from e
persister: ModelPersister | None
if cfg.persistence:
try:
# Override config_filename if custom config_path is provided
persister_config = OmegaConf.to_container(cfg.persistence.instance, resolve=False) # type: ignore[arg-type]
if isinstance(persister_config, dict) and self.config_path:
if 'config_filename' in persister_config:
persister_config = dict(persister_config)
persister_config['config_filename'] = self.config_path
persister = typed_instantiate(persister_config, ModelPersister)
except Exception as e:
raise RuntimeError(
"Failed to instantiate persister from run config.\n"
"If using S3, ensure credentials/config are available (e.g., config.ini or env).\n"
f"Underlying error: {e}"
) from e
else:
persister = None
return model, persister, cfg

def _instantiate_persister_only(self) -> ModelPersister | None:
cfg = self.load_config()
if not cfg.persistence:
return None
try:
# Override config_filename if custom config_path is provided
persister_config = OmegaConf.to_container(cfg.persistence.instance, resolve=False) # type: ignore[arg-type]
if isinstance(persister_config, dict) and self.config_path:
if 'config_filename' in persister_config:
persister_config = dict(persister_config)
persister_config['config_filename'] = self.config_path
return typed_instantiate(persister_config, ModelPersister)
except Exception:
# Best-effort: return None if we cannot construct the persister (missing creds, etc.)
return None

def list_checkpoints(self) -> list[int]:
persister = self._instantiate_persister_only()
return persister.list_checkpoints() if persister else []

def latest_checkpoint(self) -> int | None:
persister = self._instantiate_persister_only()
return persister.latest_checkpoint() if persister else None

def load_model(self, step: int | Literal["latest"] = "latest") -> object:
model, persister, _ = self._instantiate_model_and_persister()
if not persister:
raise RuntimeError("No persistence configuration found in run config; cannot load checkpoints.")

target_step: int
if step == "latest":
latest = persister.latest_checkpoint()
if latest is None:
raise RuntimeError("No checkpoints found for this run.")
target_step = latest
else:
if not persister.checkpoint_exists(step):
raise RuntimeError(f"Requested checkpoint step {step} does not exist.")
target_step = step

return persister.load_weights(model, target_step)
103 changes: 103 additions & 0 deletions simplexity/logging/mlflow_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

import fnmatch
import tempfile
from pathlib import Path
from typing import Iterable

import mlflow
import pandas as pd
from mlflow.entities import Metric
from omegaconf import DictConfig, OmegaConf

from simplexity.logging.run_reader import RunReader


class MLflowRunReader(RunReader):
"""Read experiment data for a single MLflow run."""

def __init__(self, run_id: str, tracking_uri: str | None = None) -> None:
self._run_id = run_id
# The global mlflow.set_tracking_uri also works, but we keep it local
self._client = mlflow.MlflowClient(tracking_uri=tracking_uri)
self._temp_dir = tempfile.TemporaryDirectory()

# --- Basic run info helpers ---
def _get_run(self):
return self._client.get_run(self._run_id)

# --- Config/params/tags ---
def get_config(self) -> DictConfig:
# We expect the config artifact to be saved as "config.yaml" at the root.
# If not found, search the artifact tree for a file named config.yaml.
dst_dir = Path(self._temp_dir.name) / "config"
dst_dir.mkdir(parents=True, exist_ok=True)
try:
local_path = self._client.download_artifacts(self._run_id, "config.yaml", str(dst_dir))
return OmegaConf.load(local_path)
except Exception:
for path in self.list_artifacts():
if Path(path).name == "config.yaml":
local_path = self._client.download_artifacts(self._run_id, path, str(dst_dir))
return OmegaConf.load(local_path)
raise RuntimeError(
"Could not find config.yaml in MLflow artifacts for this run."
)

def get_params(self) -> dict[str, str]:
run = self._get_run()
# mlflow returns a dict-like mapping of strings
return dict(run.data.params)

def get_tags(self) -> dict[str, str]:
run = self._get_run()
return dict(run.data.tags)

# --- Metrics ---
def get_metrics(self, pattern: str | None = None) -> pd.DataFrame:
run = self._get_run()
metric_keys = list(run.data.metrics.keys())
if pattern:
# Support glob-style filtering (e.g., "validation/*")
metric_keys = [k for k in metric_keys if fnmatch.fnmatch(k, pattern)]

records: list[tuple[str, int, float, int]] = []
for key in metric_keys:
history: list[Metric] = self._client.get_metric_history(self._run_id, key)
for m in history:
records.append((key, int(m.step), float(m.value), int(m.timestamp)))
if not records:
return pd.DataFrame(columns=["metric", "step", "value", "timestamp"])
df = pd.DataFrame.from_records(records, columns=["metric", "step", "value", "timestamp"])
df.sort_values(["metric", "step"], inplace=True)
return df

# --- Artifacts ---
def list_artifacts(self, path: str | None = None) -> list[str]:
"""List artifact relative paths (recursively)."""
base = path or ""
results: list[str] = []

def _recurse(rel: str) -> None:
for e in self._client.list_artifacts(self._run_id, rel):
child = f"{rel}/{e.path}" if rel else e.path
if e.is_dir:
_recurse(child)
else:
results.append(child)

_recurse(base)
return results

def download_artifact(self, path: str, dst: str | Path | None = None) -> Path:
dst_dir = Path(dst) if dst is not None else Path(self._temp_dir.name) / "artifacts"
dst_dir.mkdir(parents=True, exist_ok=True)
local_path = self._client.download_artifacts(self._run_id, path, str(dst_dir))
return Path(local_path)

# --- Cleanup ---
def __del__(self) -> None: # pragma: no cover - best-effort cleanup
try:
self._temp_dir.cleanup()
except Exception:
pass
37 changes: 37 additions & 0 deletions simplexity/logging/run_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Protocol

import pandas as pd
from omegaconf import DictConfig


class RunReader(Protocol):
"""Protocol for reading experiment run data from a tracking backend."""

def get_config(self) -> DictConfig:
"""Return the saved run config as a DictConfig."""
...

def get_params(self) -> dict[str, str]:
"""Return run parameters as a simple dict of strings."""
...

def get_tags(self) -> dict[str, str]:
"""Return run tags as a simple dict of strings."""
...

def get_metrics(self, pattern: str | None = None) -> pd.DataFrame:
"""Return metrics as a tidy DataFrame with columns: metric, step, value, timestamp."""
...

def list_artifacts(self, path: str | None = None) -> list[str]:
"""List artifact paths stored for this run (relative paths)."""
...

def download_artifact(self, path: str, dst: str | Path | None = None) -> Path:
"""Download an artifact relative path to a destination directory and return the local file path."""
...

27 changes: 27 additions & 0 deletions simplexity/persistence/local_equinox_persister.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,30 @@ def load_weights(self, model: PredictiveModel, step: int = 0) -> PredictiveModel

def _get_path(self, step: int) -> Path:
return self.directory / str(step) / self.filename

# --- Checkpoint discovery ---
def list_checkpoints(self) -> list[int]:
steps: list[int] = []
if not self.directory.exists():
return steps
for child in self.directory.iterdir():
if child.is_dir():
try:
step = int(child.name)
except ValueError:
continue
if (child / self.filename).exists():
steps.append(step)
steps.sort()
return steps

def latest_checkpoint(self) -> int | None:
steps = self.list_checkpoints()
return steps[-1] if steps else None

def checkpoint_exists(self, step: int) -> bool:
return (self.directory / str(step) / self.filename).exists()

def uri_for_step(self, step: int) -> str:
path = self._get_path(step)
return f"file://{path}"
Loading
Loading