diff --git a/examples/spacy_embeddings.py b/examples/spacy_embeddings.py index aa89e0a..55876fd 100644 --- a/examples/spacy_embeddings.py +++ b/examples/spacy_embeddings.py @@ -13,7 +13,7 @@ class SpacyEmbeddings(Embeddings): """A class that uses spaCy to generate embeddings for the given input. Args: - model (str): The spaCy model to use, e.g., 'en_core_web_md'. + model: The spaCy model to use, e.g., 'en_core_web_md'. **kwargs: Additional keyword arguments supported by spaCy's `load` function. """ @@ -35,10 +35,10 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: """Call spaCy's model to generate embeddings for a list of documents. Args: - texts (list[str]): The list of texts to embed. + texts: The list of texts to embed. Returns: - list[list[float]]: List of embeddings, one for each text. + List of embeddings, one for each text. """ return [self._nlp(text).vector.tolist() for text in texts] @@ -46,10 +46,10 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronously generate embeddings for a list of documents using spaCy's model. Args: - texts (List[str]): The list of texts to embed. + texts: The list of texts to embed. Returns: - list[list[float]]: List of embeddings, one for each text. + List of embeddings, one for each text. """ return await asyncio.to_thread(self.embed_documents, texts) @@ -57,10 +57,10 @@ def embed_query(self, text: str) -> list[float]: """Generate an embedding for a single query text. Args: - text (str): The text to embed. + text: The text to embed. Returns: - list[float]: Embedding for the text. + Embedding for the text. """ return self._nlp(text).vector.tolist() @@ -68,10 +68,10 @@ async def aembed_query(self, text: str) -> list[float]: """Asynchronously generate embedding for a single query text. Args: - text (str): The text to embed. + text: The text to embed. Returns: - list[float]: Embedding for the text. + Embedding for the text. """ return await asyncio.to_thread(self.embed_query, text) diff --git a/src/clonellm/core.py b/src/clonellm/core.py index 65ce384..c0b5a2b 100644 --- a/src/clonellm/core.py +++ b/src/clonellm/core.py @@ -45,16 +45,15 @@ class CloneLLM(LiteLLMMixin): """Creates an LLM clone of a user based on provided user profile and related context. Args: - model (str): Name of the language model. - documents (list[Document | str]): List of documents or strings related to cloning user to use for LLM context. - embedding (Embeddings | None): The embedding function to use for RAG. Defaults to None for no embedding, i.e., a summary of `documents` is used for RAG. - vector_store (str | RagVectorStore | None): The vector store to use for embedding-based retrieval. Defaults to None for "in-memory" vector store. - user_profile (UserProfile | dict[str, Any] | str | None): The profile of the user to be cloned by the language model. Defaults to None. - memory (bool | int | None): Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory. - api_key (str | None): The API key to use. Defaults to None. - system_prompts (list[str] | None): Additional system prompts (instructions) for the language model. Defaults to None. - **kwargs (Any): Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class. - + model: Name of the language model. + documents: List of documents or strings related to cloning user to use for LLM context. + embedding: The embedding function to use for RAG. If not provided, a summary of `documents` is used for RAG. + vector_store: The vector store to use for embedding-based retrieval. Defaults to None for "in-memory" vector store. + user_profile: The profile of the user to be cloned by the language model. Defaults to None. + memory: Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory. + api_key: The API key to use. Defaults to None. + system_prompts: Additional system prompts (instructions) for the language model. Defaults to None. + **kwargs: Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class. """ _VECTOR_STORE_COLLECTION_NAME = "clonellm" @@ -143,17 +142,17 @@ def from_persist_directory( """Creates an instance of CloneLLM by loading a Chroma vector store from a persistent directory. Args: - model (str): Name of the language model. - chroma_persist_directory (str): Directory path to the persisted Chroma vector store. - embedding (Embeddings | None): The embedding function to use for Chroma store. Defaults to None for no embedding, i.e., a summary of `documents` is used for RAG. - user_profile (UserProfile | dict[str, Any] | str | None): The profile of the user to be cloned by the language model. Defaults to None. - memory (bool | int | None): Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory. - api_key (str | None): The API key to use. Defaults to None. - system_prompts (list[str] | None): Additional system prompts (instructions) for the language model. Defaults to None. - **kwargs (Any): Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class. + model: Name of the language model. + chroma_persist_directory: Directory path to the persisted Chroma vector store. + embedding: The embedding function to use for Chroma store. If not provided, a summary of `documents` is used for RAG. + user_profile: The profile of the user to be cloned by the language model. Defaults to None. + memory: Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory. + api_key: The API key to use. Defaults to None. + system_prompts: Additional system prompts (instructions) for the language model. Defaults to None. + **kwargs: Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class. Returns: - CloneLLM: An instance of CloneLLM with Chroma-based retrieval. + An instance of CloneLLM with Chroma-based retrieval. """ kwargs.update( { @@ -193,16 +192,16 @@ def from_context( """Creates an instance of CloneLLM using a summarized context string instead of documents. Args: - model (str): Name of the language model. - context (str): Pre-summarized context string for the language model. - user_profile (UserProfile | dict[str, Any] | str | None): The profile of the user to be cloned by the language model. Defaults to None. - memory (bool | int | None): Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory. - api_key (str | None): The API key to use. Defaults to None. - system_prompts (list[str] | None): Additional system prompts (instructions) for the language model. Defaults to None. - **kwargs (Any): Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class. + model: Name of the language model. + context: Pre-summarized context string for the language model. + user_profile: The profile of the user to be cloned by the language model. Defaults to None. + memory: Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory. + api_key: The API key to use. Defaults to None. + system_prompts: Additional system prompts (instructions) for the language model. Defaults to None. + **kwargs: Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class. Returns: - CloneLLM instance using the provided context. + An instance of CloneLLM with the provided context. """ kwargs.update({cls._FROM_CLASS_METHOD_KWARG: {"context": context, "_is_fitted": True}}) return cls( @@ -238,7 +237,7 @@ def fit(self) -> Self: Embeds the documents for retrieval using the selected vector store or generates a summarized context. Returns: - CloneLLM: Fitted CloneLLM instance. + Fitted CloneLLM instance. """ documents = self._get_documents() if self.embedding: @@ -304,10 +303,10 @@ def update(self, documents: list[Document | str]) -> Self: """Updates the CloneLLM with additional documents, either embedding them or updating the context. Args: - documents (list[Document | str]): Additional documents to add to the model. + documents: Additional documents to add to the model. Returns: - CloneLLM: Updated CloneLLM instance. + Updated CloneLLM instance. """ self._check_is_fitted(from_update=True) documents_ = self._get_documents(documents) @@ -322,10 +321,10 @@ async def aupdate(self, documents: list[Document | str]) -> Self: """Asynchronously updates the CloneLLM with additional documents, either embedding them or updating the context. Args: - documents (list[Document | str]): Additional documents to add to the model. + documents: Additional documents to add to the model. Returns: - CloneLLM: Updated CloneLLM instance. + Updated CloneLLM instance. """ self._check_is_fitted(from_update=True) documents_ = self._get_documents(documents) @@ -390,10 +389,10 @@ def invoke(self, prompt: str) -> str: This method uses the underlying language model to simulate responses as if coming from the cloned user profile. Args: - prompt (str): Input prompt for the cloned language model. + prompt: Input prompt for the cloned language model. Returns: - str: The generated response from the language model as the cloned user. + The generated response from the language model as the cloned user. """ self._check_is_fitted() if self.memory: @@ -409,10 +408,10 @@ async def ainvoke(self, prompt: str) -> str: This method uses the underlying language model to simulate responses as if coming from the cloned user profile. Args: - prompt (str): Input prompt for the cloned language model. + prompt: Input prompt for the cloned language model. Returns: - str: The generated response from the language model as the cloned user. + The generated response from the language model as the cloned user. """ self._check_is_fitted() if self.memory: @@ -428,10 +427,10 @@ def stream(self, prompt: str) -> Iterator[str]: """Streams responses from the cloned language model for a given prompt, returning the output in chunks. Args: - prompt (str): Input prompt for the cloned language model. + prompt: Input prompt for the cloned language model. Returns: - Iterator[str]: An iterator over the streamed response chunks from the cloned language model. + An iterator over the streamed response chunks from the cloned language model. """ self._check_is_fitted() if self.memory: @@ -449,10 +448,10 @@ async def astream(self, prompt: str) -> AsyncIterator[str]: """Asynchronously streams responses from the cloned language model for a given prompt, returning the output in chunks. Args: - prompt (str): Input prompt for the cloned language model. + prompt: Input prompt for the cloned language model. Returns: - AsyncIterator[str]: An asynchronous iterator over the streamed response chunks from the cloned language model. + An asynchronous iterator over the streamed response chunks from the cloned language model. """ self._check_is_fitted() if self.memory: diff --git a/src/clonellm/embed.py b/src/clonellm/embed.py index bd02225..d66af57 100644 --- a/src/clonellm/embed.py +++ b/src/clonellm/embed.py @@ -12,11 +12,10 @@ class LiteLLMEmbeddings(LiteLLMMixin, Embeddings): """A class that uses LiteLLM to call an LLM's API to generate embeddings for the given input. Args: - model (str): The embedding model to use. - api_key (str | None): The API key to use. Defaults to None. - dimensions (int | None): The number of dimensions the resulting output embeddings should have. Defaults to None. - **kwargs (Any): Additional keyword arguments supported by the `litellm.embedding` and `litellm.aembedding` functions. - + model: The embedding model to use. + api_key: The API key to use. Defaults to None. + dimensions: The number of dimensions the resulting output embeddings should have. Defaults to None. + **kwargs: Additional keyword arguments supported by the `litellm.embedding` and `litellm.aembedding` functions. """ def __init__(self, model: str, api_key: str | None = None, dimensions: int | None = None, **kwargs: Any) -> None: @@ -27,10 +26,10 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: """Call out to LLM's embedding endpoint for embedding a list of documents. Args: - texts (list[str]): The list of texts to embed. + texts: The list of texts to embed. Returns: - list[list[float]]: List of embeddings, one for each text. + List of embeddings, one for each text. """ response = embedding( model=self.model, input=texts, api_key=self.api_key, dimensions=self.dimensions, **self._litellm_kwargs @@ -41,10 +40,10 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Call out to LLM's embedding endpoint async for embedding a list of documents. Args: - texts (list[str]): The list of texts to embed. + texts: The list of texts to embed. Returns: - list[list[float]]: List of embeddings, one for each text. + List of embeddings, one for each text. """ response = await aembedding( model=self.model, input=texts, api_key=self.api_key, dimensions=self.dimensions, **self._litellm_kwargs @@ -55,10 +54,10 @@ def embed_query(self, text: str) -> list[float]: """Call out to LLM's embedding endpoint for embedding query text. Args: - text (str): The text to embed. + text: The text to embed. Returns: - list[float]: Embedding for the text. + Embedding for the text. """ return self.embed_documents([text])[0] @@ -66,10 +65,10 @@ async def aembed_query(self, text: str) -> list[float]: """Call out to LLM's embedding endpoint async for embedding query text. Args: - text (str): The text to embed. + text: The text to embed. Returns: - list[float]: Embedding for the text. + Embedding for the text. """ embeddings = await self.aembed_documents([text]) return embeddings[0] diff --git a/src/clonellm/memory.py b/src/clonellm/memory.py index a630bbb..142dee2 100644 --- a/src/clonellm/memory.py +++ b/src/clonellm/memory.py @@ -28,7 +28,7 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None: """Add a list of messages to the store Args: - messages (Sequence[BaseMessage]): A list of BaseMessage objects to store. + messages: A list of BaseMessage objects to store. """ for message in messages: self.messages.append(message)