Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,42 @@
# Bert-Pytorch-Chinese-TextClassification
Pytorch Bert Finetune in Chinese Text Classification

---

### Update

* Add save / load function to save and load the model that you trained

When the training finished,the model will be saved in `/src`

```python
# add optional arg --from_trained(bool default False)
# if True then evoke
model = torch.load(arg.init_checkpoint)

# if False then evoke like original
model = BertForSequenceClassification(...)
model.bert.load_state_dict(...)
```

* Add more evaluation index(when doing bi-classification task like check real or fake):
* precision
* recall
* F1

```python
# add optional arg --is_bi-classification(bool default True)
# if True then enable calculating
if args.is_bi-classification:
bin_out = np.argmax(logits, axis=1)
TP += ((bin_out == label_ids) & (label_ids == 0)).sum().item()
FP += ((bin_out != label_ids) & (label_ids == 0)).sum().item()
TN += ((bin_out == label_ids) & (label_ids == 1)).sum().item()
FN += ((bin_out != label_ids) & (label_ids == 1)).sum().item()
```

---

### Step 1

Download the pretrained TensorFlow model:[chinese_L-12_H-768_A-12](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)
Expand Down
33 changes: 30 additions & 3 deletions src/run_classifier_word.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,14 @@ def main():
default=None,
type=str,
help="Initial checkpoint (usually from a pre-trained BERT model).")
parser.add_argument("--from_trained",
default=False,
type=bool,
help="Whether the checkpoint you load is trained")
parser.add_argument("--is_bi_classification",
default=True,
type=bool,
help="Whether the task only have two kinds of label(like real and fake)")
parser.add_argument("--do_lower_case",
default=False,
action='store_true',
Expand Down Expand Up @@ -550,7 +558,10 @@ def main():
# Prepare model
model = BertForSequenceClassification(bert_config, len(label_list))
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
if args.from_trained:
model = torch.load(args.init_checkpoint, map_location='cpu')
else:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
if args.fp16:
model.half()
model.to(device)
Expand Down Expand Up @@ -638,6 +649,7 @@ def main():
optimizer.step()
model.zero_grad()
global_step += 1
torch.save(model, 'news_model.bin')

if args.do_eval:
eval_examples = processor.get_dev_examples(args.data_dir)
Expand All @@ -660,7 +672,8 @@ def main():
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

model.eval()
eval_loss, eval_accuracy = 0, 0
eval_loss, eval_accuracy, eval_precision, eval_recall, eval_f1 = 0, 0, 0, 0, 0
TP, FP, TN, FN = 0, 0, 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
input_ids = input_ids.to(device)
Expand All @@ -675,6 +688,13 @@ def main():
label_ids = label_ids.to('cpu').numpy()
tmp_eval_accuracy = accuracy(logits, label_ids)

if args.is_bi_classification:
bin_out = np.argmax(logits, axis=1)
TP += ((bin_out == label_ids) & (label_ids == 0)).sum().item()
FP += ((bin_out != label_ids) & (label_ids == 0)).sum().item()
TN += ((bin_out == label_ids) & (label_ids == 1)).sum().item()
FN += ((bin_out != label_ids) & (label_ids == 1)).sum().item()

eval_loss += tmp_eval_loss.mean().item()
eval_accuracy += tmp_eval_accuracy

Expand All @@ -684,10 +704,17 @@ def main():
eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples

eval_precision = TP / (TP + FP)
eval_recall = TP / (TP + FN)
eval_f1 = 2 * eval_precision * eval_recall / (eval_precision + eval_recall)

result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy,
'global_step': global_step,
'loss': tr_loss / nb_tr_steps}
'loss': tr_loss / nb_tr_steps,
'eval_precision': eval_precision,
'eval_recall': eval_recall,
'eval_f1': eval_f1}

output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
Expand Down