From 498e323aa3bb5af0b655aa2d3b86804d489d5a0c Mon Sep 17 00:00:00 2001 From: Benjamin Bennett <125957+benbennett@users.noreply.github.com> Date: Wed, 24 Sep 2025 19:49:59 -0500 Subject: [PATCH] Add training state checkpointing --- atme.py | 40 +++++++++++--------- models/base_model.py | 87 +++++++++++++++++++++++++++++++++++--------- simple.py | 41 +++++++++++---------- 3 files changed, 113 insertions(+), 55 deletions(-) diff --git a/atme.py b/atme.py index 5188374..fbc92b2 100644 --- a/atme.py +++ b/atme.py @@ -48,13 +48,15 @@ def train(opt): dataset_size = len(dataset) print('The number of training images = %d' % dataset_size) - model = create_model(opt, dataset) - model.setup(opt) - visualizer = Visualizer(opt) - total_iters = 0 - - - for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): + model = create_model(opt, dataset) + model.setup(opt) + visualizer = Visualizer(opt) + start_epoch = getattr(model, 'start_epoch', opt.epoch_count) + total_iters = getattr(model, 'start_iter', 0) + opt.epoch_count = start_epoch + + + for epoch in range(start_epoch, opt.n_epochs + opt.n_epochs_decay + 1): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0 @@ -85,19 +87,21 @@ def train(opt): # if opt.display_id > 0: # visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) - if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations - print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) - save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' - model.save_networks(save_suffix) - visuals = model.get_current_visuals() - slice_num = i if opt.batch_size == 1 else random.randint(0, opt.batch_size) - save_atme_images(visuals, save_fig_dir, slice_num, iter_num=total_iters, epoch=epoch) + if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations + print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) + save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' + model.save_networks(save_suffix) + visuals = model.get_current_visuals() + slice_num = i if opt.batch_size == 1 else random.randint(0, opt.batch_size) + save_atme_images(visuals, save_fig_dir, slice_num, iter_num=total_iters, epoch=epoch) + model.save_training_state(epoch, total_iters) iter_data_time = time.time() - if epoch % opt.save_epoch_freq == 0: # cache our model every epochs - print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) - model.save_networks('latest') - model.save_networks(epoch) + if epoch % opt.save_epoch_freq == 0: # cache our model every epochs + print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) + model.save_networks('latest') + model.save_networks(epoch) + model.save_training_state(epoch + 1, total_iters) # Save D_real and D_fake visualizer.save_D_losses(model.get_current_losses()) diff --git a/models/base_model.py b/models/base_model.py index d84ded8..ab8478a 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -38,10 +38,12 @@ def __init__(self, opt): torch.backends.cudnn.benchmark = True self.loss_names = [] self.model_names = [] - self.visual_names = [] - self.optimizers = [] - self.image_paths = [] - self.metric = 0 # used for learning rate policy 'plateau' + self.visual_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + self.start_epoch = opt.epoch_count + self.start_iter = 0 @abstractmethod def set_input(self, input): @@ -68,12 +70,14 @@ def setup(self, opt): Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions """ - if self.isTrain: - self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] - self.print_networks(opt.verbose) - if not self.isTrain or opt.continue_train: - load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch - self.load_networks(load_suffix, pre_train_G_path=opt.pre_train_G_path) + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + self.print_networks(opt.verbose) + if not self.isTrain or opt.continue_train: + load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch + self.load_networks(load_suffix, pre_train_G_path=opt.pre_train_G_path) + if self.isTrain and opt.continue_train: + self.load_training_state() def eval(self): @@ -159,14 +163,61 @@ def save_specific_networks(self, networks_names, epoch): save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net' + name) - if len(self.gpu_ids) > 1 and torch.cuda.is_available(): - torch.save(net.module.cpu().state_dict(), save_path) - net.cuda(self.gpu_ids[0]) - else: - torch.save(net.state_dict(), save_path) - - def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): - """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + if len(self.gpu_ids) > 1 and torch.cuda.is_available(): + torch.save(net.module.cpu().state_dict(), save_path) + net.cuda(self.gpu_ids[0]) + else: + torch.save(net.state_dict(), save_path) + + def save_training_state(self, epoch, total_iters): + """Save optimizer and scheduler states for resuming training later.""" + if not self.isTrain: + return + + state = { + 'epoch': int(epoch), + 'total_iters': int(total_iters), + 'optimizer_states': [optimizer.state_dict() for optimizer in self.optimizers], + } + + if hasattr(self, 'schedulers'): + state['scheduler_states'] = [scheduler.state_dict() for scheduler in getattr(self, 'schedulers', [])] + + os.makedirs(self.save_dir, exist_ok=True) + save_path = os.path.join(self.save_dir, 'training_state.pth') + torch.save(state, save_path) + + def load_training_state(self): + """Load optimizer and scheduler states to resume training.""" + load_path = os.path.join(self.save_dir, 'training_state.pth') + if not os.path.isfile(load_path): + return + + print('loading training state from %s' % load_path) + state = torch.load(load_path, map_location='cpu') + + optimizer_states = state.get('optimizer_states', []) + if len(optimizer_states) != len(self.optimizers): + print('Warning: number of optimizers does not match when loading training state.') + for optimizer, opt_state in zip(self.optimizers, optimizer_states): + optimizer.load_state_dict(opt_state) + for opt_state_value in optimizer.state.values(): + for k, v in opt_state_value.items(): + if isinstance(v, torch.Tensor): + opt_state_value[k] = v.to(self.device) + + if hasattr(self, 'schedulers'): + scheduler_states = state.get('scheduler_states', []) + if len(scheduler_states) != len(getattr(self, 'schedulers', [])): + print('Warning: number of schedulers does not match when loading training state.') + for scheduler, scheduler_state in zip(getattr(self, 'schedulers', []), scheduler_states): + scheduler.load_state_dict(scheduler_state) + + self.start_epoch = state.get('epoch', self.start_epoch) + self.start_iter = state.get('total_iters', self.start_iter) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith('InstanceNorm') and \ diff --git a/simple.py b/simple.py index 643b3ba..2d3d323 100644 --- a/simple.py +++ b/simple.py @@ -23,21 +23,23 @@ def train(opt): opt.save_dir = os.path.join(opt.main_root, opt.model_root, opt.exp_name) mkdir(opt.save_dir) - model = create_model(opt) - model.setup(opt) - visualizer = Visualizer(opt) - - train_loader = create_simple_train_dataset(opt) - print('prepare data_loader done') - - total_iters = 0 - - figures_path = os.path.join(opt.save_dir, 'figures', 'train') - mkdir(figures_path) - - slice_index = int(opt.patch_size / 2) - - for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): + model = create_model(opt) + model.setup(opt) + visualizer = Visualizer(opt) + + train_loader = create_simple_train_dataset(opt) + print('prepare data_loader done') + + start_epoch = getattr(model, 'start_epoch', opt.epoch_count) + total_iters = getattr(model, 'start_iter', 0) + opt.epoch_count = start_epoch + + figures_path = os.path.join(opt.save_dir, 'figures', 'train') + mkdir(figures_path) + + slice_index = int(opt.patch_size / 2) + + for epoch in range(start_epoch, opt.n_epochs + opt.n_epochs_decay + 1): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0 @@ -64,10 +66,11 @@ def train(opt): visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) iter_data_time = time.time() - if epoch % opt.save_epoch_freq == 0: - print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) - model.save_networks('latest') - model.save_networks(epoch) + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) + model.save_networks('latest') + model.save_networks(epoch) + model.save_training_state(epoch + 1, total_iters) losses = model.get_current_losses() visualizer.save_to_tensorboard_writer(epoch, losses)