From bf8e7a48113bc0ae7a878101f77d71bce8f214c4 Mon Sep 17 00:00:00 2001 From: Benjamin Bennett <125957+benbennett@users.noreply.github.com> Date: Wed, 24 Sep 2025 19:48:53 -0500 Subject: [PATCH] Enhance learning rate scheduling flexibility --- models/base_model.py | 28 ++++--- models/networks.py | 83 ++++++++++++++------ options/atme_options.py | 138 ++++++++++++++++++---------------- options/simple_options.py | 154 ++++++++++++++++++++------------------ 4 files changed, 233 insertions(+), 170 deletions(-) diff --git a/models/base_model.py b/models/base_model.py index d84ded8..64acb4e 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -103,15 +103,25 @@ def get_image_paths(self): def update_learning_rate(self): """Update learning rates for all the networks; called at the end of every epoch""" - old_lr = self.optimizers[0].param_groups[0]['lr'] - for scheduler in self.schedulers: - if self.opt.lr_policy == 'plateau': - scheduler.step(self.metric) - else: - scheduler.step() - - lr = self.optimizers[0].param_groups[0]['lr'] - print('learning rate %.7f -> %.7f' % (old_lr, lr)) + old_lrs = [optimizer.param_groups[0]['lr'] for optimizer in self.optimizers] + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + if hasattr(self.opt, 'lr_min') and self.opt.lr_min not in (None, 0): + for optimizer in self.optimizers: + for param_group in optimizer.param_groups: + param_group['lr'] = max(param_group['lr'], self.opt.lr_min) + + new_lrs = [optimizer.param_groups[0]['lr'] for optimizer in self.optimizers] + if len(self.optimizers) == 1: + print('learning rate %.7f -> %.7f' % (old_lrs[0], new_lrs[0])) + else: + lr_changes = ', '.join( + '%.7f -> %.7f' % (old, new) for old, new in zip(old_lrs, new_lrs) + ) + print(f'learning rates {lr_changes}') def get_current_visuals(self): """Return visualization images. simple.py will display these images with visdom, and save the images to a HTML""" diff --git a/models/networks.py b/models/networks.py index b23aac5..a30961f 100644 --- a/models/networks.py +++ b/models/networks.py @@ -46,30 +46,67 @@ def norm_layer(x): def get_scheduler(optimizer, opt): """Return a learning rate scheduler - Parameters: - optimizer -- the optimizer of the network - opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  - opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine - - For 'linear', we keep the same learning rate for the first epochs - and linearly decay the rate to zero over the next epochs. - For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. - See https://pytorch.org/docs/stable/optim.html for more details. + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine | cosine_restart | exponential | poly + + For 'linear', we keep the same learning rate for the first epochs + and linearly decay the rate to zero over the next epochs. + For other schedulers, we use the default PyTorch schedulers (step, plateau, cosine, + cosine with warm restarts, exponential decay) or the provided polynomial decay rule. + See https://pytorch.org/docs/stable/optim.html for more details. """ - if opt.lr_policy == 'linear': - def lambda_rule(epoch): - lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) - return lr_l - scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) - elif opt.lr_policy == 'step': - scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) - elif opt.lr_policy == 'plateau': - scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) - elif opt.lr_policy == 'cosine': - scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) - else: - return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) - return scheduler + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) + min_scale = opt.lr_min / opt.lr if getattr(opt, 'lr_min', 0) and opt.lr > 0 else 0.0 + return max(lr_l, min_scale) + + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR( + optimizer, + step_size=opt.lr_decay_iters, + gamma=getattr(opt, 'lr_decay_gamma', 0.1), + ) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau( + optimizer, + mode='min', + factor=getattr(opt, 'lr_plateau_factor', 0.2), + threshold=getattr(opt, 'lr_plateau_threshold', 0.01), + patience=getattr(opt, 'lr_plateau_patience', 5), + ) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=opt.n_epochs, + eta_min=getattr(opt, 'lr_min', 0), + ) + elif opt.lr_policy == 'cosine_restart': + scheduler = lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0=getattr(opt, 'lr_restart_period', opt.n_epochs), + T_mult=getattr(opt, 'lr_restart_mult', 1), + eta_min=getattr(opt, 'lr_min', 0), + ) + elif opt.lr_policy == 'exponential': + scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=getattr(opt, 'lr_decay_gamma', 0.1)) + elif opt.lr_policy == 'poly': + total_epochs = max(1, opt.n_epochs + opt.n_epochs_decay) + + def poly_rule(epoch): + progress = min(epoch, total_epochs) / float(total_epochs) + lr_scale = (1 - progress) ** getattr(opt, 'lr_poly_power', 0.9) + if getattr(opt, 'lr_min', 0) and opt.lr > 0: + lr_scale = max(lr_scale, opt.lr_min / opt.lr) + return lr_scale + + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=poly_rule) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler def init_weights(net, init_type='normal', init_gain=0.02): diff --git a/options/atme_options.py b/options/atme_options.py index 1d87408..fa65215 100644 --- a/options/atme_options.py +++ b/options/atme_options.py @@ -1,65 +1,73 @@ -from .base_options import BaseOptions -import argparse - - -class AtmeOptions(BaseOptions): - """This class includes training options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) - # visdom and HTML visualization parameters - parser.add_argument('--plane', type=str, required=True, default='coronal', help='define the plane the atme is trained on') - parser.add_argument('--model_root', type=str, default='atme_coronal_output', help='path to atme coronal images (should have subfolders trainA, trainB, valA, valB, etc)') - parser.add_argument('--TestAfterTrain', default=True, action=argparse.BooleanOptionalAction, help='specify if to test immediatly after train') - parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') - parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') - parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') - parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') - parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') - parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') - parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') - parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') - parser.add_argument('--no_html', type=bool, default=False, help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') - # model parameters - parser.add_argument('--model', type=str, default='atme', help='chooses which model to use. [atme, simple]') - parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') - parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels: 3 for RGB and 1 for grayscale') - parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale') - parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') - parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') - parser.add_argument('--netD', type=str, default='n_layers', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') - parser.add_argument('--netG', type=str, default='unet_256_ddm', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') - parser.add_argument('--n_layers_D', type=int, default=4, help='only used if netD==n_layers') - parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') - parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') - parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') - parser.add_argument('--no_dropout', default=True, action=argparse.BooleanOptionalAction, help='no dropout for the generator') - # dataset parameters - parser.add_argument('--batch_size', type=int, default=1, help='input batch size') - parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') - parser.add_argument('--crop_val', type=int, default=7, help='value for cropping the volume') - parser.add_argument('--stride', type=int, default=7, help='value for cropping the volume') - parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') - parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') - # network saving and loading parameters - parser.add_argument('--pre_train_W_path', type=str, default='', help='load path for pre-trained W model') - parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') - parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') - parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') - parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') - parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') - parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') - # training parameters - parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') - parser.add_argument('--n_epochs_decay', type=int, default=500, help='number of epochs to linearly decay learning rate to zero') - parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') - parser.add_argument('--lr', type=float, default=0.00001, help='initial learning rate for adam') - parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') - parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') - parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') - parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') - - return parser \ No newline at end of file +from .base_options import BaseOptions +import argparse + + +class AtmeOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + # visdom and HTML visualization parameters + parser.add_argument('--plane', type=str, required=True, default='coronal', help='define the plane the atme is trained on') + parser.add_argument('--model_root', type=str, default='atme_coronal_output', help='path to atme coronal images (should have subfolders trainA, trainB, valA, valB, etc)') + parser.add_argument('--TestAfterTrain', default=True, action=argparse.BooleanOptionalAction, help='specify if to test immediatly after train') + parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') + parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') + parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') + parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') + parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') + parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + parser.add_argument('--no_html', type=bool, default=False, help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + # model parameters + parser.add_argument('--model', type=str, default='atme', help='chooses which model to use. [atme, simple]') + parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') + parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') + parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') + parser.add_argument('--netD', type=str, default='n_layers', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') + parser.add_argument('--netG', type=str, default='unet_256_ddm', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') + parser.add_argument('--n_layers_D', type=int, default=4, help='only used if netD==n_layers') + parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') + parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') + parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + parser.add_argument('--no_dropout', default=True, action=argparse.BooleanOptionalAction, help='no dropout for the generator') + # dataset parameters + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') + parser.add_argument('--crop_val', type=int, default=7, help='value for cropping the volume') + parser.add_argument('--stride', type=int, default=7, help='value for cropping the volume') + parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') + parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') + # network saving and loading parameters + parser.add_argument('--pre_train_W_path', type=str, default='', help='load path for pre-trained W model') + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + # training parameters + parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') + parser.add_argument('--n_epochs_decay', type=int, default=500, help='number of epochs to linearly decay learning rate to zero') + parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + parser.add_argument('--lr', type=float, default=0.00001, help='initial learning rate for adam') + parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') + parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') + parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine | cosine_restart | exponential | poly]') + parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + parser.add_argument('--lr_decay_gamma', type=float, default=0.1, help='multiplicative factor of learning rate decay for step/exponential policies') + parser.add_argument('--lr_min', type=float, default=0.0, help='minimum learning rate allowed by schedulers') + parser.add_argument('--lr_plateau_factor', type=float, default=0.2, help='multiplicative factor for ReduceLROnPlateau') + parser.add_argument('--lr_plateau_patience', type=int, default=5, help='patience in epochs for ReduceLROnPlateau') + parser.add_argument('--lr_plateau_threshold', type=float, default=0.01, help='threshold for measuring new optimum in ReduceLROnPlateau') + parser.add_argument('--lr_restart_period', type=int, default=50, help='number of epochs before the first cosine restart') + parser.add_argument('--lr_restart_mult', type=int, default=1, help='multiplicative factor for increasing restart period in cosine_restart policy') + parser.add_argument('--lr_poly_power', type=float, default=0.9, help='power used for polynomial learning rate decay') + + return parser diff --git a/options/simple_options.py b/options/simple_options.py index 1e758b9..4918d6b 100644 --- a/options/simple_options.py +++ b/options/simple_options.py @@ -1,73 +1,81 @@ -from .base_options import BaseOptions -import argparse - - -class SimpleOptions(BaseOptions): - """This class includes training options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) - # visdom and HTML visualization parameters - parser.add_argument('--model_root', default='simple_output', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') - parser.add_argument('--planes_number', required=True, default=2, help='number of planes to train simple') - parser.add_argument('--atme_cor_root', default='atme_coronal_output', help='main path to atme coronal model') - parser.add_argument('--atme_ax_root', default='atme_axial_output', help='main path to atme axial model') - parser.add_argument('--atme_sag_root', default='atme_sagittal_output', help='main path to atme sagittal model') - parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') - parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') - parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') - parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') - parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') - parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') - parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') - parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') - parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') - # model parameters - parser.add_argument('--model', type=str, default='simple', help='chooses which model to use. [atme, simple]') - parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') - parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels: 3 for RGB and 1 for grayscale') - parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale') - parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') - parser.add_argument('--ndf_cor', type=int, default=16, help='# of discrim filters in the first conv layer') - parser.add_argument('--ndf_ax', type=int, default=64, help='# of discrim filters in the first conv layer') - parser.add_argument('--netD_cor', type=str, default='n_layers', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') - parser.add_argument('--netD_ax', type=str, default='n_layers', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') - parser.add_argument('--netG', type=str, default='unet_64', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') - parser.add_argument('--n_layers_D_cor', type=int, default=1, help='only used if netD==n_layers') - parser.add_argument('--n_layers_D_ax', type=int, default=2, help='only used if netD==n_layers') - parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') - parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') - parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') - parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') - # dataset parameters - parser.add_argument('--batch_size', type=int, default=16, help='input batch size') - parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') - parser.add_argument('--patch_size', type=int, default=64, help='patch size') - parser.add_argument('--overlap_ratio', type=float, default=0.125, help='patches overlap ratio') - parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') - parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') - # network saving and loading parameters - parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') - parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') - parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') - parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') - parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') - parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') - # training parameters - parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') - parser.add_argument('--n_epochs_decay', type=int, default=500, help='number of epochs to linearly decay learning rate to zero') - parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') - parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') - parser.add_argument('--cor_coef', type=float, default=0.5, help='coefficient for coronal component in G loss') - parser.add_argument('--ax_coef', type=float, default=0.5, help='coefficient for axial component in G loss') - parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') - parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') - parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') - parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') - - self.isTrain = True - - return parser \ No newline at end of file +from .base_options import BaseOptions +import argparse + + +class SimpleOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + # visdom and HTML visualization parameters + parser.add_argument('--model_root', default='simple_output', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') + parser.add_argument('--planes_number', required=True, default=2, help='number of planes to train simple') + parser.add_argument('--atme_cor_root', default='atme_coronal_output', help='main path to atme coronal model') + parser.add_argument('--atme_ax_root', default='atme_axial_output', help='main path to atme axial model') + parser.add_argument('--atme_sag_root', default='atme_sagittal_output', help='main path to atme sagittal model') + parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') + parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') + parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') + parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') + parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') + parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + # model parameters + parser.add_argument('--model', type=str, default='simple', help='chooses which model to use. [atme, simple]') + parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') + parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') + parser.add_argument('--ndf_cor', type=int, default=16, help='# of discrim filters in the first conv layer') + parser.add_argument('--ndf_ax', type=int, default=64, help='# of discrim filters in the first conv layer') + parser.add_argument('--netD_cor', type=str, default='n_layers', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') + parser.add_argument('--netD_ax', type=str, default='n_layers', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') + parser.add_argument('--netG', type=str, default='unet_64', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') + parser.add_argument('--n_layers_D_cor', type=int, default=1, help='only used if netD==n_layers') + parser.add_argument('--n_layers_D_ax', type=int, default=2, help='only used if netD==n_layers') + parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') + parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') + parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') + # dataset parameters + parser.add_argument('--batch_size', type=int, default=16, help='input batch size') + parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') + parser.add_argument('--patch_size', type=int, default=64, help='patch size') + parser.add_argument('--overlap_ratio', type=float, default=0.125, help='patches overlap ratio') + parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') + parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') + # network saving and loading parameters + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + # training parameters + parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') + parser.add_argument('--n_epochs_decay', type=int, default=500, help='number of epochs to linearly decay learning rate to zero') + parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') + parser.add_argument('--cor_coef', type=float, default=0.5, help='coefficient for coronal component in G loss') + parser.add_argument('--ax_coef', type=float, default=0.5, help='coefficient for axial component in G loss') + parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') + parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') + parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine | cosine_restart | exponential | poly]') + parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + parser.add_argument('--lr_decay_gamma', type=float, default=0.1, help='multiplicative factor of learning rate decay for step/exponential policies') + parser.add_argument('--lr_min', type=float, default=0.0, help='minimum learning rate allowed by schedulers') + parser.add_argument('--lr_plateau_factor', type=float, default=0.2, help='multiplicative factor for ReduceLROnPlateau') + parser.add_argument('--lr_plateau_patience', type=int, default=5, help='patience in epochs for ReduceLROnPlateau') + parser.add_argument('--lr_plateau_threshold', type=float, default=0.01, help='threshold for measuring new optimum in ReduceLROnPlateau') + parser.add_argument('--lr_restart_period', type=int, default=50, help='number of epochs before the first cosine restart') + parser.add_argument('--lr_restart_mult', type=int, default=1, help='multiplicative factor for increasing restart period in cosine_restart policy') + parser.add_argument('--lr_poly_power', type=float, default=0.9, help='power used for polynomial learning rate decay') + + self.isTrain = True + + return parser