Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions training/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def save_checkpoint(
amp,
global_step,
list_head_names,
keep_num=5):
keep_num=5,
optimizer=None):
"""
Save training state, create a separate folder for each step, and subfolders for each PFC head

Expand All @@ -58,12 +59,13 @@ def save_checkpoint(
global_step: Current global step count
list_head_names: List of head names
keep_num: Number of checkpoints to keep
optimizer: Optimizer object (AdamW) - CRITICAL for proper resume
"""
# Create folder for current step
step_dir = os.path.join(output_dir, f"{global_step:08d}")
os.makedirs(step_dir, exist_ok=True)

# Save backbone model and optimizer state (only on rank 0)
# Save backbone model and scheduler state (only on rank 0)
if rank == 0:
# Save backbone model (move to CPU)
backbone_path = os.path.join(step_dir, "backbone.pt")
Expand All @@ -81,6 +83,21 @@ def save_checkpoint(

logging.info(f"Backbone, scheduler saved at step {global_step}")

# Save optimizer state per-rank (CRITICAL for proper resume)
# Optimizer contains per-rank PFC parameter moments, so each rank saves its own
if optimizer is not None:
optimizer_path = os.path.join(step_dir, f"optimizer_{rank:03d}.pt")
# Move optimizer state tensors to CPU for saving
opt_state_dict = optimizer.state_dict()
cpu_state_dict = {
'state': {k: {sk: sv.cpu() if isinstance(sv, torch.Tensor) else sv
for sk, sv in v.items()}
for k, v in opt_state_dict['state'].items()},
'param_groups': opt_state_dict['param_groups']
}
torch.save(cpu_state_dict, optimizer_path)
logging.info(f"Rank {rank}: Optimizer state saved at step {global_step}")

if isinstance(pfc_modules, list):
# Each rank saves its own PFC module
for head_id, (head_name, pfc) in enumerate(zip(list_head_names, pfc_modules)):
Expand Down Expand Up @@ -163,7 +180,7 @@ def clean_old_checkpoints(output_dir, keep_num=5):


def load_checkpoint(output_dir, step, backbone, pfc_modules, lr_scheduler,
amp, list_head_names):
amp, list_head_names, optimizer=None):
"""
Load training state from checkpoint folder at specified step

Expand All @@ -175,6 +192,7 @@ def load_checkpoint(output_dir, step, backbone, pfc_modules, lr_scheduler,
lr_scheduler: Learning rate scheduler
amp: Automatic mixed precision object
list_head_names: List of head names
optimizer: Optimizer object (AdamW) - CRITICAL for proper resume

Returns:
dict: Contains restored global step information
Expand Down Expand Up @@ -282,6 +300,16 @@ def load_checkpoint(output_dir, step, backbone, pfc_modules, lr_scheduler,
else:
logging.warning(f"AMP state file not found: {amp_file}")

# Load optimizer state per-rank (CRITICAL for proper resume)
if optimizer is not None:
optimizer_file = os.path.join(step_dir, f"optimizer_{rank:03d}.pt")
if os.path.exists(optimizer_file):
optimizer.load_state_dict(torch.load(optimizer_file, ))
logging.info(f"Rank {rank}: Loaded optimizer state from step {step}")
else:
logging.warning(f"Rank {rank}: Optimizer state file not found: {optimizer_file}")
logging.warning("Training will resume with fresh optimizer moments - expect temporary loss spike")

return {
'global_step': step
}
3 changes: 3 additions & 0 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def _expand(name, v):
lr_scheduler,
None,
args.list_head_names,
optimizer=opt, # Pass optimizer for proper resume with AdamW moments
)
if result is not None:
global_step = result["global_step"]
Expand Down Expand Up @@ -691,6 +692,7 @@ def wrap_ddp(model):
global_step=global_step,
list_head_names=args.list_head_names,
keep_num=20,
optimizer=opt, # Save optimizer state for proper resume
)
# Also save in HuggingFace format
save_hf_checkpoint(args.output, backbone, global_step=global_step, image_size=args.image_size[0])
Expand All @@ -705,6 +707,7 @@ def wrap_ddp(model):
global_step=global_step,
list_head_names=args.list_head_names,
keep_num=20,
optimizer=opt, # Save optimizer state for proper resume
)
# Also save final model in HuggingFace format
save_hf_checkpoint(args.output, backbone, global_step=global_step, image_size=args.image_size[0])
Expand Down