From 7465f443653c4e743945fe214f5c78f3a13b7272 Mon Sep 17 00:00:00 2001 From: beshr eldebuch Date: Thu, 28 Jul 2022 17:32:50 +0300 Subject: [PATCH] Fix max_seq_length issue --- keybert/_model.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/keybert/_model.py b/keybert/_model.py index 904fcec7..137ce96d 100644 --- a/keybert/_model.py +++ b/keybert/_model.py @@ -65,6 +65,7 @@ def extract_keywords( vectorizer: CountVectorizer = None, highlight: bool = False, seed_keywords: List[str] = None, + max_seq_length: int = None, ) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]: """Extract keywords and/or keyphrases @@ -151,7 +152,25 @@ def extract_keywords( df = count.transform(docs) # Extract embeddings - doc_embeddings = self.model.embed(docs) + + # if the max_seq_length is smaller than the length of the document, + # the document will be truncated. + # To overcome this issue, the embedding of each document having size of the max_seq_length is calculated + # then we take the average of the embeddings of each document. + doc_embeddings = [] + if max_seq_length: + for i in range (len(docs)): + splitted_doc = docs[i].split() + # this temporarily holds the embedding of each max_seq_length chunck + temp_doc_embedding = [] + for i in range (0, len(splitted_doc), max_seq_length): + input_text = " ".join(splitted_doc[i:i+max_seq_length]) + temp_doc_embedding.append(self.model.embed([input_text])) + #final embedding of the document + doc_embeddings.append(np.mean(temp_doc_embedding, axis=0)) + else: + doc_embeddings = self.model.embed(docs) + # extract embeding for keyword candidates word_embeddings = self.model.embed(words) # Find keywords