-
Notifications
You must be signed in to change notification settings - Fork 28
Add sequential run #743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add sequential run #743
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for sequential, multi-stage experiment runs (checkpoint-chained training stages) and extends checkpoint naming/configuration to better support recurrent/latent-chaining and MeZO training variants.
Changes:
- Added
sequential_runssupport to the experiment runner to orchestrate multi-stage scripts with checkpoint handoff. - Introduced configurable checkpoint output filenames (
output_ckpt,mezo_output_ckpt,recurrent_output_ckpt) and wired them into training scripts. - Modularized latent-chaining recurrent loss into a new
recurrent_block_variantsmodule and refactoredtrain_recurrent.pyaccordingly.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
optimization_and_search/run_experiments.py |
Adds sequential multi-stage orchestration (sequential_runs) and script selection/checkpoint chaining. |
demos/sequential_run_experiments_demo.yaml |
Demonstrates a multi-stage sequential run pipeline (train → finetune → recurrent → MeZO). |
train_args.py |
Adds new CLI args for checkpoint output filenames and groups MeZO/recurrent options. |
train.py |
Uses --output_ckpt for checkpoint saving instead of hard-coded ckpt.pt. |
train_mezo.py |
Saves checkpoints using --mezo_output_ckpt. |
train_recurrent.py |
Refactors recurrent training into modular components and adds recurrent checkpoint filename + variant selection. |
recurrent_variations/recurrent_block_variants.py |
New module implementing latent-chaining recurrent block loss + config. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if args.always_save_checkpoint: | ||
| save_checkpoint( | ||
| model=model, | ||
| ckpt_model_args=ckpt_model_args, | ||
| ckpt_path=best_ckpt_path, | ||
| best_val_loss=state.best_val_loss, | ||
| global_step=state.global_step, | ||
| tag="latest", | ||
| ) |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When always_save_checkpoint is enabled, save_checkpoint() is called with ckpt_path=best_ckpt_path, which overwrites the best checkpoint file with a “latest” checkpoint. This makes it impossible to reliably keep the best checkpoint. Use a separate filename/path for the “latest” checkpoint (or only write “latest” when it’s also best).
| ptr = 0 | ||
| data_view = data | ||
| total_tokens = len(data_view) - 1 | ||
| save_enabled = args.always_save_checkpoint or not args.never_save_checkpoint |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checkpoint saving currently ignores never_save_checkpoint when always_save_checkpoint is true (and save_enabled also allows saving when never_save_checkpoint is true). In train.py, never_save_checkpoint always prevents saving regardless of always_save_checkpoint; aligning this behavior avoids surprising writes in “no-save” mode.
| save_enabled = args.always_save_checkpoint or not args.never_save_checkpoint | |
| save_enabled = not args.never_save_checkpoint |
| parser.add_argument("--weight_end", type=float, default=1.0) | ||
| parser.add_argument( | ||
| "--reset_optim", | ||
| action="store_true", |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--reset_optim is defined as store_true here, so the command builder in run_experiments.py can emit --no-reset_optim (for false) and cause an unknown-arg failure. Since the rest of the codebase uses BooleanOptionalAction for these flags, consider switching --reset_optim to argparse.BooleanOptionalAction with an explicit default to keep CLI behavior consistent and compatible with sequential runs.
| action="store_true", | |
| action=argparse.BooleanOptionalAction, | |
| default=False, |
| def _normalize_sequential_runs(sequential_runs: object) -> list[dict]: | ||
| if sequential_runs is None: | ||
| return [] | ||
| if isinstance(sequential_runs, dict): | ||
| return [sequential_runs] |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sequential_runs adds new multi-stage execution behavior, but there is no automated coverage for it. The existing tests/test_run_experiments.sh covers other config features; adding a small YAML under tests/run_optimization_tests/ that exercises sequential_runs (resume + input_ckpt chaining) would help prevent regressions.
| # Read metrics (use existing or nan on failure) | ||
| try: | ||
| metrics = read_metrics(str(combo['out_dir'])) | ||
| except Exception: | ||
| metrics = {k: float("nan") for k in METRIC_KEYS} |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run_experiment() still derives metrics only from best_val_loss_and_iter.txt (written by train.py). With sequential runs, stages like train_mezo.py / train_recurrent.py don’t write this file, so successful runs can be logged with NaN metrics. Consider standardizing metric output across scripts or emitting a clear warning/selecting metrics from an earlier stage that produces them.
| best_val_loss = ckpt["best_val_loss"].item() | ||
| print("best_val_loss", best_val_loss) | ||
| best_val_loss = 5.00 # TODO: allow configurable start threshold |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
best_val_loss is loaded from the resume checkpoint but then unconditionally overwritten to 5.00. This breaks true “best so far” tracking (and can prevent saving if losses stay above/below that threshold) and makes resumed recurrent runs non-reproducible. Prefer initializing best_val_loss from the checkpoint (or inf), and if a custom threshold is desired, expose it as an explicit CLI flag instead of hard-coding.
| best_val_loss = ckpt["best_val_loss"].item() | |
| print("best_val_loss", best_val_loss) | |
| best_val_loss = 5.00 # TODO: allow configurable start threshold | |
| # Initialize best_val_loss from checkpoint if available; otherwise use +inf | |
| best_val_loss = float("inf") | |
| if "best_val_loss" in ckpt and ckpt["best_val_loss"] is not None: | |
| best_val_loss = ckpt["best_val_loss"].item() | |
| print("best_val_loss", best_val_loss) |
This pull request introduces support for sequential multi-stage experiment runs, improves flexibility for checkpoint naming and management, and adds infrastructure for recurrent/latent-chaining training variants. Key changes include enhancements to experiment configuration and orchestration, new argument options for checkpoint handling, and a new module for recurrent block variants.
Experiment orchestration and configuration:
Added support for
sequential_runsin experiment YAML files, allowing experiments to be defined as a sequence of training stages (e.g., base training, fine-tuning, recurrent, and MeZO stages) with explicit checkpoint handoff between stages. (demos/sequential_run_experiments_demo.yaml,optimization_and_search/run_experiments.py) [1] [2] [3] [4] [5] [6] [7]Improved the experiment runner to handle per-stage script selection, argument passing, checkpoint chaining, and output directory management for sequential runs. (
optimization_and_search/run_experiments.py) [1] [2] [3] [4]Checkpoint and argument handling:
output_ckpt,mezo_output_ckpt,recurrent_output_ckpt), and updated training scripts to use these names when saving checkpoints. (train_args.py,train.py,train_mezo.py) [1] [2] [3] [4] [5] [6] [7]Recurrent/latent-chaining infrastructure:
recurrent_block_variants.pyimplementing the latent chaining recurrent block loss and a configuration class, enabling flexible experimentation with recurrent training strategies. (recurrent_variations/recurrent_block_variants.py)CLI and argument group improvements:
train_args.py) [1] [2]These changes collectively enable more complex, reproducible experiment pipelines, more flexible checkpointing, and new research directions with recurrent block variants.