Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def build_request(
model=request_config.model,
)

logging.info(
logger.info(
"chat message params budgeted; message count: %d, total token count: %d",
len(chat_message_params),
total_token_count,
Expand Down
24 changes: 21 additions & 3 deletions assistants/codespace-assistant/assistant/response/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
from contextlib import AsyncExitStack
from typing import Any

from assistant_extensions.chat_context_toolkit.archive import ArchiveTaskQueues
from assistant_extensions.attachments import get_attachments
from assistant_extensions.chat_context_toolkit.archive import (
ArchiveTaskQueues,
construct_archive_summarizer,
)
from assistant_extensions.chat_context_toolkit.message_history import (
construct_attachment_summarizer,
)
from assistant_extensions.mcp import (
MCPClientSettings,
MCPServerConnectionError,
Expand Down Expand Up @@ -166,8 +173,19 @@ async def message_handler(message) -> None:
# enqueue an archive task for this conversation
await archive_task_queues.enqueue_run(
context=context,
service_config=service_config,
request_config=request_config,
attachments=list(
await get_attachments(
context,
summarizer=construct_attachment_summarizer(
service_config=service_config,
request_config=request_config,
),
)
),
archive_summarizer=construct_archive_summarizer(
service_config=service_config,
request_config=request_config,
),
archive_task_config=ArchiveTaskConfig(
chunk_token_count_threshold=config.chat_context_config.archive_token_threshold
),
Expand Down
17 changes: 14 additions & 3 deletions assistants/codespace-assistant/assistant/response/step_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from typing import Any, List

import deepmerge
from assistant_extensions.chat_context_toolkit.message_history import chat_context_toolkit_message_provider_for
from assistant_extensions.attachments import get_attachments
from assistant_extensions.chat_context_toolkit.message_history import (
chat_context_toolkit_message_provider_for,
construct_attachment_summarizer,
)
from assistant_extensions.chat_context_toolkit.virtual_filesystem import (
archive_file_source_mount,
attachments_file_source_mount,
Expand Down Expand Up @@ -100,8 +104,15 @@ async def handle_error(error_message: str, error_debug: dict[str, Any] | None =
history_message_provider = chat_context_toolkit_message_provider_for(
context=context,
tool_abbreviations=abbreviations.tool_abbreviations,
service_config=service_config,
request_config=request_config,
attachments=list(
await get_attachments(
context,
summarizer=construct_attachment_summarizer(
service_config=service_config,
request_config=request_config,
),
)
),
)

build_request_result = await build_request(
Expand Down
20 changes: 15 additions & 5 deletions assistants/document-assistant/assistant/response/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

import deepmerge
import pendulum
from assistant_extensions.chat_context_toolkit.message_history import chat_context_toolkit_message_provider_for
from assistant_extensions.attachments import get_attachments
from assistant_extensions.chat_context_toolkit.message_history import (
chat_context_toolkit_message_provider_for,
construct_attachment_summarizer,
)
from assistant_extensions.chat_context_toolkit.virtual_filesystem import (
archive_file_source_mount,
)
Expand Down Expand Up @@ -400,9 +404,15 @@ async def _construct_prompt(self) -> tuple[list, list[ChatCompletionMessageParam
message_provider = chat_context_toolkit_message_provider_for(
context=self.context,
tool_abbreviations=tool_abbreviations,
# use the fast client config for the attachment summarization that the message provider does
service_config=self.config.generative_ai_fast_client_config.service_config,
request_config=self.config.generative_ai_fast_client_config.request_config,
attachments=list(
await get_attachments(
self.context,
summarizer=construct_attachment_summarizer(
service_config=self.config.generative_ai_fast_client_config.service_config,
request_config=self.config.generative_ai_fast_client_config.request_config,
),
)
),
)
system_prompt_token_count = num_tokens_from_message(main_system_prompt, model="gpt-4o")
tool_token_count = num_tokens_from_tools(tools, model="gpt-4o")
Expand All @@ -421,7 +431,7 @@ async def _construct_prompt(self) -> tuple[list, list[ChatCompletionMessageParam
chat_history: list[ChatCompletionMessageParam] = list(budgeted_messages_result.messages)
chat_history.insert(0, main_system_prompt)

logging.info("The system prompt has been constructed.")
logger.info("The system prompt has been constructed.")
# Update telemetry for inspector
self.latest_telemetry.system_prompt_tokens = system_prompt_token_count
self.latest_telemetry.tool_tokens = tool_token_count
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def get_completion_messages_for_attachments(
self,
context: ConversationContext,
config: AttachmentsConfigModel,
include_filenames: list[str] | None = None,
include_filenames: list[str] = [],
exclude_filenames: list[str] = [],
summarizer: Summarizer | None = None,
) -> Sequence[CompletionMessage]:
Expand All @@ -143,6 +143,7 @@ async def get_completion_messages_for_attachments(
error_handler=self._error_handler,
include_filenames=include_filenames,
exclude_filenames=exclude_filenames,
summarizer=summarizer,
)

if not attachments:
Expand All @@ -159,14 +160,14 @@ async def get_completion_messages_for_attachments(
async def get_attachment_filenames(
self,
context: ConversationContext,
include_filenames: list[str] | None = None,
include_filenames: list[str] = [],
exclude_filenames: list[str] = [],
) -> list[str]:
files_response = await context.list_files()

# for all files, get the attachment
for file in files_response.files:
if include_filenames is not None and file.filename not in include_filenames:
if include_filenames and file.filename not in include_filenames:
continue
if file.filename in exclude_filenames:
continue
Expand Down Expand Up @@ -226,33 +227,33 @@ async def default_error_handler(context: ConversationContext, filename: str, e:

async def get_attachments(
context: ConversationContext,
include_filenames: list[str] | None,
exclude_filenames: list[str],
exclude_filenames: list[str] = [],
include_filenames: list[str] = [],
error_handler: AttachmentProcessingErrorHandler = default_error_handler,
summarizer: Summarizer | None = None,
) -> Sequence[Attachment]:
) -> list[Attachment]:
"""
Gets all attachments for the current state of the conversation, updating the cache as needed.
"""

# get all files in the conversation
files_response = await context.list_files()

# delete cached attachments that are no longer in the conversation
filenames = {file.filename for file in files_response.files}
asyncio.create_task(_delete_attachments_not_in(context, filenames))

attachments = []
# for all files, get the attachment
for file in files_response.files:
if include_filenames is not None and file.filename not in include_filenames:
if include_filenames and file.filename not in include_filenames:
continue
if file.filename in exclude_filenames:
continue

attachment = await _get_attachment_for_file(context, file, {}, error_handler, summarizer=summarizer)
attachments.append(attachment)

# delete cached attachments that are no longer in the conversation
filenames = {file.filename for file in files_response.files}
await _delete_attachments_not_in(context, filenames)

return attachments


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Provides the ArchiveTaskQueues class, for integrating with the chat context toolkit's archiving functionality.
"""

from ._archive import ArchiveTaskQueues
from ._archive import ArchiveTaskQueues, construct_archive_summarizer

__all__ = [
"ArchiveTaskQueues",
"construct_archive_summarizer",
]
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from chat_context_toolkit.archive import ArchiveReader, ArchiveTaskConfig, ArchiveTaskQueue, StorageProvider
from chat_context_toolkit.archive import MessageProvider as ArchiveMessageProvider
from chat_context_toolkit.archive.summarization import LLMArchiveSummarizer, LLMArchiveSummarizerConfig
from chat_context_toolkit.history.tool_abbreviations import ToolAbbreviations
from openai_client import OpenAIRequestConfig, ServiceConfig, create_client
from openai_client.tokens import num_tokens_from_messages
from semantic_workbench_assistant.assistant_app import ConversationContext, storage_directory_for_context

from assistant_extensions.attachments._model import Attachment

from ..message_history import chat_context_toolkit_message_provider_for


Expand Down Expand Up @@ -46,21 +47,30 @@ async def list_files(self, relative_directory_path: PurePath) -> list[PurePath]:


def archive_message_provider_for(
context: ConversationContext, service_config: ServiceConfig, request_config: OpenAIRequestConfig
context: ConversationContext,
attachments: list[Attachment],
) -> ArchiveMessageProvider:
"""Create an archive message provider for the provided context."""
return chat_context_toolkit_message_provider_for(
context=context,
tool_abbreviations=ToolAbbreviations(),
service_config=service_config,
request_config=request_config,
attachments=attachments,
)


def _archive_task_queue_for(
context: ConversationContext,
def construct_archive_summarizer(
service_config: ServiceConfig,
request_config: OpenAIRequestConfig,
) -> LLMArchiveSummarizer:
return LLMArchiveSummarizer(
client_factory=lambda: create_client(service_config),
llm_config=LLMArchiveSummarizerConfig(model=request_config.model),
)


def _archive_task_queue_for(
context: ConversationContext,
attachments: list[Attachment],
archive_summarizer: LLMArchiveSummarizer,
archive_task_config: ArchiveTaskConfig = ArchiveTaskConfig(),
token_counting_model: str = "gpt-4o",
archive_storage_sub_directory: str = "archives",
Expand All @@ -71,13 +81,11 @@ def _archive_task_queue_for(
return ArchiveTaskQueue(
storage_provider=ArchiveStorageProvider(context=context, sub_directory=archive_storage_sub_directory),
message_provider=archive_message_provider_for(
context=context, service_config=service_config, request_config=request_config
context=context,
attachments=attachments,
),
token_counter=lambda messages: num_tokens_from_messages(messages=messages, model=token_counting_model),
summarizer=LLMArchiveSummarizer(
client_factory=lambda: create_client(service_config),
llm_config=LLMArchiveSummarizerConfig(model=request_config.model),
),
summarizer=archive_summarizer,
config=archive_task_config,
)

Expand All @@ -93,17 +101,17 @@ def __init__(self) -> None:
async def enqueue_run(
self,
context: ConversationContext,
service_config: ServiceConfig,
request_config: OpenAIRequestConfig,
attachments: list[Attachment],
archive_summarizer: LLMArchiveSummarizer,
archive_task_config: ArchiveTaskConfig = ArchiveTaskConfig(),
) -> None:
"""Get the archive task queue for the given context, creating it if it does not exist."""
context_id = context.id
if context_id not in self._queues:
self._queues[context_id] = _archive_task_queue_for(
context=context,
service_config=service_config,
request_config=request_config,
attachments=attachments,
archive_summarizer=archive_summarizer,
archive_task_config=archive_task_config,
)
await self._queues[context_id].enqueue_run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Provides a message history provider for the chat context toolkit's history management.
"""

from ._history import chat_context_toolkit_message_provider_for
from ._history import chat_context_toolkit_message_provider_for, construct_attachment_summarizer

__all__ = [
"chat_context_toolkit_message_provider_for",
"construct_attachment_summarizer",
]
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from semantic_workbench_assistant.assistant_app import ConversationContext

from assistant_extensions.attachments._attachments import get_attachments
from assistant_extensions.attachments._model import Attachment
from assistant_extensions.attachments._summarizer import LLMConfig, LLMFileSummarizer

from ._message import conversation_message_to_chat_message_param
Expand Down Expand Up @@ -132,11 +132,23 @@ class CompositeMessageProtocol(HistoryMessageProtocol, ArchiveMessageProtocol, P
...


def chat_context_toolkit_message_provider_for(
context: ConversationContext,
tool_abbreviations: ToolAbbreviations,
def construct_attachment_summarizer(
service_config: ServiceConfig,
request_config: OpenAIRequestConfig,
) -> LLMFileSummarizer:
return LLMFileSummarizer(
llm_config=LLMConfig(
client_factory=lambda: create_client(service_config),
model=request_config.model,
max_response_tokens=request_config.response_tokens,
)
)


def chat_context_toolkit_message_provider_for(
context: ConversationContext,
attachments: list[Attachment],
tool_abbreviations: ToolAbbreviations = ToolAbbreviations(),
) -> CompositeMessageProvider:
"""
Create a composite message provider for the given workbench conversation context.
Expand All @@ -146,9 +158,8 @@ async def provider(after_id: str | None = None) -> Sequence[CompositeMessageProt
history = await _get_history_manager_messages(
context,
tool_abbreviations=tool_abbreviations,
service_config=service_config,
request_config=request_config,
after_id=after_id,
attachments=attachments,
)

return history
Expand All @@ -159,8 +170,7 @@ async def provider(after_id: str | None = None) -> Sequence[CompositeMessageProt
async def _get_history_manager_messages(
context: ConversationContext,
tool_abbreviations: ToolAbbreviations,
service_config: ServiceConfig,
request_config: OpenAIRequestConfig,
attachments: list[Attachment],
after_id: str | None = None,
) -> list[HistoryMessageWithAbbreviation]:
"""
Expand All @@ -175,21 +185,6 @@ async def _get_history_manager_messages(
batch_size = 100
before_message_id = None

attachments = list(
await get_attachments(
context=context,
include_filenames=None,
exclude_filenames=[],
summarizer=LLMFileSummarizer(
llm_config=LLMConfig(
client_factory=lambda: create_client(service_config),
model=request_config.model,
max_response_tokens=request_config.response_tokens,
)
),
)
)

# each call to get_messages will return a maximum of `batch_size` messages
# so we need to loop until all messages are retrieved
while True:
Expand Down
Loading