Skip to content

Conversation

@Luodian
Copy link
Collaborator

@Luodian Luodian commented Jan 10, 2026

Summary

  • Add per-rank optimizer state saving/loading to checkpoint utilities
  • Fixes loss spikes and training instability when resuming from checkpoints

Problem

The current checkpoint system saves:

  • backbone.pt - model weights
  • scheduler.pt - LR scheduler state
  • PFC module weights per rank

But it does NOT save optimizer.state_dict(), which contains:

  • exp_avg (1st moment / momentum)
  • exp_avg_sq (2nd moment / RMS accumulator)
  • step count for bias correction

When resuming training without optimizer state, AdamW starts fresh, causing:

  • Transient loss spikes (especially severe if resuming mid-training with decayed LR)
  • Training instability lasting 1k-10k steps (due to beta2=0.999)
  • Wasted compute to recover the optimization trajectory

Solution

  • Add optimizer parameter to save_checkpoint() and load_checkpoint()
  • Save optimizer state per-rank since it includes moments for local PFC shards
  • Backward compatible: old checkpoints without optimizer_*.pt will warn and continue

Files Changed

  • training/checkpoint_utils.py - Core save/load logic
  • training/train.py - Pass optimizer to checkpoint functions

Testing

  • Existing training should work unchanged (optimizer param defaults to None)
  • New checkpoints will include optimizer_{rank:03d}.pt files
  • Resume from new checkpoints will restore full optimizer state

Problem:
- Checkpoints saved backbone, scheduler, and PFC weights but NOT optimizer state
- Resuming training loses AdamW momentum (exp_avg, exp_avg_sq) accumulated over training
- This causes loss spikes and training instability when resuming

Solution:
- Add per-rank optimizer state saving in save_checkpoint()
- Add per-rank optimizer state loading in load_checkpoint()
- Each rank saves its own optimizer state since it includes moments for local PFC shards

Impact:
- Training resume now maintains optimization trajectory
- No more loss spikes from cold optimizer restart
- Backward compatible: old checkpoints without optimizer.pt will warn and continue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants