From 07bdd06092e43db17b5c57b1d08fbfaf3bb65f9d Mon Sep 17 00:00:00 2001 From: Makoto Morishita Date: Sun, 8 Jan 2017 09:36:34 +0900 Subject: [PATCH 1/3] fix BLEU score calculation during training --- src/bin/train.cc | 24 +-- src/include/nmtkit/basic_types.h | 14 ++ src/include/nmtkit/character_vocabulary.h | 4 + src/include/nmtkit/test_corpus.h | 120 +++++++++++++++ src/include/nmtkit/test_sampler.h | 43 ++++++ src/include/nmtkit/vocabulary.h | 20 +++ src/include/nmtkit/word_vocabulary.h | 4 + src/lib/Makefile.am | 2 + src/lib/character_vocabulary.cc | 18 +++ src/lib/test_corpus.cc | 108 +++++++++++++ src/lib/test_sampler.cc | 68 +++++++++ src/lib/word_vocabulary.cc | 20 +++ src/test/Makefile.am | 6 + src/test/character_vocabulary_test.cc | 72 +++++++++ src/test/test_corpus_test.cc | 175 ++++++++++++++++++++++ src/test/test_sampler_test.cc | 148 ++++++++++++++++++ src/test/word_vocabulary_test.cc | 52 +++++++ 17 files changed, 888 insertions(+), 10 deletions(-) create mode 100644 src/include/nmtkit/test_corpus.h create mode 100644 src/include/nmtkit/test_sampler.h create mode 100644 src/lib/test_corpus.cc create mode 100644 src/lib/test_sampler.cc create mode 100644 src/test/test_corpus_test.cc create mode 100644 src/test/test_sampler_test.cc diff --git a/src/bin/train.cc b/src/bin/train.cc index 06f8e61..5036088 100644 --- a/src/bin/train.cc +++ b/src/bin/train.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -30,6 +31,7 @@ #include #include #include +#include #include #include @@ -270,7 +272,7 @@ void saveArchive( // The log perplexity score. float evaluateLogPerplexity( nmtkit::EncoderDecoder & encdec, - nmtkit::MonotoneSampler & sampler, + nmtkit::TestSampler & sampler, nmtkit::BatchConverter & converter) { unsigned num_outputs = 0; float total_loss = 0.0f; @@ -302,15 +304,16 @@ float evaluateLogPerplexity( float evaluateBLEU( const nmtkit::Vocabulary & trg_vocab, nmtkit::EncoderDecoder & encdec, - nmtkit::MonotoneSampler & sampler, + nmtkit::TestSampler & sampler, const unsigned max_length) { const auto evaluator = MTEval::EvaluatorFactory::create("BLEU"); const unsigned bos_id = trg_vocab.getID(""); const unsigned eos_id = trg_vocab.getID(""); + MTEval::Dictionary dict; MTEval::Statistics stats; sampler.rewind(); while (sampler.hasSamples()) { - vector samples = sampler.getSamples(); + vector samples = sampler.getTestSamples(); nmtkit::InferenceGraph ig = encdec.infer( samples[0].source, bos_id, eos_id, max_length, 1, 0.0f); const auto hyp_nodes = ig.findOneBestPath(bos_id, eos_id); @@ -319,7 +322,10 @@ float evaluateBLEU( for (unsigned i = 1; i < hyp_nodes.size() - 1; ++i) { hyp_ids.emplace_back(hyp_nodes[i]->label().word_id); } - MTEval::Sample eval_sample {hyp_ids, {samples[0].target}}; + const string string_hyp = trg_vocab.convertToSentence(hyp_ids); + const MTEval::Sentence sent_hyp = dict.getSentence(string_hyp); + const MTEval::Sentence sent_ref = dict.getSentence(samples[0].target_string); + MTEval::Sample eval_sample {sent_hyp, {sent_ref}}; stats += evaluator->map(eval_sample); } return evaluator->integrate(stats); @@ -381,10 +387,8 @@ int main(int argc, char * argv[]) { // Maximum lengths const unsigned train_max_length = config.get("Batch.max_length"); - const unsigned test_max_length = 1024; const float train_max_length_ratio = config.get( "Batch.max_length_ratio"); - const float test_max_length_ratio = 1e10; // Creates samplers and batch converter. nmtkit::SortedRandomSampler train_sampler( @@ -398,15 +402,15 @@ int main(int argc, char * argv[]) { config.get("Global.random_seed")); const unsigned corpus_size = train_sampler.getNumSamples(); logger->info("Loaded 'train' corpus."); - nmtkit::MonotoneSampler dev_sampler( + nmtkit::TestSampler dev_sampler( config.get("Corpus.dev_source"), config.get("Corpus.dev_target"), - *src_vocab, *trg_vocab, test_max_length, test_max_length_ratio, 1); + *src_vocab, *trg_vocab, 1); logger->info("Loaded 'dev' corpus."); - nmtkit::MonotoneSampler test_sampler( + nmtkit::TestSampler test_sampler( config.get("Corpus.test_source"), config.get("Corpus.test_target"), - *src_vocab, *trg_vocab, test_max_length, test_max_length_ratio, 1); + *src_vocab, *trg_vocab, 1); logger->info("Loaded 'test' corpus."); const auto fmt_corpus_size = boost::format( "Cleaned corpus size: train=%d dev=%d test=%d") diff --git a/src/include/nmtkit/basic_types.h b/src/include/nmtkit/basic_types.h index 4ea1c1a..b0081a3 100644 --- a/src/include/nmtkit/basic_types.h +++ b/src/include/nmtkit/basic_types.h @@ -14,6 +14,20 @@ struct Sample { std::vector target; }; +struct TestSample { + // Source sentence with word IDs + std::vector source; + + // Target sentence with word IDs. + std::vector target; + + // Source sentence string. + std::string source_string; + + // Target sentence string. + std::string target_string; +}; + struct Batch { // Source word ID table with shape (max_source_length, batch_size). std::vector> source_ids; diff --git a/src/include/nmtkit/character_vocabulary.h b/src/include/nmtkit/character_vocabulary.h index b1554ea..b61fc50 100644 --- a/src/include/nmtkit/character_vocabulary.h +++ b/src/include/nmtkit/character_vocabulary.h @@ -41,8 +41,12 @@ class CharacterVocabulary : public Vocabulary { unsigned getFrequency(const unsigned id) const override; std::vector convertToIDs( const std::string & sentence) const override; + std::vector convertToTokens( + const std::string & sentence) const override; std::string convertToSentence( const std::vector & word_ids) const override; + std::vector convertToTokenizedSentence( + const std::vector & word_ids) const override; unsigned size() const override; private: diff --git a/src/include/nmtkit/test_corpus.h b/src/include/nmtkit/test_corpus.h new file mode 100644 index 0000000..b145930 --- /dev/null +++ b/src/include/nmtkit/test_corpus.h @@ -0,0 +1,120 @@ +#ifndef NMTKIT_TEST_CORPUS_H_ +#define NMTKIT_TEST_CORPUS_H_ + +#include +#include +#include +#include +#include +#include + +namespace nmtkit { + +class TestCorpus : public Corpus { + TestCorpus() = delete; + TestCorpus(const TestCorpus &) = delete; + TestCorpus(TestCorpus &&) = delete; + TestCorpus & operator=(const TestCorpus &) = delete; + TestCorpus & operator=(TestCorpus &&) = delete; + +public: + // Reads one line from the input stream. + // + // Arguments: + // is: Target input stream. + // line: Placeholder to store new string. Old data will be deleted before + // storing the new value. + // + // Returns: + // true if reading completed successfully, false otherwise (e.g. EOF). + static bool readLine(std::istream * is, std::string * line); + + // Reads one line from the input stream. + // + // Arguments: + // vocab: Vocabulary object to be used to convert words into word IDs. + // is: Target input stream. + // word_ids: Placeholder to store new words. Old data will be deleted + // automatically before storing new samples. + // sent_string: Placeholder to store new sentences. Old data will be + // deleted automatically before storing new samples. + // + // Returns: + // true if reading completed successfully, false otherwise (e.g. EOF). + static bool readTokens( + const Vocabulary & vocab, + std::istream * is, + std::vector * word_ids, + std::string * sent_string); + + // Loads all samples in the tokenized corpus. + // + // Arguments: + // filepath: Location of the corpus file. + // vocab: Vocabulary object for the corpus language. + // result: Placeholder to store new samples. Old data will be deleted + // automatically before storing new samples. + // string_result: Placeholder to store new samples sentences. + // Old data will be deleted automatically before storing new + // samples. + static void loadSingleSentences( + const std::string & filepath, + const Vocabulary & vocab, + std::vector> * result, + std::vector * string_result); + + // Loads tokenized parallel corpus. + // + // Arguments: + // src_filepath: Location of the source corpus file. + // trg_filepath: Location of the target corpus file. + // src_vocab: Vocabulary object for the source language. + // trg_vocab: Vocabulary object for the target language. + // max_length: Maximum number of words in a sentence. Samples which exceeds + // this value will be skipped. + // max_length_ratio: Maximum ratio of lengths in source/target sentences. + // Samples which exceeds this value will be skipped. + // src_result: Placeholder to store new source samples. Old data will be + // deleted automatically before storing new samples. + // trg_result: Placeholder to store new target samples. Old data will be + // deleted automatically before storing new samples. + // src_string_result: Placeholder to store new source samples strings. + // Old data will be deleted automatically before + // storing new samples. + // trg_string_result: Placeholder to store new target samples strings. + // Old data will be deleted automatically before + // storing new samples. + static void loadParallelSentences( + const std::string & src_filepath, + const std::string & trg_filepath, + const Vocabulary & src_vocab, + const Vocabulary & trg_vocab, + std::vector> * src_result, + std::vector> * trg_result, + std::vector * src_string_result, + std::vector * trg_string_result); + + // Loads tokenized parallel corpus directly to Sample objects. + // + // Arguments: + // src_filepath: Location of the source corpus file. + // trg_filepath: Location of the target corpus file. + // src_vocab: Vocabulary object for the source language. + // trg_vocab: Vocabulary object for the target language. + // max_length: Maximum number of words in a sentence. Samples which exceeds + // this value will be skipped. + // max_length_ratio: Maximum ratio of lengths in source/target sentences. + // Samples which exceeds this value will be skipped. + // result: Placeholder to store new source/target samples. Old data will be + // deleted automatically before storing new samples. + static void loadParallelSentences( + const std::string & src_filepath, + const std::string & trg_filepath, + const Vocabulary & src_vocab, + const Vocabulary & trg_vocab, + std::vector * result); +}; + +} // namespace nmtkit + +#endif // NMTKIT_TEST_CORPUS_H_ diff --git a/src/include/nmtkit/test_sampler.h b/src/include/nmtkit/test_sampler.h new file mode 100644 index 0000000..b932b12 --- /dev/null +++ b/src/include/nmtkit/test_sampler.h @@ -0,0 +1,43 @@ +#ifndef NMTKIT_TEST_SAMPLER_H_ +#define NMTKIT_TEST_SAMPLER_H_ + +#include +#include + +namespace nmtkit { + +class TestSampler : public Sampler { + TestSampler() = delete; + TestSampler(const TestSampler &) = delete; + TestSampler(TestSampler &&) = delete; + TestSampler & operator=(const TestSampler &) = delete; + TestSampler & operator=(TestSampler &&) = delete; + +public: + TestSampler( + const std::string & src_filepath, + const std::string & trg_filepath, + const Vocabulary & src_vocab, + const Vocabulary & trg_vocab, + unsigned batch_size); + + ~TestSampler() override {} + + void rewind() override; + std::vector getSamples() override; + std::vector getTestSamples(); + unsigned getNumSamples() override; + bool hasSamples() const override; + +private: + std::vector src_samples_string_; + std::vector trg_samples_string_; + std::vector> src_samples_; + std::vector> trg_samples_; + unsigned batch_size_; + unsigned current_; +}; + +} // namespace nmtkit + +#endif // NMTKIT_TEST_SAMPLER_H_ diff --git a/src/include/nmtkit/vocabulary.h b/src/include/nmtkit/vocabulary.h index 2c7cdc4..2788169 100644 --- a/src/include/nmtkit/vocabulary.h +++ b/src/include/nmtkit/vocabulary.h @@ -55,6 +55,16 @@ class Vocabulary { // List of word IDs that represents given sentence. virtual std::vector convertToIDs( const std::string & sentence) const = 0; + + // Converts a sentence into a list of words (tokens). + // + // Arguments: + // sentence: A sentence string. + // + // Returns: + // List of sentence words. + virtual std::vector convertToTokens( + const std::string & sentence) const = 0; // Converts a list of word IDs into a sentence. // @@ -65,6 +75,16 @@ class Vocabulary { // Generaed sentence string. virtual std::string convertToSentence( const std::vector & word_ids) const = 0; + + // Converts a list of wordIDs into a list of words (tokens). + // + // Arguments: + // word_ids: A list of word IDs. + // + // Returns: + // Generaed list of sentence words. + virtual std::vector convertToTokenizedSentence( + const std::vector & word_ids) const = 0; // Retrieves the size of the vocabulary. // diff --git a/src/include/nmtkit/word_vocabulary.h b/src/include/nmtkit/word_vocabulary.h index e4ca18e..65ec884 100644 --- a/src/include/nmtkit/word_vocabulary.h +++ b/src/include/nmtkit/word_vocabulary.h @@ -41,8 +41,12 @@ class WordVocabulary : public Vocabulary { unsigned getFrequency(const unsigned id) const override; std::vector convertToIDs( const std::string & sentence) const override; + std::vector convertToTokens( + const std::string & sentence) const override; std::string convertToSentence( const std::vector & word_ids) const override; + std::vector convertToTokenizedSentence( + const std::vector & word_ids) const override; unsigned size() const override; private: diff --git a/src/lib/Makefile.am b/src/lib/Makefile.am index 79f3192..2525857 100644 --- a/src/lib/Makefile.am +++ b/src/lib/Makefile.am @@ -34,6 +34,8 @@ libnmtkit_la_SOURCES = \ single_text_formatter.cc \ softmax_predictor.cc \ sorted_random_sampler.cc \ + test_corpus.cc \ + test_sampler.cc \ vocabulary.cc \ word_vocabulary.cc diff --git a/src/lib/character_vocabulary.cc b/src/lib/character_vocabulary.cc index d91d5ce..c5d5a03 100644 --- a/src/lib/character_vocabulary.cc +++ b/src/lib/character_vocabulary.cc @@ -122,6 +122,15 @@ vector CharacterVocabulary::convertToIDs( return ids; } +vector CharacterVocabulary::convertToTokens( + const string & sentence) const { + vector tokens; + for (const string & letter : ::convertToLetters(sentence)) { + tokens.emplace_back(getWord(getID(letter))); + } + return tokens; +} + string CharacterVocabulary::convertToSentence( const vector & word_ids) const { vector letters; @@ -131,6 +140,15 @@ string CharacterVocabulary::convertToSentence( return boost::join(letters, ""); } +vector CharacterVocabulary::convertToTokenizedSentence( + const vector & word_ids) const { + vector letters; + for (const unsigned word_id : word_ids) { + letters.emplace_back(getWord(word_id)); + } + return letters; +} + unsigned CharacterVocabulary::size() const { return itos_.size(); } diff --git a/src/lib/test_corpus.cc b/src/lib/test_corpus.cc new file mode 100644 index 0000000..e441858 --- /dev/null +++ b/src/lib/test_corpus.cc @@ -0,0 +1,108 @@ +#include "config.h" + +#include + +#include +#include +#include +#include + +using namespace std; + +namespace nmtkit { + +bool TestCorpus::readLine(istream * is, string * line) { + if (!getline(*is, *line)) return false; + boost::trim(*line); + return true; +} + +bool TestCorpus::readTokens( + const Vocabulary & vocab, + istream * is, + vector * word_ids, + string * sent_string) { + string line; + if (!readLine(is, &line)) return false; + *word_ids = vocab.convertToIDs(line); + *sent_string = line; + return true; +} + +void TestCorpus::loadSingleSentences( + const string & filepath, + const Vocabulary & vocab, + vector> * result, + vector * string_result) { + ifstream ifs(filepath); + NMTKIT_CHECK( + ifs.is_open(), "Could not open the corpus file to load: " + filepath); + + result->clear(); + vector word_ids; + string sent_string; + while (readTokens(vocab, &ifs, &word_ids, &sent_string)) { + result->emplace_back(std::move(word_ids)); + string_result->emplace_back(std::move(sent_string)); + } +} + +void TestCorpus::loadParallelSentences( + const string & src_filepath, + const string & trg_filepath, + const Vocabulary & src_vocab, + const Vocabulary & trg_vocab, + vector> * src_result, + vector> * trg_result, + vector * src_string_result, + vector * trg_string_result) { + ifstream src_ifs(src_filepath), trg_ifs(trg_filepath); + NMTKIT_CHECK( + src_ifs.is_open(), + "Could not open the source corpus file to load: " + src_filepath); + NMTKIT_CHECK( + trg_ifs.is_open(), + "Could not open the target corpus file to load: " + trg_filepath); + + src_result->clear(); + trg_result->clear(); + src_string_result->clear(); + trg_string_result->clear(); + vector src_ids, trg_ids; + string src_string, trg_string; + while ( + readTokens(src_vocab, &src_ifs, &src_ids, &src_string) and + readTokens(trg_vocab, &trg_ifs, &trg_ids, &trg_string)) { + src_result->emplace_back(std::move(src_ids)); + trg_result->emplace_back(std::move(trg_ids)); + src_string_result->emplace_back(std::move(src_string)); + trg_string_result->emplace_back(std::move(trg_string)); + } +} + +void TestCorpus::loadParallelSentences( + const string & src_filepath, + const string & trg_filepath, + const Vocabulary & src_vocab, + const Vocabulary & trg_vocab, + vector * result) { + ifstream src_ifs(src_filepath), trg_ifs(trg_filepath); + NMTKIT_CHECK( + src_ifs.is_open(), + "Could not open the source corpus file to load: " + src_filepath); + NMTKIT_CHECK( + trg_ifs.is_open(), + "Could not open the target corpus file to load: " + trg_filepath); + + result->clear(); + vector src_ids, trg_ids; + string src_string, trg_string; + while ( + readTokens(src_vocab, &src_ifs, &src_ids, &src_string) and + readTokens(trg_vocab, &trg_ifs, &trg_ids, &trg_string)) { + result->emplace_back(TestSample {std::move(src_ids), std::move(trg_ids), + std::move(src_string), std::move(trg_string)}); + } +} + +} // namespace nmtkit diff --git a/src/lib/test_sampler.cc b/src/lib/test_sampler.cc new file mode 100644 index 0000000..471e13f --- /dev/null +++ b/src/lib/test_sampler.cc @@ -0,0 +1,68 @@ +#include "config.h" + +#include + +#include +#include + +using namespace std; + +namespace nmtkit { + +TestSampler::TestSampler( + const string & src_filepath, + const string & trg_filepath, + const Vocabulary & src_vocab, + const Vocabulary & trg_vocab, + unsigned batch_size) +: batch_size_(batch_size) { + TestCorpus::loadParallelSentences( + src_filepath, trg_filepath, + src_vocab, trg_vocab, + &src_samples_, &trg_samples_, &src_samples_string_, &trg_samples_string_); + NMTKIT_CHECK(src_samples_.size() > 0, "Corpus files are empty."); + NMTKIT_CHECK(batch_size_ > 0, "batch_size should be greater than 0."); + + rewind(); +} + +void TestSampler::rewind() { + current_ = 0; +} + +vector TestSampler::getSamples() { + NMTKIT_CHECK(hasSamples(), "No more samples."); + + vector result; + for (unsigned i = 0; i < batch_size_ && hasSamples(); ++i) { + result.emplace_back( + Sample {src_samples_[current_], trg_samples_[current_]}); + ++current_; + } + + return result; +} + +vector TestSampler::getTestSamples() { + NMTKIT_CHECK(hasSamples(), "No more samples."); + + vector result; + for (unsigned i = 0; i < batch_size_ && hasSamples(); ++i) { + result.emplace_back( + TestSample {src_samples_[current_], trg_samples_[current_], + src_samples_string_[current_], trg_samples_string_[current_]}); + ++current_; + } + + return result; +} + +unsigned TestSampler::getNumSamples() { + return src_samples_.size(); +} + +bool TestSampler::hasSamples() const { + return current_ < src_samples_.size(); +} + +} // namespace nmtkit diff --git a/src/lib/word_vocabulary.cc b/src/lib/word_vocabulary.cc index e57230f..e7964ba 100644 --- a/src/lib/word_vocabulary.cc +++ b/src/lib/word_vocabulary.cc @@ -89,6 +89,17 @@ vector WordVocabulary::convertToIDs(const string & sentence) const { return ids; } +vector WordVocabulary::convertToTokens(const string & sentence) const { + vector words; + boost::split( + words, sentence, boost::is_space(), boost::algorithm::token_compress_on); + vector tokens; + for (const string & word : words) { + tokens.emplace_back(getWord(getID(word))); + } + return tokens; +} + string WordVocabulary::convertToSentence( const vector & word_ids) const { vector words; @@ -98,6 +109,15 @@ string WordVocabulary::convertToSentence( return boost::join(words, " "); } +vector WordVocabulary::convertToTokenizedSentence( + const vector & word_ids) const { + vector words; + for (const unsigned word_id : word_ids) { + words.emplace_back(getWord(word_id)); + } + return words; +} + unsigned WordVocabulary::size() const { return itos_.size(); } diff --git a/src/test/Makefile.am b/src/test/Makefile.am index 706f90f..e437c25 100644 --- a/src/test/Makefile.am +++ b/src/test/Makefile.am @@ -18,8 +18,10 @@ check_PROGRAMS = \ batch_converter_test \ character_vocabulary_test \ corpus_test \ + test_corpus_test \ inference_graph_test \ monotone_sampler_test \ + test_sampler_test \ random_test \ sorted_random_sampler_test \ word_vocabulary_test @@ -29,8 +31,10 @@ TESTS = \ batch_converter_test \ character_vocabulary_test \ corpus_test \ + test_corpus_test \ inference_graph_test \ monotone_sampler_test \ + test_sampler_test \ random_test \ sorted_random_sampler_test \ word_vocabulary_test @@ -39,8 +43,10 @@ array_test_SOURCES = array_test.cc batch_converter_test_SOURCES = batch_converter_test.cc character_vocabulary_test_SOURCES = character_vocabulary_test.cc corpus_test_SOURCES = corpus_test.cc +test_corpus_test_SOURCES = test_corpus_test.cc inference_graph_test_SOURCES = inference_graph_test.cc monotone_sampler_test_SOURCES = monotone_sampler_test.cc +test_sampler_test_SOURCES = test_sampler_test.cc random_test_SOURCES = random_test.cc sorted_random_sampler_test_SOURCES = sorted_random_sampler_test.cc word_vocabulary_test_SOURCES = word_vocabulary_test.cc diff --git a/src/test/character_vocabulary_test.cc b/src/test/character_vocabulary_test.cc index 8fee28b..d0c8315 100644 --- a/src/test/character_vocabulary_test.cc +++ b/src/test/character_vocabulary_test.cc @@ -120,6 +120,39 @@ BOOST_AUTO_TEST_CASE(CheckConvertingToIDs) { } } +BOOST_AUTO_TEST_CASE(CheckConvertingToTokens) { + nmtkit::CharacterVocabulary vocab; + ::loadArchive("data/small.en.char.vocab", &vocab); + const vector sentences { + "anything that can go wrong , will go wrong .", + "there is always light behind the clouds .", + "and yet it moves .", + "これ は 日本 語 の テスト 文 で す 。", + }; + const vector> expected { + {"a", "n", "y", "t", "h", "i", "n", "g", " ", "t", "h", "a", + "t", " ", "c", "a", "n", " ", "g", "o", " ", "w", "r", "o", + "n", "g", " ", ",", " ", "w", "i", "l", "l", " ", "g", "o", + " ", "w", "r", "o", "n", "g", " ", "."}, + {"t", "h", "e", "r", "e", " ", "i", "s", " ", "a", "l", "w", + "a", "y", "s", " ", "l", "i", "g", "h", "t", " ", "b", "e", + "h", "i", "n", "d", " ", "t", "h", "e", " ", "c", "l", "o", + "u", "d", "s", " ", "."}, + {"a", "n", "d", " ", "y", "e", "t", " ", "i", "t", " ", "m", + "o", "v", "e", "s", " ", "."}, + {"", "", " ", "", " ", "", "", " ", + "", " ", "", " ", "", "", "", " ", + "", " ", "", " ", "", " ", ""} + }; + + for (unsigned i = 0; i < sentences.size(); ++i) { + vector observed = vocab.convertToTokens(sentences[i]); + BOOST_CHECK_EQUAL_COLLECTIONS( + expected[i].begin(), expected[i].end(), + observed.begin(), observed.end()); + } +} + BOOST_AUTO_TEST_CASE(CheckConvertingToSentence) { nmtkit::CharacterVocabulary vocab; ::loadArchive("data/small.en.char.vocab", &vocab); @@ -149,4 +182,43 @@ BOOST_AUTO_TEST_CASE(CheckConvertingToSentence) { } } +BOOST_AUTO_TEST_CASE(CheckConvertingToTokenizedSentence) { + nmtkit::CharacterVocabulary vocab; + ::loadArchive("data/small.en.char.vocab", &vocab); + const vector> word_ids { + { 7, 11, 17, 5, 9, 8, 11, 21, 3, 5, 9, 7, 5, 3, 20, 7, + 11, 3, 21, 6, 3, 18, 12, 6, 11, 21, 3, 29, 3, 18, 8, 13, + 13, 3, 21, 6, 3, 18, 12, 6, 11, 21, 3, 14}, + { 5, 9, 4, 12, 4, 3, 8, 10, 3, 7, 13, 18, 7, 17, 10, 3, + 13, 8, 21, 9, 5, 3, 24, 4, 9, 8, 11, 15, 3, 5, 9, 4, + 3, 20, 13, 6, 16, 15, 10, 3, 14}, + { 7, 11, 15, 3, 17, 4, 5, 3, 8, 5, 3, 19, 6, 26, 4, 10, + 3, 14}, + { 0, 0, 3, 0, 3, 0, 0, 3, 0, 3, 0, 3, 0, 0, 0, 3, + 0, 3, 0, 3, 0, 3, 0}, + }; + const vector> expected { + {"a", "n", "y", "t", "h", "i", "n", "g", " ", "t", "h", "a", + "t", " ", "c", "a", "n", " ", "g", "o", " ", "w", "r", "o", + "n", "g", " ", ",", " ", "w", "i", "l", "l", " ", "g", "o", + " ", "w", "r", "o", "n", "g", " ", "."}, + {"t", "h", "e", "r", "e", " ", "i", "s", " ", "a", "l", "w", + "a", "y", "s", " ", "l", "i", "g", "h", "t", " ", "b", "e", + "h", "i", "n", "d", " ", "t", "h", "e", " ", "c", "l", "o", + "u", "d", "s", " ", "."}, + {"a", "n", "d", " ", "y", "e", "t", " ", "i", "t", " ", "m", + "o", "v", "e", "s", " ", "."}, + {"", "", " ", "", " ", "", "", " ", + "", " ", "", " ", "", "", "", " ", + "", " ", "", " ", "", " ", ""} + }; + + for (unsigned i = 0; i < word_ids.size(); ++i) { + vector observed = vocab.convertToTokenizedSentence(word_ids[i]); + BOOST_CHECK_EQUAL_COLLECTIONS( + expected[i].begin(), expected[i].end(), + observed.begin(), observed.end()); + } +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/test_corpus_test.cc b/src/test/test_corpus_test.cc new file mode 100644 index 0000000..bdd9c65 --- /dev/null +++ b/src/test/test_corpus_test.cc @@ -0,0 +1,175 @@ +#include "config.h" + +#define BOOST_TEST_MAIN +#include + +#include +#include +#include +#include +#include +#include + +using namespace std; + +namespace { + +const string src_tok_filename = "data/small.en.tok"; +const string trg_tok_filename = "data/small.ja.tok"; +const string src_vocab_filename = "data/small.en.vocab"; +const string trg_vocab_filename = "data/small.ja.vocab"; + +template +void loadArchive(const string & filepath, T * obj) { + ifstream ifs(filepath); + boost::archive::binary_iarchive iar(ifs); + iar >> *obj; +} + +} // namespace + +BOOST_AUTO_TEST_SUITE(TestCorpusTest) + +BOOST_AUTO_TEST_CASE(CheckLoadingSingle) { + const unsigned expected_num_sents = 500; + const unsigned expected_num_words = 3871; + const vector> expected_words { + {6, 41, 17, 90, 106, 37, 0, 364, 3}, + {159, 0, 13, 130, 0, 101, 332, 3}, + {6, 75, 12, 4, 145, 0, 3}, + {0, 219, 228, 3}, + }; + const vector expected_strings { + "i can 't tell who will arrive first .", + "many animals have been destroyed by men .", + "i 'm in the tennis club .", + "emi looks happy ." + }; + + nmtkit::WordVocabulary vocab; + ::loadArchive(::src_vocab_filename, &vocab); + vector> result; + vector result_string; + nmtkit::TestCorpus::loadSingleSentences(::src_tok_filename, vocab, &result, &result_string); + + BOOST_CHECK_EQUAL(expected_num_sents, result.size()); + + unsigned num_words = 0; + for (const auto & sent : result) { + num_words += sent.size(); + } + BOOST_CHECK_EQUAL(expected_num_words, num_words); + + for (unsigned i = 0; i < expected_words.size(); ++i) { + BOOST_CHECK_EQUAL_COLLECTIONS( + expected_words[i].begin(), expected_words[i].end(), + result[i].begin(), result[i].end()); + BOOST_CHECK_EQUAL(expected_strings[i], result_string[i]); + } +} + +BOOST_AUTO_TEST_CASE(CheckLoadingParallel) { + const vector> expected_src_words { + {6, 41, 17, 90, 106, 37, 0, 364, 3}, + {159, 0, 13, 130, 0, 101, 332, 3}, + {6, 75, 12, 4, 145, 0, 3}, + {0, 219, 228, 3}, + }; + const vector expected_src_strings { + "i can 't tell who will arrive first .", + "many animals have been destroyed by men .", + "i 'm in the tennis club .", + "emi looks happy ." + }; + + const vector> expected_trg_words { + {86, 13, 202, 6, 138, 30, 22, 18, 6, 4, 310, 38, 20, 46, 29, 3}, + {298, 9, 0, 13, 325, 6, 33, 15, 10, 0, 69, 88, 8, 3}, + {18, 4, 158, 416, 12, 19, 3}, + {0, 4, 0, 164, 6, 242, 20, 19, 3}, + }; + const vector expected_trg_strings { + "誰 が 一番 に 着 く か 私 に は 分か り ま せ ん 。", + "多く の 動物 が 人間 に よ っ て 滅ぼ さ れ た 。", + "私 は テニス 部員 で す 。", + "エミ は 幸せ そう に 見え ま す 。" + }; + + nmtkit::WordVocabulary src_vocab, trg_vocab; + ::loadArchive(::src_vocab_filename, &src_vocab); + ::loadArchive(::trg_vocab_filename, &trg_vocab); + vector> src_result, trg_result; + vector src_string_result; + vector trg_string_result; + + nmtkit::TestCorpus::loadParallelSentences( + ::src_tok_filename, ::trg_tok_filename, + src_vocab, trg_vocab, + &src_result, &trg_result, + &src_string_result, &trg_string_result); + + BOOST_CHECK_EQUAL(src_result.size(), trg_result.size()); + BOOST_CHECK_EQUAL(src_string_result.size(), trg_string_result.size()); + + for (unsigned i = 0; i < expected_src_words.size(); ++i) { + BOOST_CHECK_EQUAL_COLLECTIONS( + expected_src_words[i].begin(), expected_src_words[i].end(), + src_result[i].begin(), src_result[i].end()); + BOOST_CHECK_EQUAL(expected_src_strings[i], src_string_result[i]); + BOOST_CHECK_EQUAL_COLLECTIONS( + expected_trg_words[i].begin(), expected_trg_words[i].end(), + trg_result[i].begin(), trg_result[i].end()); + BOOST_CHECK_EQUAL(expected_trg_strings[i], trg_string_result[i]); + } +} + +BOOST_AUTO_TEST_CASE(CheckLoadingParallel2) { + const vector> expected_src_words { + {6, 41, 17, 90, 106, 37, 0, 364, 3}, + {159, 0, 13, 130, 0, 101, 332, 3}, + {6, 75, 12, 4, 145, 0, 3}, + {0, 219, 228, 3}, + }; + const vector expected_src_strings { + "i can 't tell who will arrive first .", + "many animals have been destroyed by men .", + "i 'm in the tennis club .", + "emi looks happy ." + }; + + const vector> expected_trg_words { + {86, 13, 202, 6, 138, 30, 22, 18, 6, 4, 310, 38, 20, 46, 29, 3}, + {298, 9, 0, 13, 325, 6, 33, 15, 10, 0, 69, 88, 8, 3}, + {18, 4, 158, 416, 12, 19, 3}, + {0, 4, 0, 164, 6, 242, 20, 19, 3}, + }; + const vector expected_trg_strings { + "誰 が 一番 に 着 く か 私 に は 分か り ま せ ん 。", + "多く の 動物 が 人間 に よ っ て 滅ぼ さ れ た 。", + "私 は テニス 部員 で す 。", + "エミ は 幸せ そう に 見え ま す 。" + }; + + nmtkit::WordVocabulary src_vocab, trg_vocab; + ::loadArchive(::src_vocab_filename, &src_vocab); + ::loadArchive(::trg_vocab_filename, &trg_vocab); + vector result; + + nmtkit::TestCorpus::loadParallelSentences( + ::src_tok_filename, ::trg_tok_filename, + src_vocab, trg_vocab, + &result); + + for (unsigned i = 0; i < expected_src_words.size(); ++i) { + BOOST_CHECK_EQUAL_COLLECTIONS( + expected_src_words[i].begin(), expected_src_words[i].end(), + result[i].source.begin(), result[i].source.end()); + BOOST_CHECK_EQUAL(expected_src_strings[i], result[i].source_string); + BOOST_CHECK_EQUAL_COLLECTIONS( + expected_trg_words[i].begin(), expected_trg_words[i].end(), + result[i].target.begin(), result[i].target.end()); + BOOST_CHECK_EQUAL(expected_trg_strings[i], result[i].target_string); + } +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/test_sampler_test.cc b/src/test/test_sampler_test.cc new file mode 100644 index 0000000..0cebdca --- /dev/null +++ b/src/test/test_sampler_test.cc @@ -0,0 +1,148 @@ +#include "config.h" + +#define BOOST_TEST_MAIN +#include + +#include +#include +#include +#include +#include + +using namespace std; + +namespace { + +const string src_tok_filename = "data/small.en.tok"; +const string trg_tok_filename = "data/small.ja.tok"; +const string src_vocab_filename = "data/small.en.vocab"; +const string trg_vocab_filename = "data/small.ja.vocab"; +const unsigned corpus_size = 500; // #samples in the sample corpus +const unsigned batch_size = 64; +const unsigned tail_size = corpus_size % batch_size; + +const vector> expected_src { + {6, 41, 17, 90, 106, 37, 0, 364, 3}, + {159, 0, 13, 130, 0, 101, 332, 3}, + {6, 75, 12, 4, 145, 0, 3}, + {0, 219, 228, 3}, +}; +const vector expected_src_strings { + "i can 't tell who will arrive first .", + "many animals have been destroyed by men .", + "i 'm in the tennis club .", + "emi looks happy ." +}; +const vector> expected_trg { + {86, 13, 202, 6, 138, 30, 22, 18, 6, 4, 310, 38, 20, 46, 29, 3}, + {298, 9, 0, 13, 325, 6, 33, 15, 10, 0, 69, 88, 8, 3}, + {18, 4, 158, 416, 12, 19, 3}, + {0, 4, 0, 164, 6, 242, 20, 19, 3}, +}; +const vector expected_trg_strings { + "誰 が 一番 に 着 く か 私 に は 分か り ま せ ん 。", + "多く の 動物 が 人間 に よ っ て 滅ぼ さ れ た 。", + "私 は テニス 部員 で す 。", + "エミ は 幸せ そう に 見え ま す 。" +}; + +template +void loadArchive(const string & filepath, T * obj) { + ifstream ifs(filepath); + boost::archive::binary_iarchive iar(ifs); + iar >> *obj; +} + +} // namespace + +BOOST_AUTO_TEST_SUITE(TestSamplerTest) + +BOOST_AUTO_TEST_CASE(CheckIteration) { + nmtkit::WordVocabulary src_vocab, trg_vocab; + ::loadArchive(::src_vocab_filename, &src_vocab); + ::loadArchive(::trg_vocab_filename, &trg_vocab); + nmtkit::TestSampler sampler( + ::src_tok_filename, ::trg_tok_filename, + src_vocab, trg_vocab, ::batch_size); + + BOOST_CHECK(sampler.hasSamples()); + + // Checks head samples. + { + vector samples = sampler.getSamples(); + BOOST_CHECK_EQUAL(::batch_size, samples.size()); + for (unsigned i = 0; i < ::expected_src.size(); ++i) { + BOOST_CHECK_EQUAL_COLLECTIONS( + ::expected_src[i].begin(), ::expected_src[i].end(), + samples[i].source.begin(), samples[i].source.end()); + BOOST_CHECK_EQUAL_COLLECTIONS( + ::expected_trg[i].begin(), ::expected_trg[i].end(), + samples[i].target.begin(), samples[i].target.end()); + } + } + + // Checks rewinding. + sampler.rewind(); + BOOST_CHECK(sampler.hasSamples()); + + // Re-checks head samples. + { + vector samples = sampler.getSamples(); + BOOST_CHECK_EQUAL(::batch_size, samples.size()); + for (unsigned i = 0; i < ::expected_src.size(); ++i) { + BOOST_CHECK_EQUAL_COLLECTIONS( + ::expected_src[i].begin(), ::expected_src[i].end(), + samples[i].source.begin(), samples[i].source.end()); + BOOST_CHECK_EQUAL_COLLECTIONS( + ::expected_trg[i].begin(), ::expected_trg[i].end(), + samples[i].target.begin(), samples[i].target.end()); + } + } +} + +BOOST_AUTO_TEST_CASE(CheckIteration2) { + nmtkit::WordVocabulary src_vocab, trg_vocab; + ::loadArchive(::src_vocab_filename, &src_vocab); + ::loadArchive(::trg_vocab_filename, &trg_vocab); + nmtkit::TestSampler sampler( + ::src_tok_filename, ::trg_tok_filename, + src_vocab, trg_vocab, ::batch_size); + + BOOST_CHECK(sampler.hasSamples()); + + // Checks head samples. + { + vector samples = sampler.getTestSamples(); + BOOST_CHECK_EQUAL(::batch_size, samples.size()); + for (unsigned i = 0; i < ::expected_src.size(); ++i) { + BOOST_CHECK_EQUAL_COLLECTIONS( + ::expected_src[i].begin(), ::expected_src[i].end(), + samples[i].source.begin(), samples[i].source.end()); + BOOST_CHECK_EQUAL(expected_src_strings[i], samples[i].source_string); + BOOST_CHECK_EQUAL_COLLECTIONS( + ::expected_trg[i].begin(), ::expected_trg[i].end(), + samples[i].target.begin(), samples[i].target.end()); + BOOST_CHECK_EQUAL(expected_trg_strings[i], samples[i].target_string); + } + } + + // Checks rewinding. + sampler.rewind(); + BOOST_CHECK(sampler.hasSamples()); + + // Re-checks head samples. + { + vector samples = sampler.getSamples(); + BOOST_CHECK_EQUAL(::batch_size, samples.size()); + for (unsigned i = 0; i < ::expected_src.size(); ++i) { + BOOST_CHECK_EQUAL_COLLECTIONS( + ::expected_src[i].begin(), ::expected_src[i].end(), + samples[i].source.begin(), samples[i].source.end()); + BOOST_CHECK_EQUAL_COLLECTIONS( + ::expected_trg[i].begin(), ::expected_trg[i].end(), + samples[i].target.begin(), samples[i].target.end()); + } + } +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/word_vocabulary_test.cc b/src/test/word_vocabulary_test.cc index facf548..ee8b2e7 100644 --- a/src/test/word_vocabulary_test.cc +++ b/src/test/word_vocabulary_test.cc @@ -114,6 +114,32 @@ BOOST_AUTO_TEST_CASE(CheckConvertingToIDs) { } } +BOOST_AUTO_TEST_CASE(CheckConvertingToTokens) { + nmtkit::WordVocabulary vocab; + ::loadArchive("data/small.en.vocab", &vocab); + const vector sentences { + "anything that can go wrong , will go wrong .", + "there is always light behind the clouds .", + "and yet it moves .", + "これ は 日本 語 の テスト 文 で す 。", + }; + const vector> expected { + {"", "that", "can", "go", "wrong", ",", "will", "go", + "wrong", "."}, + {"there", "is", "always", "", "behind", "the", "", "."}, + {"and", "yet", "it", "", "."}, + {"", "", "", "", "", "", "", + "", "", ""} + }; + + for (unsigned i = 0; i < sentences.size(); ++i) { + vector observed = vocab.convertToTokens(sentences[i]); + BOOST_CHECK_EQUAL_COLLECTIONS( + expected[i].begin(), expected[i].end(), + observed.begin(), observed.end()); + } +} + BOOST_AUTO_TEST_CASE(CheckConvertingToSentence) { nmtkit::WordVocabulary vocab; ::loadArchive("data/small.en.vocab", &vocab); @@ -136,4 +162,30 @@ BOOST_AUTO_TEST_CASE(CheckConvertingToSentence) { } } +BOOST_AUTO_TEST_CASE(CheckConvertingToTokenizedSentence) { + nmtkit::WordVocabulary vocab; + ::loadArchive("data/small.en.vocab", &vocab); + const vector> word_ids { + {0, 20, 41, 45, 134, 31, 37, 45, 134, 3}, + {39, 9, 85, 0, 400, 4, 0, 3}, + {56, 183, 16, 0, 3}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }; + const vector> expected { + {"", "that", "can", "go", "wrong", ",", "will", "go", + "wrong", "."}, + {"there", "is", "always", "", "behind", "the", "", "."}, + {"and", "yet", "it", "", "."}, + {"", "", "", "", "", "", "", + "", "", ""} + }; + + for (unsigned i = 0; i < word_ids.size(); ++i) { + vector observed = vocab.convertToTokenizedSentence(word_ids[i]); + BOOST_CHECK_EQUAL_COLLECTIONS( + expected[i].begin(), expected[i].end(), + observed.begin(), observed.end()); + } +} + BOOST_AUTO_TEST_SUITE_END() From ec3e4a6dacd01f11185a08bda2691f20820f3c60 Mon Sep 17 00:00:00 2001 From: Makoto Morishita Date: Sun, 8 Jan 2017 10:06:41 +0900 Subject: [PATCH 2/3] fix comment. --- src/include/nmtkit/test_corpus.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/include/nmtkit/test_corpus.h b/src/include/nmtkit/test_corpus.h index b145930..ff39129 100644 --- a/src/include/nmtkit/test_corpus.h +++ b/src/include/nmtkit/test_corpus.h @@ -101,10 +101,6 @@ class TestCorpus : public Corpus { // trg_filepath: Location of the target corpus file. // src_vocab: Vocabulary object for the source language. // trg_vocab: Vocabulary object for the target language. - // max_length: Maximum number of words in a sentence. Samples which exceeds - // this value will be skipped. - // max_length_ratio: Maximum ratio of lengths in source/target sentences. - // Samples which exceeds this value will be skipped. // result: Placeholder to store new source/target samples. Old data will be // deleted automatically before storing new samples. static void loadParallelSentences( From 031d7cf60f395a550ea7052dd73162772e35bb69 Mon Sep 17 00:00:00 2001 From: Makoto Morishita Date: Sun, 8 Jan 2017 12:58:16 +0900 Subject: [PATCH 3/3] fix Makefile.am --- src/include/Makefile.am | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/include/Makefile.am b/src/include/Makefile.am index 1088ae3..5237f7e 100644 --- a/src/include/Makefile.am +++ b/src/include/Makefile.am @@ -31,5 +31,7 @@ nobase_include_HEADERS = \ nmtkit/single_text_formatter.h \ nmtkit/softmax_predictor.h \ nmtkit/sorted_random_sampler.h \ + nmtkit/test_corpus.h \ + nmtkit/test_sampler.h \ nmtkit/vocabulary.h \ nmtkit/word_vocabulary.h