-
Notifications
You must be signed in to change notification settings - Fork 42
Description
Presently, at the end of every epoch, the trained weights are reloaded via a call to Keras.Models.load_weights() 3x separate times in order to evaluate the accuracy on the shots in the training, validation, and testing sets:
plasma-python/plasma/models/mpi_runner.py
Lines 932 to 965 in c82ba61
| # TODO(KGF): is there a way to avoid Keras.Models.load_weights() | |
| # repeated calls throughout mpi_make_pred*() fn calls? | |
| _, _, _, roc_area, loss = mpi_make_predictions_and_evaluate( | |
| conf, shot_list_validate, loader) | |
| if conf['training']['ranking_difficulty_fac'] != 1.0: | |
| (_, _, _, roc_area_train, | |
| loss_train) = mpi_make_predictions_and_evaluate( | |
| conf, shot_list_train, loader) | |
| batch_generator = partial( | |
| loader.training_batch_generator_partial_reset, | |
| shot_list=shot_list_train) | |
| mpi_model.batch_iterator = batch_generator | |
| mpi_model.batch_iterator_func.__exit__() | |
| mpi_model.num_so_far_accum = mpi_model.num_so_far_indiv | |
| mpi_model.set_batch_iterator_func() | |
| if ('monitor_test' in conf['callbacks'].keys() | |
| and conf['callbacks']['monitor_test']): | |
| times = conf['callbacks']['monitor_times'] | |
| areas, _ = mpi_make_predictions_and_evaluate_multiple_times( | |
| conf, shot_list_validate, loader, times) | |
| epoch_str = 'epoch {}, '.format(int(round(e))) | |
| g.write_unique(epoch_str + ' '.join( | |
| ['val_roc_{} = {}'.format(t, roc) for t, roc in zip( | |
| times, areas)] | |
| ) + '\n') | |
| if shot_list_test is not None: | |
| areas, _ = mpi_make_predictions_and_evaluate_multiple_times( | |
| conf, shot_list_test, loader, times) | |
| g.write_unique(epoch_str + ' '.join( | |
| ['test_roc_{} = {}'.format(t, roc) for t, roc in zip( | |
| times, areas)] | |
| ) + '\n') |
Depending on the size of the datasets (number of shots, pulse length, number of signals per shot), network architecture, and hardware, this process might take a significant amount of time. This is especially noticeable if the epoch walltimes are relatively short due to small batch sizes, etc.
For example, for a recent test with d3d_0D on Traverse 4x V100s:
Finished training epoch 3.01 during this session (1.00 epochs passed) in 87.65 seconds
Finished training of epoch 6.01/1000
Begin evaluation of epoch 6.01/1000
[2] loading from epoch 6
[1] loading from epoch 6
[0] loading from epoch 6
[3] loading from epoch 6
128/894 [===>..........................] - ETA: 1:53
640/894 [====================>.........] - ETA: 13s
896/894 [==============================] - 35s 39ms/step
[0] loading from epoch 6
[3] loading from epoch 6
[1] loading from epoch 6
[2] loading from epoch 6
128/894 [===>..........................] - ETA: 1:53
640/894 [====================>.........] - ETA: 13s
896/894 [==============================] - 35s 39ms/step
epoch 6, val_roc_30 = 0.85346611872694 val_roc_70 = 0.8345022047574768 val_roc_200 = 0.7913309535951044 val_roc_500 = 0.6638869724330323 va\l_roc_1000 = 0.5480697123316435
[3] loading from epoch 6 [2] loading from epoch 6
[0] loading from epoch 6 [1] loading from epoch 6
128/894 [===>..........................] - ETA: 1:53
640/894 [====================>.........] - ETA: 12s 896/894 [==============================] - 35s 39ms/step
epoch 6, test_roc_30 = 0.8400389140546622 test_roc_70 = 0.8236098866020126 test_roc_200 = 0.7792357036451524 test_roc_500 = 0.6798285349466\453 test_roc_1000 = 0.5699692943787431
It seems straightforward to deduplicate the 3x 1:53 load times via a new combined function instead of 2x calls to mpi_make_predictions_and_evaluate_multiple_times() + 1x call to mpi_make_predictions_and_evaluate().