Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions examples/spacy_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SpacyEmbeddings(Embeddings):
"""A class that uses spaCy to generate embeddings for the given input.

Args:
model (str): The spaCy model to use, e.g., 'en_core_web_md'.
model: The spaCy model to use, e.g., 'en_core_web_md'.
**kwargs: Additional keyword arguments supported by spaCy's `load` function.
"""

Expand All @@ -35,43 +35,43 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Call spaCy's model to generate embeddings for a list of documents.

Args:
texts (list[str]): The list of texts to embed.
texts: The list of texts to embed.

Returns:
list[list[float]]: List of embeddings, one for each text.
List of embeddings, one for each text.
"""
return [self._nlp(text).vector.tolist() for text in texts]

async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronously generate embeddings for a list of documents using spaCy's model.

Args:
texts (List[str]): The list of texts to embed.
texts: The list of texts to embed.

Returns:
list[list[float]]: List of embeddings, one for each text.
List of embeddings, one for each text.
"""
return await asyncio.to_thread(self.embed_documents, texts)

def embed_query(self, text: str) -> list[float]:
"""Generate an embedding for a single query text.

Args:
text (str): The text to embed.
text: The text to embed.

Returns:
list[float]: Embedding for the text.
Embedding for the text.
"""
return self._nlp(text).vector.tolist()

async def aembed_query(self, text: str) -> list[float]:
"""Asynchronously generate embedding for a single query text.

Args:
text (str): The text to embed.
text: The text to embed.

Returns:
list[float]: Embedding for the text.
Embedding for the text.
"""
return await asyncio.to_thread(self.embed_query, text)

Expand Down
79 changes: 39 additions & 40 deletions src/clonellm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,15 @@ class CloneLLM(LiteLLMMixin):
"""Creates an LLM clone of a user based on provided user profile and related context.

Args:
model (str): Name of the language model.
documents (list[Document | str]): List of documents or strings related to cloning user to use for LLM context.
embedding (Embeddings | None): The embedding function to use for RAG. Defaults to None for no embedding, i.e., a summary of `documents` is used for RAG.
vector_store (str | RagVectorStore | None): The vector store to use for embedding-based retrieval. Defaults to None for "in-memory" vector store.
user_profile (UserProfile | dict[str, Any] | str | None): The profile of the user to be cloned by the language model. Defaults to None.
memory (bool | int | None): Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory.
api_key (str | None): The API key to use. Defaults to None.
system_prompts (list[str] | None): Additional system prompts (instructions) for the language model. Defaults to None.
**kwargs (Any): Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class.

model: Name of the language model.
documents: List of documents or strings related to cloning user to use for LLM context.
embedding: The embedding function to use for RAG. If not provided, a summary of `documents` is used for RAG.
vector_store: The vector store to use for embedding-based retrieval. Defaults to None for "in-memory" vector store.
user_profile: The profile of the user to be cloned by the language model. Defaults to None.
memory: Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory.
api_key: The API key to use. Defaults to None.
system_prompts: Additional system prompts (instructions) for the language model. Defaults to None.
**kwargs: Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class.
"""

_VECTOR_STORE_COLLECTION_NAME = "clonellm"
Expand Down Expand Up @@ -143,17 +142,17 @@ def from_persist_directory(
"""Creates an instance of CloneLLM by loading a Chroma vector store from a persistent directory.

Args:
model (str): Name of the language model.
chroma_persist_directory (str): Directory path to the persisted Chroma vector store.
embedding (Embeddings | None): The embedding function to use for Chroma store. Defaults to None for no embedding, i.e., a summary of `documents` is used for RAG.
user_profile (UserProfile | dict[str, Any] | str | None): The profile of the user to be cloned by the language model. Defaults to None.
memory (bool | int | None): Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory.
api_key (str | None): The API key to use. Defaults to None.
system_prompts (list[str] | None): Additional system prompts (instructions) for the language model. Defaults to None.
**kwargs (Any): Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class.
model: Name of the language model.
chroma_persist_directory: Directory path to the persisted Chroma vector store.
embedding: The embedding function to use for Chroma store. If not provided, a summary of `documents` is used for RAG.
user_profile: The profile of the user to be cloned by the language model. Defaults to None.
memory: Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory.
api_key: The API key to use. Defaults to None.
system_prompts: Additional system prompts (instructions) for the language model. Defaults to None.
**kwargs: Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class.

Returns:
CloneLLM: An instance of CloneLLM with Chroma-based retrieval.
An instance of CloneLLM with Chroma-based retrieval.
"""
kwargs.update(
{
Expand Down Expand Up @@ -193,16 +192,16 @@ def from_context(
"""Creates an instance of CloneLLM using a summarized context string instead of documents.

Args:
model (str): Name of the language model.
context (str): Pre-summarized context string for the language model.
user_profile (UserProfile | dict[str, Any] | str | None): The profile of the user to be cloned by the language model. Defaults to None.
memory (bool | int | None): Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory.
api_key (str | None): The API key to use. Defaults to None.
system_prompts (list[str] | None): Additional system prompts (instructions) for the language model. Defaults to None.
**kwargs (Any): Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class.
model: Name of the language model.
context: Pre-summarized context string for the language model.
user_profile: The profile of the user to be cloned by the language model. Defaults to None.
memory: Maximum number of messages in conversation memory. Defaults to None (or 0) for no memory. -1 or `True` means infinite memory.
api_key: The API key to use. Defaults to None.
system_prompts: Additional system prompts (instructions) for the language model. Defaults to None.
**kwargs: Additional keyword arguments supported by the `langchain_community.chat_models.ChatLiteLLM` class.

Returns:
CloneLLM instance using the provided context.
An instance of CloneLLM with the provided context.
"""
kwargs.update({cls._FROM_CLASS_METHOD_KWARG: {"context": context, "_is_fitted": True}})
return cls(
Expand Down Expand Up @@ -238,7 +237,7 @@ def fit(self) -> Self:
Embeds the documents for retrieval using the selected vector store or generates a summarized context.

Returns:
CloneLLM: Fitted CloneLLM instance.
Fitted CloneLLM instance.
"""
documents = self._get_documents()
if self.embedding:
Expand Down Expand Up @@ -304,10 +303,10 @@ def update(self, documents: list[Document | str]) -> Self:
"""Updates the CloneLLM with additional documents, either embedding them or updating the context.

Args:
documents (list[Document | str]): Additional documents to add to the model.
documents: Additional documents to add to the model.

Returns:
CloneLLM: Updated CloneLLM instance.
Updated CloneLLM instance.
"""
self._check_is_fitted(from_update=True)
documents_ = self._get_documents(documents)
Expand All @@ -322,10 +321,10 @@ async def aupdate(self, documents: list[Document | str]) -> Self:
"""Asynchronously updates the CloneLLM with additional documents, either embedding them or updating the context.

Args:
documents (list[Document | str]): Additional documents to add to the model.
documents: Additional documents to add to the model.

Returns:
CloneLLM: Updated CloneLLM instance.
Updated CloneLLM instance.
"""
self._check_is_fitted(from_update=True)
documents_ = self._get_documents(documents)
Expand Down Expand Up @@ -390,10 +389,10 @@ def invoke(self, prompt: str) -> str:
This method uses the underlying language model to simulate responses as if coming from the cloned user profile.

Args:
prompt (str): Input prompt for the cloned language model.
prompt: Input prompt for the cloned language model.

Returns:
str: The generated response from the language model as the cloned user.
The generated response from the language model as the cloned user.
"""
self._check_is_fitted()
if self.memory:
Expand All @@ -409,10 +408,10 @@ async def ainvoke(self, prompt: str) -> str:
This method uses the underlying language model to simulate responses as if coming from the cloned user profile.

Args:
prompt (str): Input prompt for the cloned language model.
prompt: Input prompt for the cloned language model.

Returns:
str: The generated response from the language model as the cloned user.
The generated response from the language model as the cloned user.
"""
self._check_is_fitted()
if self.memory:
Expand All @@ -428,10 +427,10 @@ def stream(self, prompt: str) -> Iterator[str]:
"""Streams responses from the cloned language model for a given prompt, returning the output in chunks.

Args:
prompt (str): Input prompt for the cloned language model.
prompt: Input prompt for the cloned language model.

Returns:
Iterator[str]: An iterator over the streamed response chunks from the cloned language model.
An iterator over the streamed response chunks from the cloned language model.
"""
self._check_is_fitted()
if self.memory:
Expand All @@ -449,10 +448,10 @@ async def astream(self, prompt: str) -> AsyncIterator[str]:
"""Asynchronously streams responses from the cloned language model for a given prompt, returning the output in chunks.

Args:
prompt (str): Input prompt for the cloned language model.
prompt: Input prompt for the cloned language model.

Returns:
AsyncIterator[str]: An asynchronous iterator over the streamed response chunks from the cloned language model.
An asynchronous iterator over the streamed response chunks from the cloned language model.
"""
self._check_is_fitted()
if self.memory:
Expand Down
25 changes: 12 additions & 13 deletions src/clonellm/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ class LiteLLMEmbeddings(LiteLLMMixin, Embeddings):
"""A class that uses LiteLLM to call an LLM's API to generate embeddings for the given input.

Args:
model (str): The embedding model to use.
api_key (str | None): The API key to use. Defaults to None.
dimensions (int | None): The number of dimensions the resulting output embeddings should have. Defaults to None.
**kwargs (Any): Additional keyword arguments supported by the `litellm.embedding` and `litellm.aembedding` functions.

model: The embedding model to use.
api_key: The API key to use. Defaults to None.
dimensions: The number of dimensions the resulting output embeddings should have. Defaults to None.
**kwargs: Additional keyword arguments supported by the `litellm.embedding` and `litellm.aembedding` functions.
"""

def __init__(self, model: str, api_key: str | None = None, dimensions: int | None = None, **kwargs: Any) -> None:
Expand All @@ -27,10 +26,10 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Call out to LLM's embedding endpoint for embedding a list of documents.

Args:
texts (list[str]): The list of texts to embed.
texts: The list of texts to embed.

Returns:
list[list[float]]: List of embeddings, one for each text.
List of embeddings, one for each text.
"""
response = embedding(
model=self.model, input=texts, api_key=self.api_key, dimensions=self.dimensions, **self._litellm_kwargs
Expand All @@ -41,10 +40,10 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Call out to LLM's embedding endpoint async for embedding a list of documents.

Args:
texts (list[str]): The list of texts to embed.
texts: The list of texts to embed.

Returns:
list[list[float]]: List of embeddings, one for each text.
List of embeddings, one for each text.
"""
response = await aembedding(
model=self.model, input=texts, api_key=self.api_key, dimensions=self.dimensions, **self._litellm_kwargs
Expand All @@ -55,21 +54,21 @@ def embed_query(self, text: str) -> list[float]:
"""Call out to LLM's embedding endpoint for embedding query text.

Args:
text (str): The text to embed.
text: The text to embed.

Returns:
list[float]: Embedding for the text.
Embedding for the text.
"""
return self.embed_documents([text])[0]

async def aembed_query(self, text: str) -> list[float]:
"""Call out to LLM's embedding endpoint async for embedding query text.

Args:
text (str): The text to embed.
text: The text to embed.

Returns:
list[float]: Embedding for the text.
Embedding for the text.
"""
embeddings = await self.aembed_documents([text])
return embeddings[0]
Expand Down
2 changes: 1 addition & 1 deletion src/clonellm/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Add a list of messages to the store

Args:
messages (Sequence[BaseMessage]): A list of BaseMessage objects to store.
messages: A list of BaseMessage objects to store.
"""
for message in messages:
self.messages.append(message)
Expand Down
Loading