Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,25 @@ def get_image_paths(self):

def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
old_lr = self.optimizers[0].param_groups[0]['lr']
for scheduler in self.schedulers:
if self.opt.lr_policy == 'plateau':
scheduler.step(self.metric)
else:
scheduler.step()

lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate %.7f -> %.7f' % (old_lr, lr))
old_lrs = [optimizer.param_groups[0]['lr'] for optimizer in self.optimizers]
for scheduler in self.schedulers:
if self.opt.lr_policy == 'plateau':
scheduler.step(self.metric)
else:
scheduler.step()
if hasattr(self.opt, 'lr_min') and self.opt.lr_min not in (None, 0):
for optimizer in self.optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] = max(param_group['lr'], self.opt.lr_min)

new_lrs = [optimizer.param_groups[0]['lr'] for optimizer in self.optimizers]
if len(self.optimizers) == 1:
print('learning rate %.7f -> %.7f' % (old_lrs[0], new_lrs[0]))
else:
lr_changes = ', '.join(
'%.7f -> %.7f' % (old, new) for old, new in zip(old_lrs, new_lrs)
)
print(f'learning rates {lr_changes}')

def get_current_visuals(self):
"""Return visualization images. simple.py will display these images with visdom, and save the images to a HTML"""
Expand Down
83 changes: 60 additions & 23 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,67 @@ def norm_layer(x):
def get_scheduler(optimizer, opt):
"""Return a learning rate scheduler

Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine

For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
See https://pytorch.org/docs/stable/optim.html for more details.
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine | cosine_restart | exponential | poly

For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
For other schedulers, we use the default PyTorch schedulers (step, plateau, cosine,
cosine with warm restarts, exponential decay) or the provided polynomial decay rule.
See https://pytorch.org/docs/stable/optim.html for more details.
"""
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
min_scale = opt.lr_min / opt.lr if getattr(opt, 'lr_min', 0) and opt.lr > 0 else 0.0
return max(lr_l, min_scale)

scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(
optimizer,
step_size=opt.lr_decay_iters,
gamma=getattr(opt, 'lr_decay_gamma', 0.1),
)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=getattr(opt, 'lr_plateau_factor', 0.2),
threshold=getattr(opt, 'lr_plateau_threshold', 0.01),
patience=getattr(opt, 'lr_plateau_patience', 5),
)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=opt.n_epochs,
eta_min=getattr(opt, 'lr_min', 0),
)
elif opt.lr_policy == 'cosine_restart':
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=getattr(opt, 'lr_restart_period', opt.n_epochs),
T_mult=getattr(opt, 'lr_restart_mult', 1),
eta_min=getattr(opt, 'lr_min', 0),
)
elif opt.lr_policy == 'exponential':
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=getattr(opt, 'lr_decay_gamma', 0.1))
elif opt.lr_policy == 'poly':
total_epochs = max(1, opt.n_epochs + opt.n_epochs_decay)

def poly_rule(epoch):
progress = min(epoch, total_epochs) / float(total_epochs)
lr_scale = (1 - progress) ** getattr(opt, 'lr_poly_power', 0.9)
if getattr(opt, 'lr_min', 0) and opt.lr > 0:
lr_scale = max(lr_scale, opt.lr_min / opt.lr)
return lr_scale

scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=poly_rule)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler


def init_weights(net, init_type='normal', init_gain=0.02):
Expand Down
Loading