From e10739779a8de2d66bd7f558f94428c51eb50cd0 Mon Sep 17 00:00:00 2001 From: Makoto Morishita Date: Thu, 26 Jan 2017 22:39:44 +0900 Subject: [PATCH 1/2] add feature to load vocabularies from the file --- sample_data/sample_config.ini | 5 ++++ src/bin/train.cc | 51 ++++++++++++++++++++++++++++------- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/sample_data/sample_config.ini b/sample_data/sample_config.ini index 62c114d..642e8cb 100644 --- a/sample_data/sample_config.ini +++ b/sample_data/sample_config.ini @@ -56,6 +56,11 @@ target_vocabulary_type=word source_vocabulary_size=4100 target_vocabulary_size=4900 +; If you specified the model directory here, load vocabulary from the specified model. +; If you use this option, the vocabulary type and the vocabulary size specified above will be ignored. +; By writing vocabulary_model=none, this option will be ignored. +vocabulary_model=none + ; Name of the encoder strategy. Available options: ; * bidirectional ... Bidirectional RNN. ; * forward ......... Forward RNN. diff --git a/src/bin/train.cc b/src/bin/train.cc index a17e198..6bd9eb2 100644 --- a/src/bin/train.cc +++ b/src/bin/train.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -333,6 +334,23 @@ float evaluateBLEU( return evaluator->integrate(stats); } +template +void loadArchive( + const FS::path & filepath, + const string & archive_format, + T * obj) { + ifstream ifs(filepath.string()); + NMTKIT_CHECK( + ifs.is_open(), "Could not open file to read: " + filepath.string()); + if (archive_format == "binary") { + boost::archive::binary_iarchive iar(ifs); + iar >> *obj; + } else if (archive_format == "text") { + boost::archive::text_iarchive iar(ifs); + iar >> *obj; + } +} + } // namespace int main(int argc, char * argv[]) { @@ -374,16 +392,29 @@ int main(int argc, char * argv[]) { nmtkit::initialize(global_config); // Creates vocabularies. - boost::scoped_ptr src_vocab( - ::createVocabulary( - config.get("Corpus.train_source"), - config.get("Model.source_vocabulary_type"), - config.get("Model.source_vocabulary_size"))); - boost::scoped_ptr trg_vocab( - ::createVocabulary( - config.get("Corpus.train_target"), - config.get("Model.target_vocabulary_type"), - config.get("Model.target_vocabulary_size"))); + boost::scoped_ptr src_vocab, trg_vocab; + if (config.get("Model.vocabulary_model") == "none") { + logger->info("Making source vocabulary."); + src_vocab.reset(::createVocabulary( + config.get("Corpus.train_source"), + config.get("Model.source_vocabulary_type"), + config.get("Model.source_vocabulary_size"))); + logger->info("Making target vocabulary."); + trg_vocab.reset(::createVocabulary( + config.get("Corpus.train_target"), + config.get("Model.target_vocabulary_type"), + config.get("Model.target_vocabulary_size"))); + } else { + FS::path vocab_model_dir(config.get("Model.vocabulary_model")); + // Parses config file. + PT::ptree vocab_config; + PT::read_ini((vocab_model_dir / "config.ini").string(), vocab_config); + // Archive format to load models. + const string vocab_archive_format = vocab_config.get("Global.archive_format"); + ::loadArchive(vocab_model_dir / "source.vocab", vocab_archive_format, &src_vocab); + ::loadArchive(vocab_model_dir / "target.vocab", vocab_archive_format, &trg_vocab); + logger->info("Loaded vocabularies."); + } ::saveArchive(model_dir / "source.vocab", archive_format, src_vocab); ::saveArchive(model_dir / "target.vocab", archive_format, trg_vocab); From 9417d8c3cdf4fdfdc3eb8f66343259e3656e4a85 Mon Sep 17 00:00:00 2001 From: Makoto Morishita Date: Thu, 26 Jan 2017 23:02:18 +0900 Subject: [PATCH 2/2] fix bugs --- src/bin/train.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bin/train.cc b/src/bin/train.cc index 6bd9eb2..760b938 100644 --- a/src/bin/train.cc +++ b/src/bin/train.cc @@ -339,7 +339,7 @@ void loadArchive( const FS::path & filepath, const string & archive_format, T * obj) { - ifstream ifs(filepath.string()); + std::ifstream ifs(filepath.string()); NMTKIT_CHECK( ifs.is_open(), "Could not open file to read: " + filepath.string()); if (archive_format == "binary") {