From bf2932213a59bcebf6154d3e63d4532229ae6bbb Mon Sep 17 00:00:00 2001 From: MorinoseiMorizo Date: Fri, 6 Jan 2017 11:44:52 +0900 Subject: [PATCH 1/2] add function to stop training when we cannot get lower dev ppl for some hours --- sample_data/sample_config.ini | 4 ++++ src/bin/train.cc | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/sample_data/sample_config.ini b/sample_data/sample_config.ini index 8cbf865..07f870e 100644 --- a/sample_data/sample_config.ini +++ b/sample_data/sample_config.ini @@ -180,6 +180,10 @@ dropout_ratio=0.3 ; Maximum number of batch data to be trained. max_iteration=10000 +; Training will be stopped if we cannot get the lowest dev ppl. for this hours. +; If this is set to 0, we don't stop training with this criteria. +max_waiting_hour=24 + ; Timing of evaluating dev/test set. Available options: ; * step .... Number of steps (iterations). ; * sample .... Number of samples (sentences). diff --git a/src/bin/train.cc b/src/bin/train.cc index 06f8e61..9e31809 100644 --- a/src/bin/train.cc +++ b/src/bin/train.cc @@ -450,6 +450,10 @@ int main(int argc, char * argv[]) { const float dropout_ratio = config.get("Train.dropout_ratio"); const unsigned max_iteration = config.get("Train.max_iteration"); + const unsigned max_waiting_hour = config.get("Train.max_waiting_hour"); + auto scheduled_training_finish_time = std::chrono::system_clock::to_time_t( + std::chrono::system_clock::now() + std::chrono::hours(max_waiting_hour)); + const string eval_type = config.get("Train.evaluation_type"); const unsigned eval_interval = config.get( "Train.evaluation_interval"); @@ -576,6 +580,8 @@ int main(int argc, char * argv[]) { FS::copy_file(model_dir / "latest.trainer.params", trainer_path); FS::copy_file(model_dir / "latest.model.params", model_path); logger->info("Saved 'best_dev_log_ppl' model."); + scheduled_training_finish_time = std::chrono::system_clock::to_time_t( + std::chrono::system_clock::now() + std::chrono::hours(max_waiting_hour)); } else { if (lr_decay_type == "logppl") { lr_decay *= lr_decay_ratio; @@ -608,6 +614,12 @@ int main(int argc, char * argv[]) { std::chrono::system_clock::to_time_t( current_time + std::chrono::minutes(eval_interval)); } + + // Training finish check + if (max_waiting_hour != 0 and + std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()) >= scheduled_training_finish_time) { + break; + } } logger->info("Finished."); From 3d0b4043742a585c9dd228120662bd13e3054f4a Mon Sep 17 00:00:00 2001 From: MorinoseiMorizo Date: Fri, 6 Jan 2017 16:59:20 +0900 Subject: [PATCH 2/2] add number of trained sample to evaluation log. --- src/bin/train.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/bin/train.cc b/src/bin/train.cc index 9e31809..7f94988 100644 --- a/src/bin/train.cc +++ b/src/bin/train.cc @@ -536,29 +536,29 @@ int main(int argc, char * argv[]) { const float dev_log_ppl = ::evaluateLogPerplexity( encdec, dev_sampler, batch_converter); const auto fmt_dev_log_ppl = boost::format( - "Evaluated: batch=%d words=%d elapsed-time(sec)=%d dev-log-ppl=%.6e") - % iteration % num_trained_words % elapsed_time_seconds % dev_log_ppl; + "Evaluated: batch=%d samples=%d words=%d elapsed-time(sec)=%d dev-log-ppl=%.6e") + % iteration % num_trained_samples % num_trained_words % elapsed_time_seconds % dev_log_ppl; logger->info(fmt_dev_log_ppl.str()); const float dev_bleu = ::evaluateBLEU( *trg_vocab, encdec, dev_sampler, train_max_length); const auto fmt_dev_bleu = boost::format( - "Evaluated: batch=%d words=%d elapsed-time(sec)=%d dev-bleu=%.6f") - % iteration % num_trained_words % elapsed_time_seconds % dev_bleu; + "Evaluated: batch=%d samples=%d words=%d elapsed-time(sec)=%d dev-bleu=%.6f") + % iteration % num_trained_samples % num_trained_words % elapsed_time_seconds % dev_bleu; logger->info(fmt_dev_bleu.str()); const float test_log_ppl = ::evaluateLogPerplexity( encdec, test_sampler, batch_converter); const auto fmt_test_log_ppl = boost::format( - "Evaluated: batch=%d words=%d elapsed-time(sec)=%d test-log-ppl=%.6e") - % iteration % num_trained_words % elapsed_time_seconds % test_log_ppl; + "Evaluated: batch=%d samples=%d words=%d elapsed-time(sec)=%d test-log-ppl=%.6e") + % iteration % num_trained_samples % num_trained_words % elapsed_time_seconds % test_log_ppl; logger->info(fmt_test_log_ppl.str()); const float test_bleu = ::evaluateBLEU( *trg_vocab, encdec, test_sampler, train_max_length); const auto fmt_test_bleu = boost::format( - "Evaluated: batch=%d words=%d elapsed-time(sec)=%d test-bleu=%.6f") - % iteration % num_trained_words % elapsed_time_seconds % test_bleu; + "Evaluated: batch=%d samples=%d words=%d elapsed-time(sec)=%d test-bleu=%.6f") + % iteration % num_trained_samples % num_trained_words % elapsed_time_seconds % test_bleu; logger->info(fmt_test_bleu.str()); if (lr_decay_type == "eval") {