diff --git a/docs/guides/llms.md b/docs/guides/llms.md index b7b07f0d..a98237fc 100644 --- a/docs/guides/llms.md +++ b/docs/guides/llms.md @@ -172,32 +172,43 @@ kw_model = KeyLLM(llm) ### **LangChain** -To use LangChain, we can simply load in any LLM and pass that as a QA-chain to KeyLLM. +To use `langchain` LLM client in KeyLLM, we can simply load in any LLM in `langchain` and pass that to KeyLLM. -We install the package first: +We install langchain and corresponding LLM provider package first. Take OpenAI as an example: ```bash pip install langchain +pip install langchain-openai # LLM provider package ``` +> [!NOTE] +> KeyBERT only supports `langchain >= 0.1` + -Then we run LangChain as follows: +Then create your LLM client with `langchain` ```python -from langchain.chains.question_answering import load_qa_chain -from langchain.llms import OpenAI -chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff") +from langchain_openai import ChatOpenAI + +_llm = ChatOpenAI( + model="gpt-4o", + api_key="my-openai-api-key", + temperature=0, +) ``` -Finally, you can pass the chain to KeyBERT as follows: +Finally, pass the `langchain` llm client to KeyBERT as follows: ```python from keybert.llm import LangChain from keybert import KeyLLM # Create your LLM -llm = LangChain(chain) +llm = LangChain(_llm) # Load it in KeyLLM kw_model = KeyLLM(llm) + +# Extract keywords +keywords = kw_model.extract_keywords(MY_DOCUMENTS) ``` diff --git a/keybert/llm/__init__.py b/keybert/llm/__init__.py index abe7c7d6..b222d953 100644 --- a/keybert/llm/__init__.py +++ b/keybert/llm/__init__.py @@ -32,9 +32,16 @@ # LangChain Generator try: from keybert.llm._langchain import LangChain -except ModuleNotFoundError: - msg = "`pip install langchain` \n\n" - LangChain = NotInstalled("langchain", "langchain", custom_msg=msg) +except ModuleNotFoundError as e: + if e.name == "langchain": + msg = "`pip install langchain` \n\n" + LangChain = NotInstalled("langchain", "langchain", custom_msg=msg) + elif e.name == "langchain_core": + msg = "`pip install -U langchain` to upgrade to langchain>=0.1\n\n" + LangChain = NotInstalled("langchain", "langchain", custom_msg=msg) + else: + # not caused by importing langchain or langchain_core + raise e # LiteLLM try: diff --git a/keybert/llm/_langchain.py b/keybert/llm/_langchain.py index f786109e..681f2536 100644 --- a/keybert/llm/_langchain.py +++ b/keybert/llm/_langchain.py @@ -1,29 +1,37 @@ -from tqdm import tqdm from typing import List -from langchain.docstore.document import Document + +from langchain.prompts import ChatPromptTemplate, PromptTemplate +from langchain_core.language_models.chat_models import BaseChatModel as LangChainBaseChatModel +from langchain_core.language_models.llms import BaseLLM as LangChainBaseLLM +from langchain_core.output_parsers import StrOutputParser +from tqdm import tqdm + from keybert.llm._base import BaseLLM from keybert.llm._utils import process_candidate_keywords - -DEFAULT_PROMPT = "What is this document about? Please provide keywords separated by commas." +"""NOTE +KeyBERT only supports `langchain >= 0.1` which features: +- [Runnable Interface](https://python.langchain.com/docs/concepts/runnables/) +- [LangChain Expression Language (LCEL)](https://python.langchain.com/docs/concepts/lcel/) +""" class LangChain(BaseLLM): """Using chains in langchain to generate keywords. - Currently, only chains from question answering is implemented. See: - https://langchain.readthedocs.io/en/latest/modules/chains/combine_docs_examples/question_answering.html - - NOTE: The resulting keywords are expected to be separated by commas so - any changes to the prompt will have to make sure that the resulting - keywords are comma-separated. - Arguments: - chain: A langchain chain that has two input parameters, `input_documents` and `query`. + llm: A langchain LLM class. e.g ChatOpenAI, OpenAI, etc. prompt: The prompt to be used in the model. If no prompt is given, - `self.default_prompt_` is used instead. + `self.DEFAULT_PROMPT_TEMPLATE` is used instead. + NOTE: The prompt should contain: + 1. Placeholders + - `[DOCUMENT]`: Required. The document to extract keywords from. + - `[CANDIDATES]`: Optional. The candidate keywords to fine-tune the extraction. + 2. Output format instructions + - Include this or something similar in your prompt: + "Extracted keywords must be separated by comma." verbose: Set this to True if you want to see a progress bar for the - keyword extraction. + keyword extraction. Usage: @@ -32,14 +40,18 @@ class LangChain(BaseLLM): like openai: `pip install langchain` - `pip install openai` + `pip install langchain-openai` Then, you can create your chain as follows: ```python - from langchain.chains.question_answering import load_qa_chain - from langchain.llms import OpenAI - chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff") + from langchain_openai import ChatOpenAI + + _llm = ChatOpenAI( + model="gpt-4o", + api_key="my-openai-api-key", + temperature=0, + ) ``` Finally, you can pass the chain to KeyBERT as follows: @@ -49,14 +61,39 @@ class LangChain(BaseLLM): from keybert import KeyLLM # Create your LLM - llm = LangChain(chain) + llm = LangChain(_llm) # Load it in KeyLLM kw_model = KeyLLM(llm) # Extract keywords - document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine." - keywords = kw_model.extract_keywords(document) + docs = [ + "KeyBERT: A minimal method for keyword extraction with BERT. The keyword extraction is done by finding the sub-phrases in a document that are the most similar to the document itself. First, document embeddings are extracted with BERT to get a document-level representation. Then, word embeddings are extracted for N-gram words/phrases. Finally, we use cosine similarity to find the words/phrases that are the most similar to the document. The most similar words could then be identified as the words that best describe the entire document.", + "KeyLLM: A minimal method for keyword extraction with Large Language Models (LLM). The keyword extraction is done by simply asking the LLM to extract a number of keywords from a single piece of text.", + ] + keywords = kw_model.extract_keywords(docs=docs) + print(keywords) + + # Output: + # [ + # ['KeyBERT', 'keyword extraction', 'BERT', 'document embeddings', 'word embeddings', 'N-gram phrases', 'cosine similarity', 'document representation'], + # ['KeyLLM', 'keyword extraction', 'Large Language Models', 'LLM', 'minimal method'] + # ] + + + # fine tune with candidate keywords + candidates = [ + ["keyword extraction", "Large Language Models", "LLM", "BERT", "transformer", "embeddings"], + ["keyword extraction", "Large Language Models", "LLM", "BERT", "transformer", "embeddings"], + ] + keywords = kw_model.extract_keywords(docs=docs, candidate_keywords=candidates) + print(keywords) + + # Output: + # [ + # ['keyword extraction', 'BERT', 'document embeddings', 'word embeddings', 'cosine similarity', 'N-gram phrases'], + # ['KeyLLM', 'keyword extraction', 'Large Language Models', 'LLM'] + # ] ``` You can also use a custom prompt: @@ -67,16 +104,35 @@ class LangChain(BaseLLM): ``` """ + DEFAULT_PROMPT_TEMPLATE = """ +# Task +You are provided with a document and possiblily a list of candidate keywords. + +If no candidate keywords are provided, your task to is extract keywords from the document. +If candidate keywords are provided, your task is to improve the candidate keywords to best describe the topic of the document. + +# Document +[DOCUMENT] + +# Candidate Keywords +[CANDIDATES] + + +Now extract the keywords from the document. +The keywords must be comma separated. +For example: "keyword1, keyword2, keyword3" +""" + def __init__( self, - chain, + llm: LangChainBaseChatModel | LangChainBaseLLM, prompt: str = None, verbose: bool = False, ): - self.chain = chain - self.prompt = prompt if prompt is not None else DEFAULT_PROMPT - self.default_prompt_ = DEFAULT_PROMPT + self.llm = llm + self.prompt = prompt if prompt is not None else self.DEFAULT_PROMPT_TEMPLATE self.verbose = verbose + self.chain = self._get_chain() def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None): """Extract topics. @@ -95,12 +151,19 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s candidate_keywords = process_candidate_keywords(documents, candidate_keywords) for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose): - prompt = self.prompt.replace("[DOCUMENT]", document) - if candidates is not None: - prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates)) - input_document = Document(page_content=document) - keywords = self.chain.run(input_documents=[input_document], question=self.prompt).strip() + keywords = self.chain.invoke({"DOCUMENT": document, "CANDIDATES": candidates}) keywords = [keyword.strip() for keyword in keywords.split(",")] all_keywords.append(keywords) return all_keywords + + def _get_chain(self): + """Get the chain using LLM and prompt.""" + # format prompt for langchain template placeholders + prompt = self.prompt.replace("[DOCUMENT]", "{DOCUMENT}").replace("[CANDIDATES]", "{CANDIDATES}") + # check if the model is a chat model + is_chat_model = isinstance(self.llm, LangChainBaseChatModel) + # langchain prompt template + prompt_template = ChatPromptTemplate([("human", prompt)]) if is_chat_model else PromptTemplate(template=prompt) + # chain + return prompt_template | self.llm | StrOutputParser()