diff --git a/CLAUDE.md b/CLAUDE.md index 0e2b54b..25fac9e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,6 +8,7 @@ - Run commands should be prefixed with `uv`: `uv run ...` - Use `asyncio` features, if such is needed - Prefer early returns +- Private methods always go below public methods - Absolutely no useless comments! Every class and method does not need to be documented (unless it is legitimetly complex or "lib-ish") - Imports belong at the top of files, not inside functions (unless needed to avoid circular imports) diff --git a/src/agent_chat_cli/components/chat_history.py b/src/agent_chat_cli/components/chat_history.py index c90969c..b4397b7 100644 --- a/src/agent_chat_cli/components/chat_history.py +++ b/src/agent_chat_cli/components/chat_history.py @@ -13,8 +13,8 @@ class ChatHistory(Container): def add_message(self, message: Message) -> None: - widget = self._create_message(message) - self.mount(widget) + message_item = self._create_message(message) + self.mount(message_item) def _create_message( self, message: Message diff --git a/src/agent_chat_cli/components/model_selection_menu.py b/src/agent_chat_cli/components/model_selection_menu.py new file mode 100644 index 0000000..e88d0ed --- /dev/null +++ b/src/agent_chat_cli/components/model_selection_menu.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from textual.widget import Widget +from textual.app import ComposeResult +from textual.containers import VerticalScroll +from textual.widgets import OptionList +from textual.widgets.option_list import Option + +if TYPE_CHECKING: + from agent_chat_cli.core.actions import Actions + +MODELS = [ + {"id": "sonnet", "label": "Sonnet"}, + {"id": "haiku", "label": "Haiku"}, + {"id": "opus", "label": "Opus"}, +] + + +class ModelSelectionMenu(Widget): + def __init__(self, actions: Actions) -> None: + super().__init__() + self.actions = actions + + def compose(self) -> ComposeResult: + yield OptionList(*[Option(model["label"], id=model["id"]) for model in MODELS]) + + def show(self) -> None: + self.add_class("visible") + + scroll_containers = self.app.query(VerticalScroll) + if scroll_containers: + scroll_containers.first().scroll_end(animate=False) + + option_list = self.query_one(OptionList) + option_list.highlighted = 0 + option_list.focus() + + def hide(self) -> None: + self.remove_class("visible") + + @property + def is_visible(self) -> bool: + return self.has_class("visible") + + async def on_option_list_option_selected( + self, event: OptionList.OptionSelected + ) -> None: + self.hide() + + if event.option_id: + await self.actions.change_model(event.option_id) diff --git a/src/agent_chat_cli/components/slash_command_menu.py b/src/agent_chat_cli/components/slash_command_menu.py index 6f663fa..9cda1c3 100644 --- a/src/agent_chat_cli/components/slash_command_menu.py +++ b/src/agent_chat_cli/components/slash_command_menu.py @@ -11,6 +11,7 @@ COMMANDS = [ {"id": "new", "label": "/new - Start new conversation"}, {"id": "clear", "label": "/clear - Clear chat history"}, + {"id": "model", "label": "/model - Change model"}, {"id": "save", "label": "/save - Save conversation to markdown"}, {"id": "exit", "label": "/exit - Exit"}, ] @@ -84,5 +85,7 @@ async def on_option_list_option_selected( await self.actions.clear() case "new": await self.actions.new() + case "model": + self.actions.show_model_menu() case "save": await self.actions.save() diff --git a/src/agent_chat_cli/components/user_input.py b/src/agent_chat_cli/components/user_input.py index b9edb0a..4c3994b 100644 --- a/src/agent_chat_cli/components/user_input.py +++ b/src/agent_chat_cli/components/user_input.py @@ -7,6 +7,7 @@ from agent_chat_cli.components.caret import Caret from agent_chat_cli.components.flex import Flex from agent_chat_cli.components.slash_command_menu import SlashCommandMenu +from agent_chat_cli.components.model_selection_menu import ModelSelectionMenu from agent_chat_cli.core.actions import Actions from agent_chat_cli.utils.enums import Key @@ -35,6 +36,7 @@ def compose(self) -> ComposeResult: yield SlashCommandMenu( actions=self.actions, on_filter_change=self._on_filter_change ) + yield ModelSelectionMenu(actions=self.actions) def _on_filter_change(self, char: str) -> None: text_area = self.query_one(TextArea) @@ -51,13 +53,15 @@ def on_descendant_blur(self, event: DescendantBlur) -> None: if not self.display: return - menu = self.query_one(SlashCommandMenu) + menu = self._get_visible_menu() - if isinstance(event.widget, TextArea) and not menu.is_visible: + if isinstance(event.widget, TextArea) and not menu: event.widget.focus(scroll_visible=False) - elif isinstance(event.widget, OptionList) and menu.is_visible: - menu.hide() - self.query_one(TextArea).focus(scroll_visible=False) + elif isinstance(event.widget, OptionList) and menu: + menu_option_list = menu.query_one(OptionList) + if event.widget == menu_option_list: + menu.hide() + self.query_one(TextArea).focus(scroll_visible=False) def on_text_area_changed(self, event: TextArea.Changed) -> None: menu = self.query_one(SlashCommandMenu) @@ -68,10 +72,10 @@ def on_text_area_changed(self, event: TextArea.Changed) -> None: menu.show() async def on_key(self, event) -> None: - menu = self.query_one(SlashCommandMenu) + menu = self._get_visible_menu() - if menu.is_visible: - self._close_menu(event) + if menu: + self._close_menu(event, menu) return if event.key == "up": @@ -92,9 +96,7 @@ def _insert_newline(self, event) -> None: input_widget = self.query_one(TextArea) input_widget.insert("\n") - def _close_menu(self, event) -> None: - menu = self.query_one(SlashCommandMenu) - + def _close_menu(self, event, menu: SlashCommandMenu | ModelSelectionMenu) -> None: if event.key == Key.ESCAPE.value: event.stop() event.prevent_default() @@ -104,7 +106,10 @@ def _close_menu(self, event) -> None: input_widget.focus() return - if event.key in (Key.BACKSPACE.value, Key.DELETE.value): + if isinstance(menu, SlashCommandMenu) and event.key in ( + Key.BACKSPACE.value, + Key.DELETE.value, + ): if menu.filter_text: menu.filter_text = menu.filter_text[:-1] menu._refresh_options() @@ -147,10 +152,21 @@ async def _navigate_history(self, event, direction: int) -> None: input_widget.text = self.message_history[self.history_index] input_widget.move_cursor_relative(rows=999, columns=999) + def _get_visible_menu(self) -> SlashCommandMenu | ModelSelectionMenu | None: + slash_menu = self.query_one(SlashCommandMenu) + if slash_menu.is_visible: + return slash_menu + + model_menu = self.query_one(ModelSelectionMenu) + if model_menu.is_visible: + return model_menu + + return None + async def action_submit(self) -> None: - menu = self.query_one(SlashCommandMenu) + menu = self._get_visible_menu() - if menu.is_visible: + if menu: option_list = menu.query_one(OptionList) option_list.action_select() input_widget = self.query_one(TextArea) diff --git a/src/agent_chat_cli/core/actions.py b/src/agent_chat_cli/core/actions.py index 3bf9d78..fe6ecd9 100644 --- a/src/agent_chat_cli/core/actions.py +++ b/src/agent_chat_cli/core/actions.py @@ -4,6 +4,7 @@ from agent_chat_cli.components.messages import RoleType from agent_chat_cli.components.chat_history import ChatHistory from agent_chat_cli.components.tool_permission_prompt import ToolPermissionPrompt +from agent_chat_cli.components.model_selection_menu import ModelSelectionMenu from agent_chat_cli.utils.logger import log_json from agent_chat_cli.utils.save_conversation import save_conversation @@ -72,5 +73,13 @@ async def save(self) -> None: f"Conversation saved to {file_path}", thinking=False ) + def show_model_menu(self) -> None: + model_menu = self.app.query_one(ModelSelectionMenu) + model_menu.show() + + async def change_model(self, model: str) -> None: + await self.app.agent_loop.change_model(model) + await self.post_system_message(f"Switched to {model}", thinking=False) + async def _query(self, user_input: str) -> None: await self.app.agent_loop.query_queue.put(user_input) diff --git a/src/agent_chat_cli/core/agent_loop.py b/src/agent_chat_cli/core/agent_loop.py index b62ea35..dcaae00 100644 --- a/src/agent_chat_cli/core/agent_loop.py +++ b/src/agent_chat_cli/core/agent_loop.py @@ -24,7 +24,12 @@ get_available_servers, get_sdk_config, ) -from agent_chat_cli.utils.enums import AppEventType, ContentType, ControlCommand +from agent_chat_cli.utils.enums import ( + AppEventType, + ContentType, + ControlCommand, + ModelChangeCommand, +) from agent_chat_cli.utils.logger import log_json from agent_chat_cli.utils.mcp_server_status import MCPServerStatus @@ -51,36 +56,33 @@ def __init__( self.client: ClaudeSDKClient - self.query_queue: asyncio.Queue[str | ControlCommand] = asyncio.Queue() + self.query_queue: asyncio.Queue[str | ControlCommand | ModelChangeCommand] = ( + asyncio.Queue() + ) self.permission_response_queue: asyncio.Queue[str] = asyncio.Queue() self.permission_lock = asyncio.Lock() self._running = False async def start(self) -> None: - mcp_servers = { - name: config.model_dump() for name, config in self.available_servers.items() - } - - await self._initialize_client(mcp_servers=mcp_servers) + await self._initialize_client() self._running = True while self._running: user_input = await self.query_queue.get() + if isinstance(user_input, ModelChangeCommand): + self.config.model = user_input.model + await self.client.disconnect() + await self._initialize_client() + continue + if isinstance(user_input, ControlCommand): if user_input == ControlCommand.NEW_CONVERSATION: await self.client.disconnect() - self.session_id = None - - mcp_servers = { - name: config.model_dump() - for name, config in self.available_servers.items() - } - - await self._initialize_client(mcp_servers=mcp_servers) + await self._initialize_client() continue self.app.ui_state.set_interrupting(False) @@ -97,7 +99,17 @@ async def start(self) -> None: AppEvent(type=AppEventType.RESULT, data=None) ) - async def _initialize_client(self, mcp_servers: dict) -> None: + async def change_model(self, model: str) -> None: + await self.query_queue.put( + ModelChangeCommand(ControlCommand.CHANGE_MODEL, model) + ) + + async def _initialize_client(self, mcp_servers: dict | None = None) -> None: + if mcp_servers is None: + mcp_servers = { + name: config.model_dump() + for name, config in self.available_servers.items() + } sdk_config = get_sdk_config(self.config) sdk_config["mcp_servers"] = mcp_servers diff --git a/src/agent_chat_cli/core/styles.tcss b/src/agent_chat_cli/core/styles.tcss index 852127d..3f63183 100644 --- a/src/agent_chat_cli/core/styles.tcss +++ b/src/agent_chat_cli/core/styles.tcss @@ -108,17 +108,17 @@ TextArea .text-area--cursor { padding-left: 2; } -SlashCommandMenu { +SlashCommandMenu, ModelSelectionMenu { height: auto; max-height: 10; display: none; } -SlashCommandMenu.visible { +SlashCommandMenu.visible, ModelSelectionMenu.visible { display: block; } -SlashCommandMenu OptionList { +SlashCommandMenu OptionList, ModelSelectionMenu OptionList { height: auto; max-height: 10; border: solid $primary; diff --git a/src/agent_chat_cli/utils/enums.py b/src/agent_chat_cli/utils/enums.py index f8cd799..2ee4779 100644 --- a/src/agent_chat_cli/utils/enums.py +++ b/src/agent_chat_cli/utils/enums.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import NamedTuple class AppEventType(Enum): @@ -20,10 +21,16 @@ class ContentType(Enum): class ControlCommand(Enum): NEW_CONVERSATION = "new_conversation" + CHANGE_MODEL = "change_model" EXIT = "exit" CLEAR = "clear" +class ModelChangeCommand(NamedTuple): + command: ControlCommand + model: str + + class Key(Enum): ENTER = "enter" ESCAPE = "escape" diff --git a/tests/components/test_model_selection_menu.py b/tests/components/test_model_selection_menu.py new file mode 100644 index 0000000..f31910a --- /dev/null +++ b/tests/components/test_model_selection_menu.py @@ -0,0 +1,97 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock + +from textual.app import App, ComposeResult +from textual.widgets import OptionList + +from agent_chat_cli.components.model_selection_menu import ModelSelectionMenu + + +class ModelSelectionMenuApp(App): + def __init__(self): + super().__init__() + self.mock_actions = MagicMock() + self.mock_actions.change_model = AsyncMock() + + def compose(self) -> ComposeResult: + yield ModelSelectionMenu(actions=self.mock_actions) + + +class TestModelSelectionMenuVisibility: + @pytest.fixture + def app(self): + return ModelSelectionMenuApp() + + async def test_hidden_by_default(self, app): + async with app.run_test(): + menu = app.query_one(ModelSelectionMenu) + + assert menu.is_visible is False + + async def test_show_makes_visible(self, app): + async with app.run_test(): + menu = app.query_one(ModelSelectionMenu) + menu.show() + + assert menu.is_visible is True + + async def test_hide_makes_invisible(self, app): + async with app.run_test(): + menu = app.query_one(ModelSelectionMenu) + menu.show() + menu.hide() + + assert menu.is_visible is False + + async def test_show_highlights_first_option(self, app): + async with app.run_test(): + menu = app.query_one(ModelSelectionMenu) + menu.show() + + option_list = menu.query_one(OptionList) + assert option_list.highlighted == 0 + + +class TestModelSelectionMenuSelection: + @pytest.fixture + def app(self): + return ModelSelectionMenuApp() + + async def test_sonnet_command_calls_change_model(self, app): + async with app.run_test() as pilot: + menu = app.query_one(ModelSelectionMenu) + menu.show() + + await pilot.press("enter") + + app.mock_actions.change_model.assert_called_once_with("sonnet") + + async def test_haiku_command_calls_change_model(self, app): + async with app.run_test() as pilot: + menu = app.query_one(ModelSelectionMenu) + menu.show() + + await pilot.press("down") + await pilot.press("enter") + + app.mock_actions.change_model.assert_called_once_with("haiku") + + async def test_opus_command_calls_change_model(self, app): + async with app.run_test() as pilot: + menu = app.query_one(ModelSelectionMenu) + menu.show() + + await pilot.press("down") + await pilot.press("down") + await pilot.press("enter") + + app.mock_actions.change_model.assert_called_once_with("opus") + + async def test_selection_hides_menu(self, app): + async with app.run_test() as pilot: + menu = app.query_one(ModelSelectionMenu) + menu.show() + + await pilot.press("enter") + + assert menu.is_visible is False diff --git a/tests/components/test_slash_command_menu.py b/tests/components/test_slash_command_menu.py index ce06245..f66150c 100644 --- a/tests/components/test_slash_command_menu.py +++ b/tests/components/test_slash_command_menu.py @@ -15,6 +15,7 @@ def __init__(self): self.mock_actions.clear = AsyncMock() self.mock_actions.new = AsyncMock() self.mock_actions.save = AsyncMock() + self.mock_actions.show_model_menu = MagicMock() def compose(self) -> ComposeResult: yield SlashCommandMenu(actions=self.mock_actions) @@ -79,6 +80,17 @@ async def test_clear_command_calls_clear(self, app): app.mock_actions.clear.assert_called_once() + async def test_model_command_calls_show_model_menu(self, app): + async with app.run_test() as pilot: + menu = app.query_one(SlashCommandMenu) + menu.show() + + await pilot.press("down") + await pilot.press("down") + await pilot.press("enter") + + app.mock_actions.show_model_menu.assert_called_once() + async def test_exit_command_calls_quit(self, app): async with app.run_test() as pilot: menu = app.query_one(SlashCommandMenu) @@ -87,6 +99,7 @@ async def test_exit_command_calls_quit(self, app): await pilot.press("down") await pilot.press("down") await pilot.press("down") + await pilot.press("down") await pilot.press("enter") app.mock_actions.quit.assert_called_once()