diff --git a/atme.py b/atme.py index 5188374..89cc5bf 100644 --- a/atme.py +++ b/atme.py @@ -56,14 +56,13 @@ def train(opt): for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): epoch_start_time = time.time() - iter_data_time = time.time() - epoch_iter = 0 - visualizer.reset() - model.update_learning_rate() - for i, data in enumerate(dataset): - iter_start_time = time.time() - if total_iters % opt.print_freq == 0: - t_data = iter_start_time - iter_data_time + iter_data_time = time.time() + epoch_iter = 0 + visualizer.reset() + for i, data in enumerate(dataset): + iter_start_time = time.time() + if total_iters % opt.print_freq == 0: + t_data = iter_start_time - iter_data_time total_iters += opt.batch_size epoch_iter += opt.batch_size @@ -105,7 +104,8 @@ def train(opt): losses = model.get_current_losses() visualizer.save_to_tensorboard_writer(epoch, losses) - print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) + print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) + model.update_learning_rate(epoch) def test(opt): opt.isTrain = False diff --git a/models/base_model.py b/models/base_model.py index d84ded8..a95c92d 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -101,17 +101,32 @@ def get_image_paths(self): """ Return image paths that are used to load current data""" return self.image_paths - def update_learning_rate(self): - """Update learning rates for all the networks; called at the end of every epoch""" - old_lr = self.optimizers[0].param_groups[0]['lr'] - for scheduler in self.schedulers: - if self.opt.lr_policy == 'plateau': - scheduler.step(self.metric) - else: - scheduler.step() - - lr = self.optimizers[0].param_groups[0]['lr'] - print('learning rate %.7f -> %.7f' % (old_lr, lr)) + def update_learning_rate(self, epoch=None): + """Update learning rates for all the networks; called at the end of every epoch""" + old_lr = self.optimizers[0].param_groups[0]['lr'] + if epoch is not None: + # match the behaviour of stepping at the *start* of an epoch by + # offsetting the scheduler's epoch index. This keeps learning rate + # milestones identical to the previous implementation while + # allowing the update to happen after ``optimizer.step()`` has been + # called at least once. + scheduler_epoch = max(0, epoch - self.opt.epoch_count + 1) + else: + scheduler_epoch = None + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + if scheduler_epoch is None: + scheduler.step(self.metric) + else: + scheduler.step(self.metric, epoch=scheduler_epoch) + else: + if scheduler_epoch is None: + scheduler.step() + else: + scheduler.step(scheduler_epoch) + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate %.7f -> %.7f' % (old_lr, lr)) def get_current_visuals(self): """Return visualization images. simple.py will display these images with visdom, and save the images to a HTML""" diff --git a/simple.py b/simple.py index 643b3ba..d8992a1 100644 --- a/simple.py +++ b/simple.py @@ -42,12 +42,10 @@ def train(opt): iter_data_time = time.time() epoch_iter = 0 visualizer.reset() - if epoch > 1: - model.update_learning_rate() - for i, data in enumerate(train_loader): - iter_start_time = time.time() - if total_iters % opt.print_freq == 0: - t_data = iter_start_time - iter_data_time + for i, data in enumerate(train_loader): + iter_start_time = time.time() + if total_iters % opt.print_freq == 0: + t_data = iter_start_time - iter_data_time total_iters += opt.batch_size epoch_iter += opt.batch_size @@ -72,7 +70,8 @@ def train(opt): losses = model.get_current_losses() visualizer.save_to_tensorboard_writer(epoch, losses) - print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) + print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) + model.update_learning_rate(epoch) def test(opt): opt.isTrain = False