Skip to content

Conversation

@klei22
Copy link
Collaborator

@klei22 klei22 commented Feb 8, 2026

This pull request introduces support for sequential multi-stage experiment runs, improves flexibility for checkpoint naming and management, and adds infrastructure for recurrent/latent-chaining training variants. Key changes include enhancements to experiment configuration and orchestration, new argument options for checkpoint handling, and a new module for recurrent block variants.

Experiment orchestration and configuration:

  • Added support for sequential_runs in experiment YAML files, allowing experiments to be defined as a sequence of training stages (e.g., base training, fine-tuning, recurrent, and MeZO stages) with explicit checkpoint handoff between stages. (demos/sequential_run_experiments_demo.yaml, optimization_and_search/run_experiments.py) [1] [2] [3] [4] [5] [6] [7]

  • Improved the experiment runner to handle per-stage script selection, argument passing, checkpoint chaining, and output directory management for sequential runs. (optimization_and_search/run_experiments.py) [1] [2] [3] [4]

Checkpoint and argument handling:

  • Added new CLI arguments for specifying output checkpoint filenames for each training script (output_ckpt, mezo_output_ckpt, recurrent_output_ckpt), and updated training scripts to use these names when saving checkpoints. (train_args.py, train.py, train_mezo.py) [1] [2] [3] [4] [5] [6] [7]

Recurrent/latent-chaining infrastructure:

  • Added a new module recurrent_block_variants.py implementing the latent chaining recurrent block loss and a configuration class, enabling flexible experimentation with recurrent training strategies. (recurrent_variations/recurrent_block_variants.py)

CLI and argument group improvements:

  • Grouped new recurrent and MeZO options under dedicated argument groups in the CLI for better organization and discoverability. (train_args.py) [1] [2]

These changes collectively enable more complex, reproducible experiment pipelines, more flexible checkpointing, and new research directions with recurrent block variants.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for sequential, multi-stage experiment runs (checkpoint-chained training stages) and extends checkpoint naming/configuration to better support recurrent/latent-chaining and MeZO training variants.

Changes:

  • Added sequential_runs support to the experiment runner to orchestrate multi-stage scripts with checkpoint handoff.
  • Introduced configurable checkpoint output filenames (output_ckpt, mezo_output_ckpt, recurrent_output_ckpt) and wired them into training scripts.
  • Modularized latent-chaining recurrent loss into a new recurrent_block_variants module and refactored train_recurrent.py accordingly.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
optimization_and_search/run_experiments.py Adds sequential multi-stage orchestration (sequential_runs) and script selection/checkpoint chaining.
demos/sequential_run_experiments_demo.yaml Demonstrates a multi-stage sequential run pipeline (train → finetune → recurrent → MeZO).
train_args.py Adds new CLI args for checkpoint output filenames and groups MeZO/recurrent options.
train.py Uses --output_ckpt for checkpoint saving instead of hard-coded ckpt.pt.
train_mezo.py Saves checkpoints using --mezo_output_ckpt.
train_recurrent.py Refactors recurrent training into modular components and adds recurrent checkpoint filename + variant selection.
recurrent_variations/recurrent_block_variants.py New module implementing latent-chaining recurrent block loss + config.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +334 to +342
if args.always_save_checkpoint:
save_checkpoint(
model=model,
ckpt_model_args=ckpt_model_args,
ckpt_path=best_ckpt_path,
best_val_loss=state.best_val_loss,
global_step=state.global_step,
tag="latest",
)
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When always_save_checkpoint is enabled, save_checkpoint() is called with ckpt_path=best_ckpt_path, which overwrites the best checkpoint file with a “latest” checkpoint. This makes it impossible to reliably keep the best checkpoint. Use a separate filename/path for the “latest” checkpoint (or only write “latest” when it’s also best).

Copilot uses AI. Check for mistakes.
ptr = 0
data_view = data
total_tokens = len(data_view) - 1
save_enabled = args.always_save_checkpoint or not args.never_save_checkpoint
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoint saving currently ignores never_save_checkpoint when always_save_checkpoint is true (and save_enabled also allows saving when never_save_checkpoint is true). In train.py, never_save_checkpoint always prevents saving regardless of always_save_checkpoint; aligning this behavior avoids surprising writes in “no-save” mode.

Suggested change
save_enabled = args.always_save_checkpoint or not args.never_save_checkpoint
save_enabled = not args.never_save_checkpoint

Copilot uses AI. Check for mistakes.
parser.add_argument("--weight_end", type=float, default=1.0)
parser.add_argument(
"--reset_optim",
action="store_true",
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--reset_optim is defined as store_true here, so the command builder in run_experiments.py can emit --no-reset_optim (for false) and cause an unknown-arg failure. Since the rest of the codebase uses BooleanOptionalAction for these flags, consider switching --reset_optim to argparse.BooleanOptionalAction with an explicit default to keep CLI behavior consistent and compatible with sequential runs.

Suggested change
action="store_true",
action=argparse.BooleanOptionalAction,
default=False,

Copilot uses AI. Check for mistakes.
Comment on lines +638 to +642
def _normalize_sequential_runs(sequential_runs: object) -> list[dict]:
if sequential_runs is None:
return []
if isinstance(sequential_runs, dict):
return [sequential_runs]
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sequential_runs adds new multi-stage execution behavior, but there is no automated coverage for it. The existing tests/test_run_experiments.sh covers other config features; adding a small YAML under tests/run_optimization_tests/ that exercises sequential_runs (resume + input_ckpt chaining) would help prevent regressions.

Copilot uses AI. Check for mistakes.
Comment on lines 793 to 797
# Read metrics (use existing or nan on failure)
try:
metrics = read_metrics(str(combo['out_dir']))
except Exception:
metrics = {k: float("nan") for k in METRIC_KEYS}
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_experiment() still derives metrics only from best_val_loss_and_iter.txt (written by train.py). With sequential runs, stages like train_mezo.py / train_recurrent.py don’t write this file, so successful runs can be logged with NaN metrics. Consider standardizing metric output across scripts or emitting a clear warning/selecting metrics from an earlier stage that produces them.

Copilot uses AI. Check for mistakes.
Comment on lines +386 to +388
best_val_loss = ckpt["best_val_loss"].item()
print("best_val_loss", best_val_loss)
best_val_loss = 5.00 # TODO: allow configurable start threshold
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

best_val_loss is loaded from the resume checkpoint but then unconditionally overwritten to 5.00. This breaks true “best so far” tracking (and can prevent saving if losses stay above/below that threshold) and makes resumed recurrent runs non-reproducible. Prefer initializing best_val_loss from the checkpoint (or inf), and if a custom threshold is desired, expose it as an explicit CLI flag instead of hard-coding.

Suggested change
best_val_loss = ckpt["best_val_loss"].item()
print("best_val_loss", best_val_loss)
best_val_loss = 5.00 # TODO: allow configurable start threshold
# Initialize best_val_loss from checkpoint if available; otherwise use +inf
best_val_loss = float("inf")
if "best_val_loss" in ckpt and ckpt["best_val_loss"] is not None:
best_val_loss = ckpt["best_val_loss"].item()
print("best_val_loss", best_val_loss)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant