From 16cc293215bdcddd468c2cdc6a27fb8adba26abd Mon Sep 17 00:00:00 2001 From: nbertagnolli Date: Sat, 22 Jan 2022 11:33:05 -0700 Subject: [PATCH] Added ability to pass additional parameters to simpletransformer ner model. Fixing GPU only bug. --- rpunct/punctuate.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/rpunct/punctuate.py b/rpunct/punctuate.py index c0143c08..c8902cd7 100644 --- a/rpunct/punctuate.py +++ b/rpunct/punctuate.py @@ -7,15 +7,19 @@ import logging from langdetect import detect from simpletransformers.ner import NERModel +from typing import Any, Dict class RestorePuncts: - def __init__(self, wrds_per_pred=250): + def __init__(self, wrds_per_pred=250, ner_args: Dict[str, Any]=None): self.wrds_per_pred = wrds_per_pred self.overlap_wrds = 30 self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U'] + + if ner_args is None: + ner_args = {} self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels, - args={"silent": True, "max_seq_length": 512}) + args={"silent": True, "max_seq_length": 512}, **ner_args) def punctuate(self, text: str, lang:str=''): """