From 41e568c61c833e0a99519aebfb04a223116fc6c7 Mon Sep 17 00:00:00 2001 From: oceanfish <807544076@qq.com> Date: Sun, 7 Apr 2024 16:56:39 +0800 Subject: [PATCH 1/2] update for save/load --- README.md | 36 ++++++++++++++++++++++++++++++++++++ src/run_classifier_word.py | 33 ++++++++++++++++++++++++++++++--- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4479db7..f5806ca 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,42 @@ # Bert-Pytorch-Chinese-TextClassification Pytorch Bert Finetune in Chinese Text Classification +--- + +### Update + +* Add save / load function to save and load the model that you trained + +When the training finished,the model will be saved in `/src` + +```python +# add optional arg --from_trained(bool default False) +# if True then evoke +model = torch.load(arg.init_checkpoint) + +# if False then evoke like original +model = BertForSequenceClassification(...) +model.bert.load_state_dict(...) +``` + +* Add more evaluation index(when doing bi-classification task like check real or fake): + * precision + * recall + * F1 + +```python +# add optional arg --is_bi-classification(bool default True) +# if True then enable calculating +if args.is_bi-classification: + bin_out = np.argmax(logits, axis=1) + TP += ((bin_out == label_ids) & (label_ids == 0)).sum().item() + FP += ((bin_out != label_ids) & (label_ids == 0)).sum().item() + TN += ((bin_out == label_ids) & (label_ids == 1)).sum().item() + FN += ((bin_out != label_ids) & (label_ids == 1)).sum().item() +``` + +--- + ### Step 1 Download the pretrained TensorFlow model:[chinese_L-12_H-768_A-12](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip) diff --git a/src/run_classifier_word.py b/src/run_classifier_word.py index e195fd8..14e743f 100644 --- a/src/run_classifier_word.py +++ b/src/run_classifier_word.py @@ -407,6 +407,14 @@ def main(): default=None, type=str, help="Initial checkpoint (usually from a pre-trained BERT model).") + parser.add_argument("--from_trained", + default=False, + type=bool, + help="Whether the checkpoint you load is trained") + parser.add_argument("--is_bi-classification", + default=True, + type=bool, + help="Whether the task only have two kinds of label(like real and fake)") parser.add_argument("--do_lower_case", default=False, action='store_true', @@ -550,7 +558,10 @@ def main(): # Prepare model model = BertForSequenceClassification(bert_config, len(label_list)) if args.init_checkpoint is not None: - model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) + if args.from_trained: + model = torch.load(args.init_checkpoint, map_location='cpu') + else: + model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) if args.fp16: model.half() model.to(device) @@ -638,6 +649,7 @@ def main(): optimizer.step() model.zero_grad() global_step += 1 + torch.save(model, 'news_model.bin') if args.do_eval: eval_examples = processor.get_dev_examples(args.data_dir) @@ -660,7 +672,8 @@ def main(): eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) model.eval() - eval_loss, eval_accuracy = 0, 0 + eval_loss, eval_accuracy, eval_precision, eval_recall, eval_f1 = 0, 0, 0, 0, 0 + TP, FP, TN, FN = 0, 0, 0, 0 nb_eval_steps, nb_eval_examples = 0, 0 for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: input_ids = input_ids.to(device) @@ -675,6 +688,13 @@ def main(): label_ids = label_ids.to('cpu').numpy() tmp_eval_accuracy = accuracy(logits, label_ids) + if args.is_bi-classification: + bin_out = np.argmax(logits, axis=1) + TP += ((bin_out == label_ids) & (label_ids == 0)).sum().item() + FP += ((bin_out != label_ids) & (label_ids == 0)).sum().item() + TN += ((bin_out == label_ids) & (label_ids == 1)).sum().item() + FN += ((bin_out != label_ids) & (label_ids == 1)).sum().item() + eval_loss += tmp_eval_loss.mean().item() eval_accuracy += tmp_eval_accuracy @@ -684,10 +704,17 @@ def main(): eval_loss = eval_loss / nb_eval_steps eval_accuracy = eval_accuracy / nb_eval_examples + eval_precision = TP / (TP + FP) + eval_recall = TP / (TP + FN) + eval_f1 = 2 * eval_precision * eval_recall / (eval_precision + eval_recall) + result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy, 'global_step': global_step, - 'loss': tr_loss / nb_tr_steps} + 'loss': tr_loss / nb_tr_steps, + 'eval_precision': eval_precision, + 'eval_recall': eval_recall, + 'eval_f1': eval_f1} output_eval_file = os.path.join(args.output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: From 008bb9bb71e237443fbc78a4cb1d911a979256cc Mon Sep 17 00:00:00 2001 From: OceanFish <74172075+807544076@users.noreply.github.com> Date: Thu, 11 Apr 2024 22:10:36 +0800 Subject: [PATCH 2/2] fix arg is_bi_classification --- src/run_classifier_word.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/run_classifier_word.py b/src/run_classifier_word.py index 14e743f..538a8f5 100644 --- a/src/run_classifier_word.py +++ b/src/run_classifier_word.py @@ -411,7 +411,7 @@ def main(): default=False, type=bool, help="Whether the checkpoint you load is trained") - parser.add_argument("--is_bi-classification", + parser.add_argument("--is_bi_classification", default=True, type=bool, help="Whether the task only have two kinds of label(like real and fake)") @@ -688,7 +688,7 @@ def main(): label_ids = label_ids.to('cpu').numpy() tmp_eval_accuracy = accuracy(logits, label_ids) - if args.is_bi-classification: + if args.is_bi_classification: bin_out = np.argmax(logits, axis=1) TP += ((bin_out == label_ids) & (label_ids == 0)).sum().item() FP += ((bin_out != label_ids) & (label_ids == 0)).sum().item()