diff --git a/sample_data/sample_config.ini b/sample_data/sample_config.ini index 8cbf865..07f870e 100644 --- a/sample_data/sample_config.ini +++ b/sample_data/sample_config.ini @@ -180,6 +180,10 @@ dropout_ratio=0.3 ; Maximum number of batch data to be trained. max_iteration=10000 +; Training will be stopped if we cannot get the lowest dev ppl. for this hours. +; If this is set to 0, we don't stop training with this criteria. +max_waiting_hour=24 + ; Timing of evaluating dev/test set. Available options: ; * step .... Number of steps (iterations). ; * sample .... Number of samples (sentences). diff --git a/src/bin/train.cc b/src/bin/train.cc index 06f8e61..7f94988 100644 --- a/src/bin/train.cc +++ b/src/bin/train.cc @@ -450,6 +450,10 @@ int main(int argc, char * argv[]) { const float dropout_ratio = config.get("Train.dropout_ratio"); const unsigned max_iteration = config.get("Train.max_iteration"); + const unsigned max_waiting_hour = config.get("Train.max_waiting_hour"); + auto scheduled_training_finish_time = std::chrono::system_clock::to_time_t( + std::chrono::system_clock::now() + std::chrono::hours(max_waiting_hour)); + const string eval_type = config.get("Train.evaluation_type"); const unsigned eval_interval = config.get( "Train.evaluation_interval"); @@ -532,29 +536,29 @@ int main(int argc, char * argv[]) { const float dev_log_ppl = ::evaluateLogPerplexity( encdec, dev_sampler, batch_converter); const auto fmt_dev_log_ppl = boost::format( - "Evaluated: batch=%d words=%d elapsed-time(sec)=%d dev-log-ppl=%.6e") - % iteration % num_trained_words % elapsed_time_seconds % dev_log_ppl; + "Evaluated: batch=%d samples=%d words=%d elapsed-time(sec)=%d dev-log-ppl=%.6e") + % iteration % num_trained_samples % num_trained_words % elapsed_time_seconds % dev_log_ppl; logger->info(fmt_dev_log_ppl.str()); const float dev_bleu = ::evaluateBLEU( *trg_vocab, encdec, dev_sampler, train_max_length); const auto fmt_dev_bleu = boost::format( - "Evaluated: batch=%d words=%d elapsed-time(sec)=%d dev-bleu=%.6f") - % iteration % num_trained_words % elapsed_time_seconds % dev_bleu; + "Evaluated: batch=%d samples=%d words=%d elapsed-time(sec)=%d dev-bleu=%.6f") + % iteration % num_trained_samples % num_trained_words % elapsed_time_seconds % dev_bleu; logger->info(fmt_dev_bleu.str()); const float test_log_ppl = ::evaluateLogPerplexity( encdec, test_sampler, batch_converter); const auto fmt_test_log_ppl = boost::format( - "Evaluated: batch=%d words=%d elapsed-time(sec)=%d test-log-ppl=%.6e") - % iteration % num_trained_words % elapsed_time_seconds % test_log_ppl; + "Evaluated: batch=%d samples=%d words=%d elapsed-time(sec)=%d test-log-ppl=%.6e") + % iteration % num_trained_samples % num_trained_words % elapsed_time_seconds % test_log_ppl; logger->info(fmt_test_log_ppl.str()); const float test_bleu = ::evaluateBLEU( *trg_vocab, encdec, test_sampler, train_max_length); const auto fmt_test_bleu = boost::format( - "Evaluated: batch=%d words=%d elapsed-time(sec)=%d test-bleu=%.6f") - % iteration % num_trained_words % elapsed_time_seconds % test_bleu; + "Evaluated: batch=%d samples=%d words=%d elapsed-time(sec)=%d test-bleu=%.6f") + % iteration % num_trained_samples % num_trained_words % elapsed_time_seconds % test_bleu; logger->info(fmt_test_bleu.str()); if (lr_decay_type == "eval") { @@ -576,6 +580,8 @@ int main(int argc, char * argv[]) { FS::copy_file(model_dir / "latest.trainer.params", trainer_path); FS::copy_file(model_dir / "latest.model.params", model_path); logger->info("Saved 'best_dev_log_ppl' model."); + scheduled_training_finish_time = std::chrono::system_clock::to_time_t( + std::chrono::system_clock::now() + std::chrono::hours(max_waiting_hour)); } else { if (lr_decay_type == "logppl") { lr_decay *= lr_decay_ratio; @@ -608,6 +614,12 @@ int main(int argc, char * argv[]) { std::chrono::system_clock::to_time_t( current_time + std::chrono::minutes(eval_interval)); } + + // Training finish check + if (max_waiting_hour != 0 and + std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()) >= scheduled_training_finish_time) { + break; + } } logger->info("Finished.");