diff --git a/rpunct/punctuate.py b/rpunct/punctuate.py index c0143c08..e5f61eb8 100644 --- a/rpunct/punctuate.py +++ b/rpunct/punctuate.py @@ -10,14 +10,14 @@ class RestorePuncts: - def __init__(self, wrds_per_pred=250): + def __init__(self, model='felflare/bert-restore-punctuation', wrds_per_pred=250, use_cuda=False, silent=False): self.wrds_per_pred = wrds_per_pred self.overlap_wrds = 30 self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U'] - self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels, - args={"silent": True, "max_seq_length": 512}) + self.model = NERModel("bert", f"{model}", labels=self.valid_labels, use_cuda=use_cuda, + args={"silent": silent, "max_seq_length": 512}) - def punctuate(self, text: str, lang:str=''): + def punctuate(self, text: str, lang: str = ''): """ Performs punctuation restoration on arbitrarily large text. Detects if input is not English, if non-English was detected terminates predictions. @@ -38,8 +38,10 @@ def punctuate(self, text: str, lang:str=''): splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds) # predict slices # full_preds_lst contains tuple of labels and logits + print(f'predicting {len(splits)} slices') full_preds_lst = [self.predict(i['text']) for i in splits] # extract predictions, and discard logits + print(f'combining predictions') preds_lst = [i[0][0] for i in full_preds_lst] # join text slices combined_preds = self.combine_results(text, preds_lst)