Skip to content

Speed up multiple inference steps at the end of each epoch #61

@felker

Description

@felker

Presently, at the end of every epoch, the trained weights are reloaded via a call to Keras.Models.load_weights() 3x separate times in order to evaluate the accuracy on the shots in the training, validation, and testing sets:

# TODO(KGF): is there a way to avoid Keras.Models.load_weights()
# repeated calls throughout mpi_make_pred*() fn calls?
_, _, _, roc_area, loss = mpi_make_predictions_and_evaluate(
conf, shot_list_validate, loader)
if conf['training']['ranking_difficulty_fac'] != 1.0:
(_, _, _, roc_area_train,
loss_train) = mpi_make_predictions_and_evaluate(
conf, shot_list_train, loader)
batch_generator = partial(
loader.training_batch_generator_partial_reset,
shot_list=shot_list_train)
mpi_model.batch_iterator = batch_generator
mpi_model.batch_iterator_func.__exit__()
mpi_model.num_so_far_accum = mpi_model.num_so_far_indiv
mpi_model.set_batch_iterator_func()
if ('monitor_test' in conf['callbacks'].keys()
and conf['callbacks']['monitor_test']):
times = conf['callbacks']['monitor_times']
areas, _ = mpi_make_predictions_and_evaluate_multiple_times(
conf, shot_list_validate, loader, times)
epoch_str = 'epoch {}, '.format(int(round(e)))
g.write_unique(epoch_str + ' '.join(
['val_roc_{} = {}'.format(t, roc) for t, roc in zip(
times, areas)]
) + '\n')
if shot_list_test is not None:
areas, _ = mpi_make_predictions_and_evaluate_multiple_times(
conf, shot_list_test, loader, times)
g.write_unique(epoch_str + ' '.join(
['test_roc_{} = {}'.format(t, roc) for t, roc in zip(
times, areas)]
) + '\n')

Depending on the size of the datasets (number of shots, pulse length, number of signals per shot), network architecture, and hardware, this process might take a significant amount of time. This is especially noticeable if the epoch walltimes are relatively short due to small batch sizes, etc.

For example, for a recent test with d3d_0D on Traverse 4x V100s:

Finished training epoch 3.01 during this session (1.00 epochs passed) in 87.65 seconds
Finished training of epoch 6.01/1000
Begin evaluation of epoch 6.01/1000
[2] loading from epoch 6
[1] loading from epoch 6
[0] loading from epoch 6
[3] loading from epoch 6

128/894 [===>..........................] - ETA: 1:53
640/894 [====================>.........] - ETA: 13s
896/894 [==============================] - 35s 39ms/step
[0] loading from epoch 6
[3] loading from epoch 6
[1] loading from epoch 6
[2] loading from epoch 6

128/894 [===>..........................] - ETA: 1:53
640/894 [====================>.........] - ETA: 13s
896/894 [==============================] - 35s 39ms/step
epoch 6, val_roc_30 = 0.85346611872694 val_roc_70 = 0.8345022047574768 val_roc_200 = 0.7913309535951044 val_roc_500 = 0.6638869724330323 va\l_roc_1000 = 0.5480697123316435
[3] loading from epoch 6                                                                                                                    [2] loading from epoch 6
[0] loading from epoch 6                                                                                                                    [1] loading from epoch 6
                                                                                                                                            128/894 [===>..........................] - ETA: 1:53
640/894 [====================>.........] - ETA: 12s                                                                                         896/894 [==============================] - 35s 39ms/step
epoch 6, test_roc_30 = 0.8400389140546622 test_roc_70 = 0.8236098866020126 test_roc_200 = 0.7792357036451524 test_roc_500 = 0.6798285349466\453 test_roc_1000 = 0.5699692943787431

It seems straightforward to deduplicate the 3x 1:53 load times via a new combined function instead of 2x calls to mpi_make_predictions_and_evaluate_multiple_times() + 1x call to mpi_make_predictions_and_evaluate().

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions