diff --git a/README.md b/README.md index f47711fc..81cfe8c7 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,32 @@ uv run python simplexity/run_experiment.py --multirun The `ModelPersister` class is responsible for saving and loading model checkpoints. The `LocalPersister` class saves checkpoints to the local file system, while the `S3Persister` class saves checkpoints to an S3 bucket. +### Reusing Config Sections from MLflow Runs + +`managed_run` can bootstrap a new configuration by pulling pieces of previous MLflow runs before Hydra instantiates anything. Add a `load_configs` block to your config (see `examples/configs/demo_config.yaml`): + +```yaml +load_configs: + - tracking_uri: databricks # optional; defaults to current URI + experiment_name: /Shared/previous_exp # or provide experiment_id + experiment_id: "123456" # optional safeguard when both are set + run_name: best_model_run # or provide run_id + run_id: "0123456789abcdef" + artifact_path: config.yaml # optional; defaults to config.yaml + configs: + predictive_model: historical.predictive_model + generative_process.instance: reused.generative_process +``` + +For each entry we: + +- Create a dedicated `MlflowClient` using the supplied tracking URI. +- Download the specified artifact (`config.yaml` by default) from the referenced run. +- Copy the listed source keys into the current config using the provided destination paths, merging with any existing values so local overrides still apply. +- Optionally accept `experiment_id`/`run_id` pairs instead of names. When both the name and id are supplied we verify they reference the same MLflow objects to avoid accidental mixups. + +This makes it easy to pin specific components (models, generative processes, etc.) from a prior run while still editing other sections locally. + #### Using S3 Storage The `S3Persister`, can be configured using an `.ini` file, which should have the following structure: diff --git a/examples/configs/demo_config.yaml b/examples/configs/demo_config.yaml index 084781e2..f9beaa58 100644 --- a/examples/configs/demo_config.yaml +++ b/examples/configs/demo_config.yaml @@ -15,3 +15,21 @@ tags: retention: temp example_tag_1: value1 example_tag_2: value2 + +load_configs: + - tracking_uri: databricks + experiment_name: /Shared/managed_run_demo_20251106_045326 + # experiment_id: 123456789012345678 # optional alternative to experiment_name + run_name: managed_run_demo_20251106_045326 + # run_id: 0123456789abcdef0123456789abcdef # optional alternative to run_name + artifact_path: config.yaml + configs: + predictive_model: old_model_1 + generative_process: old_generative_process_1 + - tracking_uri: databricks + experiment_name: /Shared/managed_run_demo_20251105_003639 + run_name: managed_run_demo_20251105_003639 + artifact_path: config.yaml + configs: + predictive_model: old_model_2 + generative_process: old_generative_process_2 diff --git a/simplexity/run_management/run_management.py b/simplexity/run_management/run_management.py index acdb11c0..5a2ca94c 100644 --- a/simplexity/run_management/run_management.py +++ b/simplexity/run_management/run_management.py @@ -9,6 +9,7 @@ import os import random import subprocess +import tempfile import warnings from collections.abc import Callable, Iterator from contextlib import contextmanager, nullcontext @@ -18,7 +19,8 @@ import jax import jax.numpy as jnp import mlflow -from omegaconf import DictConfig, OmegaConf +from mlflow.exceptions import MlflowException +from omegaconf import DictConfig, OmegaConf, open_dict from torch.nn import Module as PytorchModel from simplexity.generative_processes.generative_process import GenerativeProcess @@ -96,6 +98,135 @@ def _suppress_pydantic_field_attribute_warning() -> Iterator[None]: yield +def _load_config(cfg: DictConfig, load_config: DictConfig) -> None: + """Load the config.""" + if not load_config: + SIMPLEXITY_LOGGER.warning("[config] load_config entry is empty, skipping") + return + + tracking_uri: str | None = load_config.get("tracking_uri") + experiment_name: str | None = load_config.get("experiment_name") + experiment_id: str | None = load_config.get("experiment_id") + run_name: str | None = load_config.get("run_name") + run_id: str | None = load_config.get("run_id") + configs_to_load: DictConfig | None = load_config.get("configs") + artifact_path: str = load_config.get("artifact_path", "config.yaml") + + if not configs_to_load: + run_identifier = run_name or run_id or "" + SIMPLEXITY_LOGGER.warning( + f"[config] no configs specified for load_config run '{run_identifier}', nothing to merge" + ) + return + + client = mlflow.MlflowClient(tracking_uri=tracking_uri) + + if experiment_id: + experiment = client.get_experiment(experiment_id) + if experiment is None: + raise ValueError( + f"Experiment with id '{experiment_id}' not found for load_config " + f"run '{run_name or run_id or ''}'" + ) + if experiment_name: + if experiment.name != experiment_name: + raise ValueError( + f"Experiment id '{experiment_id}' refers to '{experiment.name}', which does not match " + f"provided experiment_name '{experiment_name}'" + ) + else: + experiment_name = experiment.name + elif experiment_name: + experiment = client.get_experiment_by_name(experiment_name) + if experiment is None: + raise ValueError( + f"Experiment '{experiment_name}' not found for load_config run '{run_name or run_id or ''}'" + ) + experiment_id = experiment.experiment_id + else: + raise ValueError("load_config requires experiment_name or experiment_id") + assert experiment_id is not None + assert experiment_name is not None + + if run_id: + try: + run = client.get_run(run_id) + except MlflowException as e: # pragma: no cover - mlflow always raises for missing runs + raise ValueError( + f"Run with id '{run_id}' not found in experiment '{experiment_name}' for load_config entry" + ) from e + if run.info.experiment_id != experiment_id: + raise ValueError( + f"Run id '{run_id}' belongs to experiment id '{run.info.experiment_id}', expected '{experiment_id}'" + ) + if run_name and run.info.run_name != run_name: + raise ValueError( + f"Run id '{run_id}' refers to run '{run.info.run_name}', which does not match " + f"provided run_name '{run_name}'" + ) + elif run_name: + runs = client.search_runs( + experiment_ids=[experiment_id], + filter_string=f"attributes.run_name = '{run_name}'", + max_results=1, + ) + if not runs: + raise ValueError( + f"Run with name '{run_name}' not found in experiment '{experiment_name}' for load_config entry" + ) + run = runs[0] + else: + raise ValueError("load_config requires run_name or run_id") + + run_id = run.info.run_id + run_name = run.info.run_name + assert run_id is not None + assert run_name is not None + + SIMPLEXITY_LOGGER.info( + f"[config] loading artifact '{artifact_path}' from run '{run_name}' ({run_id}) " + f"in experiment '{experiment_name}' ({experiment_id})" + ) + with tempfile.TemporaryDirectory() as temp_dir: + artifact_local_path = client.download_artifacts(run_id, artifact_path, temp_dir) + source_cfg = OmegaConf.load(artifact_local_path) + + configs_mapping: dict[str, str] = OmegaConf.to_container(configs_to_load, resolve=True) # type: ignore[arg-type] + with open_dict(cfg): + for source_key, destination_key in configs_mapping.items(): + if not isinstance(source_key, str) or not source_key: + raise ValueError("load_config configs keys must be non-empty strings") + if not isinstance(destination_key, str) or not destination_key: + raise ValueError("load_config configs values must be non-empty strings") + + selected_config = OmegaConf.select(source_cfg, source_key, throw_on_missing=False) + if selected_config is None: + raise KeyError(f"Config key '{source_key}' not found in run '{run_name}' artifact '{artifact_path}'") + + cloned_config = OmegaConf.create(OmegaConf.to_container(selected_config, resolve=False)) + existing_destination = OmegaConf.select(cfg, destination_key, throw_on_missing=False) + if existing_destination is None: + SIMPLEXITY_LOGGER.info( + f"[config] adding config '{source_key}' from run '{run_name}' to '{destination_key}'" + ) + OmegaConf.update(cfg, destination_key, cloned_config, force_add=True) + else: + SIMPLEXITY_LOGGER.info( + f"[config] merging config '{source_key}' from run '{run_name}' into '{destination_key}'" + ) + merged_config = OmegaConf.merge(cloned_config, existing_destination) + OmegaConf.update(cfg, destination_key, merged_config, force_add=True) + + +def _get_config(args: tuple[Any, ...], kwargs: dict[str, Any]) -> DictConfig: + """Get the config from the arguments.""" + cfg = get_config(args, kwargs) + load_configs: list[DictConfig] = cfg.get("load_configs", []) + for load_config in load_configs: + _load_config(cfg, load_config) + return cfg + + def _setup_environment() -> None: """Setup the environment.""" for key, value in DEFAULT_ENVIRONMENT_VARIABLES.items(): @@ -512,7 +643,7 @@ def managed_run(strict: bool = True, verbose: bool = False) -> Callable[[Callabl def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: def wrapper(*args: Any, **kwargs: Any) -> Any: try: - cfg = get_config(args, kwargs) + cfg = _get_config(args, kwargs) validate_base_config(cfg) with _setup_mlflow(cfg): components = _setup(cfg, strict=strict, verbose=verbose) diff --git a/tests/run_management/test_run_management.py b/tests/run_management/test_run_management.py new file mode 100644 index 00000000..1b205d19 --- /dev/null +++ b/tests/run_management/test_run_management.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from pathlib import Path + +import mlflow +import pytest +from omegaconf import DictConfig, OmegaConf + +from simplexity.run_management import run_management + + +def _create_run_with_config( + tmp_path: Path, + source_cfg: DictConfig, + *, + experiment_name: str, + run_name: str, +) -> tuple[str, str, str]: + tracking_dir = tmp_path / "mlruns" + tracking_uri = tracking_dir.as_posix() + previous_tracking_uri = mlflow.get_tracking_uri() + run_id: str | None = None + experiment_id: str | None = None + try: + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment(experiment_name) + experiment = mlflow.get_experiment_by_name(experiment_name) + assert experiment is not None + experiment_id = experiment.experiment_id + artifact_file = tmp_path / "config.yaml" + OmegaConf.save(source_cfg, artifact_file) + with mlflow.start_run(run_name=run_name) as run: + mlflow.log_artifact(str(artifact_file)) + run_id = run.info.run_id + finally: + if previous_tracking_uri is not None: + mlflow.set_tracking_uri(previous_tracking_uri) + assert experiment_id is not None and run_id is not None + return tracking_uri, experiment_id, run_id + + +def test_load_config_adds_new_sections(tmp_path: Path) -> None: + source_cfg = OmegaConf.create( + { + "predictive_model": {"instance": {"_target_": "transformer_lens.HookedTransformer", "cfg": {"d_model": 4}}}, + "nested": {"sub": {"value": 1}}, + } + ) + experiment_name = "demo" + run_name = "reuse" + tracking_uri, _, _ = _create_run_with_config( + tmp_path, + source_cfg, + experiment_name=experiment_name, + run_name=run_name, + ) + + cfg = OmegaConf.create({"predictive_model": {"instance": {"_target_": "foo"}}}) + load_cfg = DictConfig( + { + "tracking_uri": tracking_uri, + "experiment_name": experiment_name, + "run_name": run_name, + "configs": { + "predictive_model": "old_models.model_1", + "nested.sub": "copied.sub", + }, + } + ) + + run_management._load_config(cfg, load_cfg) + + assert OmegaConf.select(cfg, "old_models.model_1.instance.cfg.d_model") == 4 + assert OmegaConf.select(cfg, "copied.sub.value") == 1 + # Ensure existing sections remain untouched + assert OmegaConf.select(cfg, "predictive_model.instance._target_") == "foo" + + +def test_load_config_merges_into_existing(tmp_path: Path) -> None: + source_cfg = OmegaConf.create({"predictive_model": {"foo": "old", "bar": 1}}) + experiment_name = "demo" + run_name = "reuse" + tracking_uri, _, _ = _create_run_with_config( + tmp_path, + source_cfg, + experiment_name=experiment_name, + run_name=run_name, + ) + + cfg = OmegaConf.create({"predictive_model": {"foo": "new"}}) + load_cfg = DictConfig( + { + "tracking_uri": tracking_uri, + "experiment_name": experiment_name, + "run_name": run_name, + "configs": { + "predictive_model": "predictive_model", + }, + } + ) + + run_management._load_config(cfg, load_cfg) + + assert OmegaConf.select(cfg, "predictive_model.foo") == "new" + assert OmegaConf.select(cfg, "predictive_model.bar") == 1 + + +def test_load_config_supports_ids(tmp_path: Path) -> None: + source_cfg = OmegaConf.create({"predictive_model": {"foo": "id"}}) + experiment_name = "demo" + run_name = "reuse" + tracking_uri, experiment_id, run_id = _create_run_with_config( + tmp_path, + source_cfg, + experiment_name=experiment_name, + run_name=run_name, + ) + + cfg = OmegaConf.create({}) + load_cfg = DictConfig( + { + "tracking_uri": tracking_uri, + "experiment_id": experiment_id, + "run_id": run_id, + "configs": { + "predictive_model": "copied", + }, + } + ) + + run_management._load_config(cfg, load_cfg) + + assert OmegaConf.select(cfg, "copied.foo") == "id" + + +def test_load_config_experiment_name_id_mismatch(tmp_path: Path) -> None: + source_cfg = OmegaConf.create({"predictive_model": {"foo": "mismatch"}}) + experiment_name = "demo" + run_name = "reuse" + tracking_uri, experiment_id, run_id = _create_run_with_config( + tmp_path, + source_cfg, + experiment_name=experiment_name, + run_name=run_name, + ) + + cfg = OmegaConf.create({}) + load_cfg = DictConfig( + { + "tracking_uri": tracking_uri, + "experiment_name": "different", + "experiment_id": experiment_id, + "run_name": run_name, + "configs": {"predictive_model": "copied"}, + } + ) + with pytest.raises(ValueError, match="does not match provided experiment_name"): + run_management._load_config(cfg, load_cfg) + + +def test_load_config_run_name_id_mismatch(tmp_path: Path) -> None: + source_cfg = OmegaConf.create({"predictive_model": {"foo": "mismatch"}}) + experiment_name = "demo" + run_name = "reuse" + tracking_uri, experiment_id, run_id = _create_run_with_config( + tmp_path, + source_cfg, + experiment_name=experiment_name, + run_name=run_name, + ) + + cfg = OmegaConf.create({}) + load_cfg = DictConfig( + { + "tracking_uri": tracking_uri, + "experiment_id": experiment_id, + "run_id": run_id, + "run_name": "different", + "configs": {"predictive_model": "copied"}, + } + ) + with pytest.raises(ValueError, match="does not match provided run_name"): + run_management._load_config(cfg, load_cfg)