Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
78b94f9
Refactor configuration structure by centralizing structured configs i…
ealt Nov 5, 2025
39caf50
Delete whitespace
ealt Nov 5, 2025
d919b3c
Refactor configuration management by consolidating structured configs…
ealt Nov 5, 2025
9453a59
Remove configs
ealt Nov 5, 2025
154f63d
Improve validation
ealt Nov 5, 2025
fcfa8c9
Enhance model configuration validation for HookedTransformerConfig to…
ealt Nov 5, 2025
f4d19d1
Refactor validation functions to accept DictConfig instead of specifi…
ealt Nov 6, 2025
c1639ba
Enhance configuration filtering by integrating validation functions f…
ealt Nov 6, 2025
1a73adf
Enhance validation logic in structured_configs.py to include dynamic …
ealt Nov 6, 2025
55719cc
Refactor MLFlowConfig to allow optional tracking and registry URIs, e…
ealt Nov 6, 2025
ac9dd9d
Create config validation error
ealt Nov 6, 2025
d77104e
Refactor config validation to use custom ConfigValidationError for im…
ealt Nov 6, 2025
9677e99
Remove validation config
ealt Nov 6, 2025
3343c42
Ororder structured configs to match order of instantiation
ealt Nov 6, 2025
11c5506
Refactor validation functions in structured_configs.py to consistentl…
ealt Nov 6, 2025
1b79597
Refactor model evaluation and training functions to use DictConfig fo…
ealt Nov 6, 2025
0f99a73
Add unit tests for logging configuration validation in structured_con…
ealt Nov 6, 2025
3c01d90
Refactor validation functions in structured_configs.py to use private…
ealt Nov 6, 2025
e79baf5
Update downgrade_unity_catalog handling in MLflow setup to ensure def…
ealt Nov 6, 2025
355fd54
Update downgrade_unity_catalog type in MLFlowConfig to allow None, en…
ealt Nov 6, 2025
6da204c
Add unit tests for MLFlow configuration validation in test_structured…
ealt Nov 6, 2025
e384b50
Add unit tests for generative process configuration validation in tes…
ealt Nov 6, 2025
77eb709
Enhance generative process configuration validation in structured_con…
ealt Nov 6, 2025
3c7a72f
Refactor validation logic in structured_configs.py to utilize dedicat…
ealt Nov 7, 2025
d2cd3b3
Enhance validation functions in structured_configs.py to allow None v…
ealt Nov 7, 2025
5b8f4ab
Refactor name validation in structured_configs.py to use a unified fu…
ealt Nov 7, 2025
239bc2a
Add support for reusing configuration sections from previous MLflow runs
ealt Nov 7, 2025
054c0b3
Refactor MLflow client usage in tests to streamline configuration loa…
ealt Nov 7, 2025
b31b459
Refactor and enhance validation tests in test_structured_configs.py
ealt Nov 7, 2025
5cc9ef8
Add sequence_len and batch_size to GenerativeProcessConfig with valid…
ealt Nov 7, 2025
1f51d47
Update mess3.yaml and run_management.py to incorporate batch_size and…
ealt Nov 7, 2025
f5346eb
Update configuration files and refactor run_management.py for improve…
ealt Nov 7, 2025
774c5bc
Remove dependence on training config from demo
ealt Nov 7, 2025
cdc9206
Remove training config
ealt Nov 7, 2025
1dc9bce
Add base configuration validation and tests
ealt Nov 7, 2025
ccf1f49
Refactor HookedTransformerConfigConfig and validation logic
ealt Nov 7, 2025
a4c323e
Refactor run_management.py to improve configuration handling
ealt Nov 7, 2025
57a47fb
Merge remote-tracking branch 'origin/main' into strucutred-configs
ealt Nov 7, 2025
ba5f869
Remove obsolete imports
ealt Nov 7, 2025
8f18887
Fix base config tests
ealt Nov 7, 2025
b0af385
Merge remote-tracking branch 'origin/strucutred-configs' into load_co…
ealt Nov 7, 2025
abc2e48
Test instance config
ealt Nov 7, 2025
d66f133
Enhance HookedTransformerConfig and add comprehensive tests
ealt Nov 9, 2025
75f707c
Update validation error messages for empty _target_ in configuration …
ealt Nov 9, 2025
3e6a373
Add tests to test filter with validation function
ealt Nov 10, 2025
28e6706
Add addition mlflow validation tests
ealt Nov 10, 2025
8c06c14
Add tests for validating gen process configs with missing bos/eos tokens
ealt Nov 10, 2025
5e8ddee
Add more tests
ealt Nov 10, 2025
1d9a82c
Fix debug message
ealt Nov 10, 2025
8dc8873
Rename ModelConfig to PredictiveModelConfig
ealt Nov 10, 2025
023678d
Split up configs test module
ealt Nov 10, 2025
dc9036b
Apply suggested fixes
ealt Nov 10, 2025
d148146
Merge remote-tracking branch 'origin/strucutred-configs' into load_co…
ealt Nov 10, 2025
905b5ec
Merge remote-tracking branch 'origin/main' into load_config
ealt Nov 11, 2025
3740966
Enhance load_config functionality to support both experiment and run …
ealt Nov 11, 2025
d8983ea
Fix bugs
ealt Nov 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,32 @@ uv run python simplexity/run_experiment.py --multirun

The `ModelPersister` class is responsible for saving and loading model checkpoints. The `LocalPersister` class saves checkpoints to the local file system, while the `S3Persister` class saves checkpoints to an S3 bucket.

### Reusing Config Sections from MLflow Runs

`managed_run` can bootstrap a new configuration by pulling pieces of previous MLflow runs before Hydra instantiates anything. Add a `load_configs` block to your config (see `examples/configs/demo_config.yaml`):

```yaml
load_configs:
- tracking_uri: databricks # optional; defaults to current URI
experiment_name: /Shared/previous_exp # or provide experiment_id
experiment_id: "123456" # optional safeguard when both are set
run_name: best_model_run # or provide run_id
run_id: "0123456789abcdef"
artifact_path: config.yaml # optional; defaults to config.yaml
configs:
predictive_model: historical.predictive_model
generative_process.instance: reused.generative_process
```

For each entry we:

- Create a dedicated `MlflowClient` using the supplied tracking URI.
- Download the specified artifact (`config.yaml` by default) from the referenced run.
- Copy the listed source keys into the current config using the provided destination paths, merging with any existing values so local overrides still apply.
- Optionally accept `experiment_id`/`run_id` pairs instead of names. When both the name and id are supplied we verify they reference the same MLflow objects to avoid accidental mixups.

This makes it easy to pin specific components (models, generative processes, etc.) from a prior run while still editing other sections locally.

#### Using S3 Storage

The `S3Persister`, can be configured using an `.ini` file, which should have the following structure:
Expand Down
18 changes: 18 additions & 0 deletions examples/configs/demo_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,21 @@ tags:
retention: temp
example_tag_1: value1
example_tag_2: value2

load_configs:
- tracking_uri: databricks
experiment_name: /Shared/managed_run_demo_20251106_045326
# experiment_id: 123456789012345678 # optional alternative to experiment_name
run_name: managed_run_demo_20251106_045326
# run_id: 0123456789abcdef0123456789abcdef # optional alternative to run_name
artifact_path: config.yaml
configs:
predictive_model: old_model_1
generative_process: old_generative_process_1
- tracking_uri: databricks
experiment_name: /Shared/managed_run_demo_20251105_003639
run_name: managed_run_demo_20251105_003639
artifact_path: config.yaml
configs:
predictive_model: old_model_2
generative_process: old_generative_process_2
135 changes: 133 additions & 2 deletions simplexity/run_management/run_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import random
import subprocess
import tempfile
import warnings
from collections.abc import Callable, Iterator
from contextlib import contextmanager, nullcontext
Expand All @@ -18,7 +19,8 @@
import jax
import jax.numpy as jnp
import mlflow
from omegaconf import DictConfig, OmegaConf
from mlflow.exceptions import MlflowException
from omegaconf import DictConfig, OmegaConf, open_dict
from torch.nn import Module as PytorchModel

from simplexity.generative_processes.generative_process import GenerativeProcess
Expand Down Expand Up @@ -96,6 +98,135 @@ def _suppress_pydantic_field_attribute_warning() -> Iterator[None]:
yield


def _load_config(cfg: DictConfig, load_config: DictConfig) -> None:
"""Load the config."""
if not load_config:
SIMPLEXITY_LOGGER.warning("[config] load_config entry is empty, skipping")
return

tracking_uri: str | None = load_config.get("tracking_uri")
experiment_name: str | None = load_config.get("experiment_name")
experiment_id: str | None = load_config.get("experiment_id")
run_name: str | None = load_config.get("run_name")
run_id: str | None = load_config.get("run_id")
configs_to_load: DictConfig | None = load_config.get("configs")
artifact_path: str = load_config.get("artifact_path", "config.yaml")

if not configs_to_load:
run_identifier = run_name or run_id or "<unknown>"
SIMPLEXITY_LOGGER.warning(
f"[config] no configs specified for load_config run '{run_identifier}', nothing to merge"
)
return

client = mlflow.MlflowClient(tracking_uri=tracking_uri)

if experiment_id:
experiment = client.get_experiment(experiment_id)
if experiment is None:
raise ValueError(
f"Experiment with id '{experiment_id}' not found for load_config "
f"run '{run_name or run_id or '<unknown>'}'"
)
if experiment_name:
if experiment.name != experiment_name:
raise ValueError(
f"Experiment id '{experiment_id}' refers to '{experiment.name}', which does not match "
f"provided experiment_name '{experiment_name}'"
)
else:
experiment_name = experiment.name
elif experiment_name:
experiment = client.get_experiment_by_name(experiment_name)
if experiment is None:
raise ValueError(
f"Experiment '{experiment_name}' not found for load_config run '{run_name or run_id or '<unknown>'}'"
)
experiment_id = experiment.experiment_id
else:
raise ValueError("load_config requires experiment_name or experiment_id")
assert experiment_id is not None
assert experiment_name is not None

if run_id:
try:
run = client.get_run(run_id)
except MlflowException as e: # pragma: no cover - mlflow always raises for missing runs
raise ValueError(
f"Run with id '{run_id}' not found in experiment '{experiment_name}' for load_config entry"
) from e
if run.info.experiment_id != experiment_id:
raise ValueError(
f"Run id '{run_id}' belongs to experiment id '{run.info.experiment_id}', expected '{experiment_id}'"
)
if run_name and run.info.run_name != run_name:
raise ValueError(
f"Run id '{run_id}' refers to run '{run.info.run_name}', which does not match "
f"provided run_name '{run_name}'"
)
elif run_name:
runs = client.search_runs(
experiment_ids=[experiment_id],
filter_string=f"attributes.run_name = '{run_name}'",
max_results=1,
)
if not runs:
raise ValueError(
f"Run with name '{run_name}' not found in experiment '{experiment_name}' for load_config entry"
)
run = runs[0]
else:
raise ValueError("load_config requires run_name or run_id")

run_id = run.info.run_id
run_name = run.info.run_name
assert run_id is not None
assert run_name is not None

SIMPLEXITY_LOGGER.info(
f"[config] loading artifact '{artifact_path}' from run '{run_name}' ({run_id}) "
f"in experiment '{experiment_name}' ({experiment_id})"
)
with tempfile.TemporaryDirectory() as temp_dir:
artifact_local_path = client.download_artifacts(run_id, artifact_path, temp_dir)
source_cfg = OmegaConf.load(artifact_local_path)

configs_mapping: dict[str, str] = OmegaConf.to_container(configs_to_load, resolve=True) # type: ignore[arg-type]
with open_dict(cfg):
for source_key, destination_key in configs_mapping.items():
if not isinstance(source_key, str) or not source_key:
raise ValueError("load_config configs keys must be non-empty strings")
if not isinstance(destination_key, str) or not destination_key:
raise ValueError("load_config configs values must be non-empty strings")

selected_config = OmegaConf.select(source_cfg, source_key, throw_on_missing=False)
if selected_config is None:
raise KeyError(f"Config key '{source_key}' not found in run '{run_name}' artifact '{artifact_path}'")

cloned_config = OmegaConf.create(OmegaConf.to_container(selected_config, resolve=False))
existing_destination = OmegaConf.select(cfg, destination_key, throw_on_missing=False)
if existing_destination is None:
SIMPLEXITY_LOGGER.info(
f"[config] adding config '{source_key}' from run '{run_name}' to '{destination_key}'"
)
OmegaConf.update(cfg, destination_key, cloned_config, force_add=True)
else:
SIMPLEXITY_LOGGER.info(
f"[config] merging config '{source_key}' from run '{run_name}' into '{destination_key}'"
)
merged_config = OmegaConf.merge(cloned_config, existing_destination)
OmegaConf.update(cfg, destination_key, merged_config, force_add=True)


def _get_config(args: tuple[Any, ...], kwargs: dict[str, Any]) -> DictConfig:
"""Get the config from the arguments."""
cfg = get_config(args, kwargs)
load_configs: list[DictConfig] = cfg.get("load_configs", [])
for load_config in load_configs:
_load_config(cfg, load_config)
return cfg


def _setup_environment() -> None:
"""Setup the environment."""
for key, value in DEFAULT_ENVIRONMENT_VARIABLES.items():
Expand Down Expand Up @@ -512,7 +643,7 @@ def managed_run(strict: bool = True, verbose: bool = False) -> Callable[[Callabl
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
cfg = get_config(args, kwargs)
cfg = _get_config(args, kwargs)
validate_base_config(cfg)
with _setup_mlflow(cfg):
components = _setup(cfg, strict=strict, verbose=verbose)
Expand Down
Loading
Loading