From d401d0a5e2414112303df0a76cc51cf1fe59782f Mon Sep 17 00:00:00 2001 From: eyalz Date: Mon, 11 Aug 2025 17:33:03 +0300 Subject: [PATCH 1/6] Save work - before tests --- CLAUDE.md | 2 +- commands/create_spec.py | 12 +- commands/mcp.py | 2 +- cxk.py | 2 +- engine/__init__.py | 38 ++++++- engine/globals.py | 47 ++++++++ mcp_client/__init__.py | 0 mcp_client/client_session_provider.py | 130 ++++++++++++++++++++++ {util/mcp => mcp_client}/config.py | 0 prompt/__init__.py | 153 +++++++++++++++++++++++++- pyproject.toml | 1 + state.py | 2 +- tests/README.md | 11 ++ tests/e2e/test_e2e.py | 2 +- tests/templates/spec2.md | 13 +++ tests/test_runner.py | 7 +- 16 files changed, 402 insertions(+), 20 deletions(-) create mode 100644 engine/globals.py create mode 100644 mcp_client/__init__.py create mode 100644 mcp_client/client_session_provider.py rename {util/mcp => mcp_client}/config.py (100%) create mode 100644 tests/templates/spec2.md diff --git a/CLAUDE.md b/CLAUDE.md index 00d42ea..35c28b3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -15,7 +15,7 @@ ContextKit is a CLI tool and MCP (Model Context Protocol) client for creating sp - `mcp.py` - MCP server management (add-sse, add-stdio, add-http) - `create_spec.py` - Template rendering with variable collection - **Template engine**: `engine/` - Jinja2-based template processing -- **MCP configuration**: `util/mcp/config.py` - Pydantic models for MCP server configs +- **MCP configuration**: `mcp_client/config.py` - Pydantic models for MCP server configs - **User prompts**: `prompt/` - Interactive variable collection ## Core Concepts diff --git a/commands/create_spec.py b/commands/create_spec.py index 5adb5f1..8972d5e 100644 --- a/commands/create_spec.py +++ b/commands/create_spec.py @@ -4,11 +4,13 @@ import sys from engine import TemplateEngine, TemplateParseError -from prompt import collect_var_value +from prompt import PromptHelper +from state import State async def handle_create_spec( spec_template: str | None, + state: State, output_file: str | None = None, var_overrides: list[str] | None = None, verbose: bool = False, @@ -16,6 +18,8 @@ async def handle_create_spec( log_level = logging.DEBUG if verbose else logging.WARNING logging.basicConfig(level=log_level, format="%(message)s", force=True) + prompt_helper = PromptHelper(state) + # Detect piped input (stdin not a TTY) and ensure there's data before using it stdin_piped = not sys.stdin.isatty() @@ -30,7 +34,7 @@ async def handle_create_spec( logging.error(f"Error: Template file '{spec_template}' not found") sys.exit(1) - template_engine = TemplateEngine.from_file(template_path) + template_engine = TemplateEngine.from_file(template_path, state, prompt_helper) elif stdin_piped: try: template_str = sys.stdin.read() @@ -41,7 +45,7 @@ async def handle_create_spec( logging.error("Error: No data received on stdin for template") sys.exit(1) - template_engine = TemplateEngine.from_string(template_str) + template_engine = TemplateEngine.from_string(template_str, state, prompt_helper) else: logging.error("Error: Missing spec_template argument (or provide template via stdin)") sys.exit(1) @@ -68,7 +72,7 @@ async def handle_create_spec( raw_value = provided_vars[var] else: - raw_value = await collect_var_value(var) + raw_value = await prompt_helper.collect_var_value(var) logging.info(f" {var}: {raw_value}") # Try to parse as JSON if it looks like JSON diff --git a/commands/mcp.py b/commands/mcp.py index 19c2067..7ca6361 100644 --- a/commands/mcp.py +++ b/commands/mcp.py @@ -1,7 +1,7 @@ from pydantic import BaseModel from state import State -from util.mcp.config import SSEServerConfig, StdioServerConfig +from mcp_client.config import SSEServerConfig, StdioServerConfig class MCPAddSSEContext(BaseModel): diff --git a/cxk.py b/cxk.py index 2104247..ff41e4b 100644 --- a/cxk.py +++ b/cxk.py @@ -61,7 +61,7 @@ async def main(): await handle_init(state) elif args.command == "create-spec": - await handle_create_spec(args.spec_template, args.output, args.var, args.verbose) + await handle_create_spec(args.spec_template, state, args.output, args.var, args.verbose) elif args.command == "mcp": if not args.mcp_command: diff --git a/engine/__init__.py b/engine/__init__.py index ffa5604..025b703 100644 --- a/engine/__init__.py +++ b/engine/__init__.py @@ -3,6 +3,10 @@ from jinja2 import Environment, FileSystemLoader, Template, meta, select_autoescape +from engine.globals import create_mcp_tool_function +from prompt import PromptHelper +from state import State + class TemplateParseError(Exception): """Raised when template parsing fails""" @@ -14,20 +18,32 @@ class TemplateEngine: """Abstract away the jinja2 template engine with clean factory methods""" def __init__( - self, env: Environment, template: Template, source_path: Path | None = None, source_string: str | None = None + self, + env: Environment, + template: Template, + state: State, + prompt_helper: PromptHelper, + source_path: Path | None = None, + source_string: str | None = None, ): """Private constructor - use from_file() or from_string() instead""" self.env = env self.template = template self._source_path = source_path self._source_string = source_string + self._state = state + self._prompt_helper = prompt_helper + + # Add global functions to env + self.env.globals["mcp"] = create_mcp_tool_function(self._state, self._prompt_helper) @classmethod - def from_file(cls, path: str | Path) -> "TemplateEngine": + def from_file(cls, path: str | Path, state: State, prompt_helper: PromptHelper) -> "TemplateEngine": """Create a TemplateEngine from a template file. Args: path: Path to the template file + state: State object containing project configuration Returns: TemplateEngine instance @@ -55,14 +71,19 @@ def from_file(cls, path: str | Path) -> "TemplateEngine": except Exception as e: raise TemplateParseError(f"Failed to load template from {path}: {e}") from e - return cls(env=env, template=template, source_path=path, source_string=None) + return cls( + env=env, template=template, state=state, source_path=path, source_string=None, prompt_helper=prompt_helper + ) @classmethod - def from_string(cls, template_string: str, name: str = "") -> "TemplateEngine": + def from_string( + cls, template_string: str, state: State, prompt_helper: PromptHelper, name: str = "" + ) -> "TemplateEngine": """Create a TemplateEngine from a template string. Args: template_string: The template content as a string + state: State object containing project configuration name: Optional name for the template (for debugging) Returns: @@ -85,7 +106,14 @@ def from_string(cls, template_string: str, name: str = "") -> "TemplateEn # Store the name in the template for better error messages template.name = name - return cls(env=env, template=template, source_path=None, source_string=template_string) + return cls( + env=env, + template=template, + state=state, + source_path=None, + source_string=template_string, + prompt_helper=prompt_helper, + ) @property def source(self) -> str: diff --git a/engine/globals.py b/engine/globals.py new file mode 100644 index 0000000..c754110 --- /dev/null +++ b/engine/globals.py @@ -0,0 +1,47 @@ +import json +from typing import Any + +from mcp import types +from mcp.shared.metadata_utils import get_display_name + +from mcp_client.client_session_provider import get_client_session_by_server +from prompt import PromptHelper +from state import State + + +def create_mcp_tool_function(state: State, prompt_helper: PromptHelper): + """Create a call_mcp_tool function with state bound.""" + + async def call_mcp_tool(server: str, tool_name: str, args: dict) -> str | dict[str, Any]: + print(f"Calling MCP tool: {tool_name} on server: {server} with args: {args}") + async with get_client_session_by_server(server, state.mcp_config) as session: + # Initialize the connection + await session.initialize() + + tools = await session.list_tools() + # Call the tool with collected input + try: + full_arguments = await prompt_helper.get_full_args(tools, tool_name, args) + + print(f"Full arguments for tool {tool_name}: {full_arguments}") + result = await session.call_tool(tool_name, arguments=full_arguments) + result_unstructured = result.content[0] + if isinstance(result_unstructured, types.TextContent): + # Try to parse the result as JSON + if result_unstructured.text.startswith("{") or result_unstructured.text.startswith("["): + try: + return json.loads(result_unstructured.text) + except json.JSONDecodeError: + print("Failed to parse result as JSON, returning raw text.") + + return result_unstructured.text + else: + return "" + except Exception as e: + print(f"Error calling tool {tool_name}: {e}") + print("Available tools:") + for tool in tools.tools: + print(f" - {tool.name}: {get_display_name(tool)}") + return "" + + return call_mcp_tool diff --git a/mcp_client/__init__.py b/mcp_client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcp_client/client_session_provider.py b/mcp_client/client_session_provider.py new file mode 100644 index 0000000..de5594a --- /dev/null +++ b/mcp_client/client_session_provider.py @@ -0,0 +1,130 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from urllib.parse import parse_qs, urlparse + +from .config import MCPServersConfig, SSEServerConfig, StdioServerConfig +from mcp import ClientSession, StdioServerParameters +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from pydantic import AnyUrl + + +class InMemoryTokenStorage(TokenStorage): + """Demo In-memory token storage implementation.""" + + def __init__(self): + self.tokens: OAuthToken | None = None + self.client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + """Get stored tokens.""" + return self.tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + """Store tokens.""" + self.tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + """Get stored client information.""" + return self.client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + """Store client information.""" + self.client_info = client_info + + +async def handle_redirect(auth_url: str) -> None: + print(f"Visit: {auth_url}") + + +async def handle_callback() -> tuple[str, str | None]: + callback_url = input("Paste callback URL: ") + params = parse_qs(urlparse(callback_url).query) + return params["code"][0], params.get("state", [None])[0] + + +@asynccontextmanager +async def get_stdio_session(server_params: StdioServerParameters): + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + yield session + + +@asynccontextmanager +async def get_streamablehttp_session(server_url: str): + oauth_auth = OAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="Example MCP Client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scope="user", + ), + storage=InMemoryTokenStorage(), + redirect_handler=handle_redirect, + callback_handler=handle_callback, + ) + # Connect to a streamable HTTP server + async with streamablehttp_client(server_url, auth=oauth_auth) as ( + read_stream, + write_stream, + _, + ): + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + yield session + + +@asynccontextmanager +async def get_sse_session(server_url: str): + oauth_auth = OAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="Example MCP Client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scope="user", + ), + storage=InMemoryTokenStorage(), + redirect_handler=handle_redirect, + callback_handler=handle_callback, + ) + + # Connect to a Server-Sent Events (SSE) server + async with sse_client( + url=server_url, + auth=oauth_auth, + timeout=60, + ) as (read_stream, write_stream): + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + yield session + + +@asynccontextmanager +async def get_client_session_by_server( + server_name: str, mcp_servers_config: MCPServersConfig +) -> AsyncGenerator[ClientSession, None]: + # Find the server configuration by name + server_config = mcp_servers_config.mcpServers.get(server_name) + if not server_config: + raise ValueError(f"Server '{server_name}' not found in configuration.") + if isinstance(server_config, StdioServerConfig): + async with get_stdio_session( + StdioServerParameters( + command=server_config.command, + args=server_config.args or [], + env=server_config.env, + ) + ) as session: + yield session + elif isinstance(server_config, SSEServerConfig): + async with get_sse_session(server_config.url) as session: + yield session + else: + raise ValueError(f"Unsupported server type for '{server_name}': {type(server_config)}") diff --git a/util/mcp/config.py b/mcp_client/config.py similarity index 100% rename from util/mcp/config.py rename to mcp_client/config.py diff --git a/prompt/__init__.py b/prompt/__init__.py index f937680..46e3a95 100644 --- a/prompt/__init__.py +++ b/prompt/__init__.py @@ -1,10 +1,157 @@ +from typing import TYPE_CHECKING + import questionary +from state import State # Add interactive prompt helpers using questionary that will collect values for # unspecified template variables. -async def collect_var_value(var_name: str) -> str: - """Collect a value for a variable using questionary.""" +class PromptHelper: + + def __init__( + self, + state: State, + ): + self._state = state + + async def collect_var_value(self, var_name: str) -> str: + """Collect a value for a variable using questionary.""" + + return await questionary.text(f"Please provide a value for '{var_name}':").ask_async() + + async def get_full_args(self, tools, tool_name, args): + """ + Get full arguments for the tool call, using collect_tool_input() to collect missing required fields. + + :param tools: List of available tools + :param tool_name: Name of the tool to call + :param args: Arguments provided by the user + :return: Full arguments including defaults and required fields + """ + selected_tool = next((t for t in tools.tools if t.name == tool_name), None) + if not selected_tool: + raise ValueError(f"Tool '{tool_name}' not found.") + + input_schema = selected_tool.inputSchema + if not input_schema: + return args + + # Use collect_tool_input to fill in missing required fields + full_args = await self.collect_tool_input(input_schema, args, include_optional=False) + + return full_args + + async def collect_tool_input(self, input_schema, existing_args=None, include_optional=True): + """ + Collect user input based on a schema using Questionary. + Skip fields that are already provided in existing_args. + + Args: + input_schema (dict): Schema with 'properties' and 'required' keys + existing_args (dict, optional): Already collected values + + Returns: + dict: Dictionary with field names as keys and collected values + + Example: + schema = { + 'properties': { + 'a': {'title': 'A', 'type': 'integer'}, + 'b': {'title': 'B', 'type': 'integer'} + }, + 'required': ['a', 'b'] + } + result = collect_input(schema) # Returns something like {"a": 5, "b": 3} + """ + properties = input_schema["properties"] + required = input_schema.get("required", []) + + if existing_args is None: + existing_args = {} + + values = existing_args.copy() + + for field_name, field_info in properties.items(): + # Skip if field is already provided in existing_args + if field_name in existing_args: + continue + + print(f"Collecting input for field: {field_info}") + title = field_info["title"] if "title" in field_info else field_name + field_type = field_info["type"] if "type" in field_info else "string" + field_desc = field_info.get("description", None) + is_required = field_name in required + + if not include_optional and not is_required: + # Skip optional fields if not requested + continue + + # Create prompt text + prompt_text = title + if is_required: + prompt_text += " (required)" + else: + prompt_text += " (optional)" + prompt_text += ":" + + if field_type == "integer": + # Create validator for integer fields + def make_integer_validator(required): + def validate_integer(value): + # Allow empty for optional fields + if not value.strip(): + if required: + return "This field is required" + return True + # Validate integer format + try: + int(value) + return True + except ValueError: + return "Please enter a valid integer" + + return validate_integer + + result = await questionary.text(prompt_text, validate=make_integer_validator(is_required)).ask_async() + + # Process the result + if result and result.strip(): + values[field_name] = int(result) + elif is_required: + # This shouldn't happen due to validation + raise ValueError(f"Required field {field_name} is missing") + # Optional fields that are empty are not included in output + + elif field_type == "string": + # Create validator for string fields + def make_string_validator(required): + def validate_string(value): + if not value.strip() and required: + return "This field is required" + return True + + return validate_string + + result = await questionary.text( + prompt_text, + validate=make_string_validator(is_required), + instruction=field_desc, + ).ask_async() + + # Process the result + if result and result.strip(): + values[field_name] = result + elif is_required: + # This shouldn't happen due to validation + raise ValueError(f"Required field {field_name} is missing") + # Optional fields that are empty are not included in output + + elif field_type == "boolean": + # Create a yes/no prompt for boolean fields + result = await questionary.confirm(prompt_text, default=False, qmark="?").ask_async() + + # Store the boolean value + values[field_name] = result - return await questionary.text(f"Please provide a value for '{var_name}':").ask_async() + return values diff --git a/pyproject.toml b/pyproject.toml index 5b306cb..05e00e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "jinja2>=3.1.6", + "mcp[cli]>=1.12.4", "pydantic>=2.11.7", "questionary>=2.1.0", ] diff --git a/state.py b/state.py index bfc9ca5..8856227 100644 --- a/state.py +++ b/state.py @@ -1,7 +1,7 @@ import json from pathlib import Path -from util.mcp.config import MCPServersConfig +from mcp_client.config import MCPServersConfig class State: diff --git a/tests/README.md b/tests/README.md index 3555fb3..2ea768e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -7,6 +7,12 @@ uv run pytest ## Manual Testing +### Initialize the environment and adding mcp configuration +``` +uv run cxk.py init +uv run cxk.py mcp add-stdio server-name2 --env KEY=value -- python server.py +``` + ``` uv run cxk.py create-spec tests/templates/spec1.md ``` @@ -27,4 +33,9 @@ uv run cxk.py create-spec tests/templates/spec1.md --verbose --var additional_co ### Piped ``` cat tests/templates/spec1.md | uv run cxk.py create-spec --verbose --var ticket='{"id":1}' --var additional_context=2 +``` + +### With MCP functions +``` +uv run cxk.py create-spec tests/templates/spec2.md ``` \ No newline at end of file diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index c575967..a199888 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -610,7 +610,7 @@ def test_create_spec_pipe_mode(self, temp_non_git_dir): "create-spec", "--verbose", "--var", - "ticket={\"id\":1}", + 'ticket={"id":1}', "--var", "additional_context=test context", ], diff --git a/tests/templates/spec2.md b/tests/templates/spec2.md new file mode 100644 index 0000000..7a81b18 --- /dev/null +++ b/tests/templates/spec2.md @@ -0,0 +1,13 @@ +# Task Template + +## Ticket description + +{% set ticket = mcp('jira', 'getJiraIssue', {'issueKey': 'ACME-4432'}) %} + +{{ ticket.id }} + +{{ ticket.description }} + +## Additional context + +{{ additional_context }} \ No newline at end of file diff --git a/tests/test_runner.py b/tests/test_runner.py index 050b7c8..0933ee5 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Test runner script that patches collect_var_value for e2e testing.""" + import asyncio import sys from pathlib import Path @@ -19,12 +20,12 @@ async def mock_collect_var_value(var_name: str) -> str: "city": "New York", "weather": '{"condition": "sunny", "temp": "75F"}', "username": "testuser", - "user": "test_user" + "user": "test_user", } return mock_values.get(var_name, f"mock_value_{var_name}") if __name__ == "__main__": # Patch collect_var_value before running main - with patch('commands.create_spec.collect_var_value', side_effect=mock_collect_var_value): - asyncio.run(main()) \ No newline at end of file + with patch("prompt.PromptHelper.collect_var_value", side_effect=mock_collect_var_value): + asyncio.run(main()) From b81c6d7b913112485035b63a133b394281bf929f Mon Sep 17 00:00:00 2001 From: eyalz Date: Tue, 12 Aug 2025 12:59:21 +0300 Subject: [PATCH 2/6] manual test pass, before updating e2e --- commands/create_spec.py | 4 ---- cxk.py | 7 ++++-- engine/globals.py | 15 ++++++------ mcp_client/client_session_provider.py | 31 +++++++++++++++--------- mcp_client/mcp_logger.py | 32 +++++++++++++++++++++++++ prompt/__init__.py | 6 ++--- tests/README.md | 14 +++++++++-- tests/mcp_test_server.py | 34 +++++++++++++++++++++++++++ tests/templates/spec2.md | 3 ++- tests/templates/spec3.md | 18 ++++++++++++++ 10 files changed, 134 insertions(+), 30 deletions(-) create mode 100644 mcp_client/mcp_logger.py create mode 100644 tests/mcp_test_server.py create mode 100644 tests/templates/spec3.md diff --git a/commands/create_spec.py b/commands/create_spec.py index 8972d5e..1f7c069 100644 --- a/commands/create_spec.py +++ b/commands/create_spec.py @@ -13,11 +13,7 @@ async def handle_create_spec( state: State, output_file: str | None = None, var_overrides: list[str] | None = None, - verbose: bool = False, ): - log_level = logging.DEBUG if verbose else logging.WARNING - logging.basicConfig(level=log_level, format="%(message)s", force=True) - prompt_helper = PromptHelper(state) # Detect piped input (stdin not a TTY) and ensure there's data before using it diff --git a/cxk.py b/cxk.py index ff41e4b..84fa18f 100644 --- a/cxk.py +++ b/cxk.py @@ -1,5 +1,6 @@ import argparse import asyncio +import logging import sys from commands.create_spec import handle_create_spec @@ -61,7 +62,9 @@ async def main(): await handle_init(state) elif args.command == "create-spec": - await handle_create_spec(args.spec_template, state, args.output, args.var, args.verbose) + log_level = logging.DEBUG if args.verbose else logging.WARNING + logging.basicConfig(level=log_level, format="%(message)s", force=True) + await handle_create_spec(args.spec_template, state, args.output, args.var) elif args.command == "mcp": if not args.mcp_command: @@ -94,7 +97,7 @@ async def main(): await handle_mcp(state, mcp_context) except Exception as e: - print(f"Error: {e}", file=sys.stderr) + logging.exception(f"Error: {e}") sys.exit(1) diff --git a/engine/globals.py b/engine/globals.py index c754110..e0bc21b 100644 --- a/engine/globals.py +++ b/engine/globals.py @@ -1,4 +1,5 @@ import json +import logging from typing import Any from mcp import types @@ -13,8 +14,8 @@ def create_mcp_tool_function(state: State, prompt_helper: PromptHelper): """Create a call_mcp_tool function with state bound.""" async def call_mcp_tool(server: str, tool_name: str, args: dict) -> str | dict[str, Any]: - print(f"Calling MCP tool: {tool_name} on server: {server} with args: {args}") - async with get_client_session_by_server(server, state.mcp_config) as session: + logging.info(f"Calling MCP tool: {tool_name} on server: {server} with args: {args}") + async with get_client_session_by_server(server, state) as session: # Initialize the connection await session.initialize() @@ -23,7 +24,7 @@ async def call_mcp_tool(server: str, tool_name: str, args: dict) -> str | dict[s try: full_arguments = await prompt_helper.get_full_args(tools, tool_name, args) - print(f"Full arguments for tool {tool_name}: {full_arguments}") + logging.debug(f"Full arguments for tool {tool_name}: {full_arguments}") result = await session.call_tool(tool_name, arguments=full_arguments) result_unstructured = result.content[0] if isinstance(result_unstructured, types.TextContent): @@ -32,16 +33,16 @@ async def call_mcp_tool(server: str, tool_name: str, args: dict) -> str | dict[s try: return json.loads(result_unstructured.text) except json.JSONDecodeError: - print("Failed to parse result as JSON, returning raw text.") + logging.error("Failed to parse result as JSON, returning raw text.") return result_unstructured.text else: return "" except Exception as e: - print(f"Error calling tool {tool_name}: {e}") - print("Available tools:") + logging.error(f"Error calling tool {tool_name}: {e}") + logging.error("Available tools:") for tool in tools.tools: - print(f" - {tool.name}: {get_display_name(tool)}") + logging.error(f" - {tool.name}: {get_display_name(tool)}") return "" return call_mcp_tool diff --git a/mcp_client/client_session_provider.py b/mcp_client/client_session_provider.py index de5594a..44468a0 100644 --- a/mcp_client/client_session_provider.py +++ b/mcp_client/client_session_provider.py @@ -1,8 +1,11 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from pathlib import Path + +# Import State type for type hints +from typing import TYPE_CHECKING from urllib.parse import parse_qs, urlparse -from .config import MCPServersConfig, SSEServerConfig, StdioServerConfig from mcp import ClientSession, StdioServerParameters from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.sse import sse_client @@ -11,6 +14,12 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken from pydantic import AnyUrl +from .config import SSEServerConfig, StdioServerConfig +from .mcp_logger import get_mcp_log_file + +if TYPE_CHECKING: + from state import State + class InMemoryTokenStorage(TokenStorage): """Demo In-memory token storage implementation.""" @@ -47,10 +56,11 @@ async def handle_callback() -> tuple[str, str | None]: @asynccontextmanager -async def get_stdio_session(server_params: StdioServerParameters): - async with stdio_client(server_params) as (read, write): - async with ClientSession(read, write) as session: - yield session +async def get_stdio_session(server_params: StdioServerParameters, config_dir: Path | None = None): + with get_mcp_log_file(config_dir) as errlog: + async with stdio_client(server_params, errlog=errlog) as (read, write): + async with ClientSession(read, write) as session: + yield session @asynccontextmanager @@ -85,7 +95,7 @@ async def get_sse_session(server_url: str): server_url=server_url, client_metadata=OAuthClientMetadata( client_name="Example MCP Client", - redirect_uris=[AnyUrl("http://localhost:3000/callback")], + redirect_uris=[AnyUrl("http://localhost:41008/callback")], grant_types=["authorization_code", "refresh_token"], response_types=["code"], scope="user", @@ -107,11 +117,9 @@ async def get_sse_session(server_url: str): @asynccontextmanager -async def get_client_session_by_server( - server_name: str, mcp_servers_config: MCPServersConfig -) -> AsyncGenerator[ClientSession, None]: +async def get_client_session_by_server(server_name: str, state: "State") -> AsyncGenerator[ClientSession, None]: # Find the server configuration by name - server_config = mcp_servers_config.mcpServers.get(server_name) + server_config = state.mcp_config.mcpServers.get(server_name) if not server_config: raise ValueError(f"Server '{server_name}' not found in configuration.") if isinstance(server_config, StdioServerConfig): @@ -120,7 +128,8 @@ async def get_client_session_by_server( command=server_config.command, args=server_config.args or [], env=server_config.env, - ) + ), + config_dir=state.config_dir, ) as session: yield session elif isinstance(server_config, SSEServerConfig): diff --git a/mcp_client/mcp_logger.py b/mcp_client/mcp_logger.py new file mode 100644 index 0000000..0932126 --- /dev/null +++ b/mcp_client/mcp_logger.py @@ -0,0 +1,32 @@ +from collections.abc import Generator +from contextlib import contextmanager +from pathlib import Path +from typing import TextIO + + +@contextmanager +def get_mcp_log_file(config_dir: Path | None) -> Generator[TextIO, None, None]: + """ + Context manager that provides a log file handle for MCP stderr output. + + Args: + config_dir: The configuration directory (typically .cxk) where logs should be stored. + If None, falls back to current directory. + + Yields: + TextIO: File handle for writing MCP stderr logs + """ + if config_dir is None: + log_dir = Path.cwd() / ".cxk" + else: + log_dir = config_dir + + # Ensure log directory exists + log_dir.mkdir(exist_ok=True) + + # Set up log file path + log_file_path = log_dir / "mcp.log" + + # Open log file for stderr output + with open(log_file_path, "a", encoding="utf-8") as errlog: + yield errlog diff --git a/prompt/__init__.py b/prompt/__init__.py index 46e3a95..af8dfe1 100644 --- a/prompt/__init__.py +++ b/prompt/__init__.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING +import logging import questionary + from state import State # Add interactive prompt helpers using questionary that will collect values for @@ -8,7 +9,6 @@ class PromptHelper: - def __init__( self, state: State, @@ -77,7 +77,7 @@ async def collect_tool_input(self, input_schema, existing_args=None, include_opt if field_name in existing_args: continue - print(f"Collecting input for field: {field_info}") + logging.info(f"Collecting input for field: {field_info}") title = field_info["title"] if "title" in field_info else field_name field_type = field_info["type"] if "type" in field_info else "string" field_desc = field_info.get("description", None) diff --git a/tests/README.md b/tests/README.md index 2ea768e..8cf9b15 100644 --- a/tests/README.md +++ b/tests/README.md @@ -11,6 +11,7 @@ uv run pytest ``` uv run cxk.py init uv run cxk.py mcp add-stdio server-name2 --env KEY=value -- python server.py +uv run cxk.py mcp add-stdio test-mcp -- uv run mcp run tests/mcp_test_server.py ``` ``` @@ -35,7 +36,16 @@ uv run cxk.py create-spec tests/templates/spec1.md --verbose --var additional_co cat tests/templates/spec1.md | uv run cxk.py create-spec --verbose --var ticket='{"id":1}' --var additional_context=2 ``` -### With MCP functions +### With MCP function (fully specified) ``` -uv run cxk.py create-spec tests/templates/spec2.md +uv run cxk.py create-spec tests/templates/spec2.md --var additional_context=aa +``` + +``` +uv run cxk.py create-spec tests/templates/spec2.md --var additional_context=aa --verbose --output res.md +``` + +### With MCP function (partially specified) +``` +uv run cxk.py create-spec tests/templates/spec3.md --var additional_context=aa ``` \ No newline at end of file diff --git a/tests/mcp_test_server.py b/tests/mcp_test_server.py new file mode 100644 index 0000000..648bfe5 --- /dev/null +++ b/tests/mcp_test_server.py @@ -0,0 +1,34 @@ +""" +FastMCP reference MCP server for testing and demonstration purposes. +""" + +from typing import Any + +from mcp.server.fastmcp import FastMCP + +# Create an MCP server +mcp = FastMCP("Test-Server", "1.0.0") + + +# Add an addition tool +@mcp.tool() +def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + +@mcp.tool() +def jsonTest(cloudId: str, ticketId: str, optional_other: str | None = None) -> dict[str, Any]: + """Mock JOSN test tool""" + return { + "id": f"{cloudId} - {ticketId}", + "summary": f"Summary for {ticketId}", + "description": "This is a mock Jira ticket description.", + } + + +# Add a dynamic greeting resource +@mcp.resource("greeting://{name}") +def get_greeting(name: str) -> str: + """Get a personalized greeting""" + return f"Hello, {name}!" diff --git a/tests/templates/spec2.md b/tests/templates/spec2.md index 7a81b18..8b15073 100644 --- a/tests/templates/spec2.md +++ b/tests/templates/spec2.md @@ -2,10 +2,11 @@ ## Ticket description -{% set ticket = mcp('jira', 'getJiraIssue', {'issueKey': 'ACME-4432'}) %} +{% set ticket = mcp('test-mcp', 'jsonTest', {'cloudId': '1234', 'ticketId': 'ACME-123'}) %} {{ ticket.id }} +### Description {{ ticket.description }} ## Additional context diff --git a/tests/templates/spec3.md b/tests/templates/spec3.md new file mode 100644 index 0000000..93778c9 --- /dev/null +++ b/tests/templates/spec3.md @@ -0,0 +1,18 @@ +# Task Template + +## Ticket description + +{% set ticket = mcp('test-mcp', 'jsonTest', {'cloudId': '1234'}) %} + +{{ ticket.id }} + +### Description +{{ ticket.description }} + +## Additional context + +{{ additional_context }} + +## Some math... + +{{ mcp('test-mcp', 'add', {'a': 5}) }} From bc2eeb1e5d676bc99eb628f490786b259a5f76be Mon Sep 17 00:00:00 2001 From: eyalz Date: Tue, 12 Aug 2025 13:03:15 +0300 Subject: [PATCH 3/6] Update CLAUDE.md --- CLAUDE.md | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 35c28b3..78cad78 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -14,8 +14,11 @@ ContextKit is a CLI tool and MCP (Model Context Protocol) client for creating sp - `init.py` - Project initialization (creates `.cxk/` config directory) - `mcp.py` - MCP server management (add-sse, add-stdio, add-http) - `create_spec.py` - Template rendering with variable collection -- **Template engine**: `engine/` - Jinja2-based template processing -- **MCP configuration**: `mcp_client/config.py` - Pydantic models for MCP server configs +- **Template engine**: `engine/` - Jinja2-based template processing with async support and MCP tool integration + - `engine/globals.py` - Global Jinja2 functions including `mcp()` for calling MCP tools from templates +- **MCP client**: `mcp_client/` - MCP protocol client implementation + - `mcp_client/config.py` - Pydantic models for MCP server configs (stdio, SSE, HTTP) + - `mcp_client/client_session_provider.py` - Connection management for MCP servers - **User prompts**: `prompt/` - Interactive variable collection ## Core Concepts @@ -24,6 +27,7 @@ ContextKit is a CLI tool and MCP (Model Context Protocol) client for creating sp 2. **MCP servers**: Configured via CLI commands, stored in `.cxk/mcp.json` 3. **Spec templates**: Jinja2 templates with variables that get filled from MCP resources 4. **Context variables**: Can be automatic MCP resources or user-provided values +5. **MCP tool functions**: Templates can call MCP tools directly using `{{ mcp('server', 'tool', args) }}` syntax ## Common Commands @@ -35,6 +39,9 @@ uv sync # Run tests uv run pytest +# Run specific test +uv run pytest tests/test_specific.py + # Linting and formatting uv run ruff check @@ -54,6 +61,12 @@ python cxk.py mcp add-http server-name http://localhost:8000 # Create spec from template python cxk.py create-spec path/to/template.md uv run cxk.py create-spec tests/templates/spec1.md --var additional_context=aa --var ticket='{"id":1}' + +# Create spec with output file +uv run cxk.py create-spec tests/templates/spec1.md --output result.md + +# Pipe template content +cat tests/templates/spec1.md | uv run cxk.py create-spec --var ticket='{"id":1}' ``` ## Key Files From 158611bfb4930cfd27301074d50cfd43b7d73f57 Mon Sep 17 00:00:00 2001 From: eyalz Date: Tue, 12 Aug 2025 14:13:59 +0300 Subject: [PATCH 4/6] Add e2e tests for full and partial mcp calls --- mcp_client/client_session_provider.py | 2 - tests/e2e/test_e2e.py | 105 ++++++++++++++++++++++++++ tests/mcp_test_server.py | 2 +- tests/test_runner.py | 45 ++++++++++- 4 files changed, 149 insertions(+), 5 deletions(-) diff --git a/mcp_client/client_session_provider.py b/mcp_client/client_session_provider.py index 44468a0..485723d 100644 --- a/mcp_client/client_session_provider.py +++ b/mcp_client/client_session_provider.py @@ -1,8 +1,6 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path - -# Import State type for type hints from typing import TYPE_CHECKING from urllib.parse import parse_qs, urlparse diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index a199888..09efa9d 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -656,3 +656,108 @@ def test_create_spec_pipe_mode_with_output_file(self, temp_non_git_dir): assert output_file.exists() content = output_file.read_text() assert "Piped template: Hello from pipe!" in content + + def test_create_spec_with_mcp_call(self, temp_git_repo): + """Test create-spec with template that includes MCP tool calls.""" + # Initialize project first + init_result = self.run_cli(["init"], cwd=temp_git_repo) + assert init_result.returncode == 0 + + # Add test MCP server + server_path = Path(__file__).parent.parent / "mcp_test_server.py" + add_server_result = self.run_cli( + ["mcp", "add-stdio", "test-mcp", "--", "uv", "run", "mcp", "run", str(server_path)], + cwd=temp_git_repo, + ) + assert add_server_result.returncode == 0 + + # Use the existing spec2.md template that has MCP calls + template_path = Path(__file__).parent.parent / "templates" / "spec2.md" + + # Run create-spec command with test runner and additional_context variable + result = self.run_cli( + [ + "create-spec", + "--verbose", + str(template_path), + "--var", + "additional_context=This is test context", + ], + cwd=temp_git_repo, + use_test_runner=True, + ) + + assert result.returncode == 0 + + # Verify that MCP tool was called and returned expected data + # The template uses: mcp('test-mcp', 'jsonTest', {'cloudId': '1234', 'ticketId': 'ACME-123'}) + # The jsonTest tool returns: {"id": "1234 - ACME-123", "summary": "Summary for ACME-123", + # "description": "This is a mock Jira ticket description."} + + # Check that the rendered template contains the expected MCP tool output + assert "1234 - ACME-123" in result.stdout # ticket.id from MCP call + assert "This is a mock Jira ticket description." in result.stdout # ticket.description from MCP call + assert "This is test context" in result.stdout # additional_context variable + + # Verify template structure is preserved + assert "# Task Template" in result.stdout + assert "## Ticket description" in result.stdout + assert "### Description" in result.stdout + assert "## Additional context" in result.stdout + + def test_create_spec_with_partial_mcp_call(self, temp_git_repo): + """Test create-spec with template that includes partial MCP tool calls requiring user input.""" + # Initialize project first + init_result = self.run_cli(["init"], cwd=temp_git_repo) + assert init_result.returncode == 0 + + # Add test MCP server + server_path = Path(__file__).parent.parent / "mcp_test_server.py" + add_server_result = self.run_cli( + ["mcp", "add-stdio", "test-mcp", "--", "uv", "run", "mcp", "run", str(server_path)], + cwd=temp_git_repo, + ) + assert add_server_result.returncode == 0 + + # Use the existing spec3.md template that has partial MCP calls + template_path = Path(__file__).parent.parent / "templates" / "spec3.md" + + # Run create-spec command with test runner and additional_context variable + # The template has: + # - mcp('test-mcp', 'jsonTest', {'cloudId': '1234'}) - missing 'ticketId' parameter + # - mcp('test-mcp', 'add', {'a': 5}) - missing 'b' parameter + result = self.run_cli( + [ + "create-spec", + "--verbose", + str(template_path), + "--var", + "additional_context=This is test context for partial MCP", + ], + cwd=temp_git_repo, + use_test_runner=True, + ) + + assert result.returncode == 0 + + # Verify that MCP tools were called with both provided and collected parameters + # For jsonTest: cloudId='1234' + ticketId from mock (should be 'mock_value_ticketId') + # Expected output: "1234 - mock_value_ticketId" + assert "1234 - mock_value_ticketId" in result.stdout + + # Verify the description from jsonTest call + assert "This is a mock Jira ticket description." in result.stdout + + # For add: a=5 + b from mock (should be 10), so result should be 15 + assert "15" in result.stdout + + # Verify additional_context variable + assert "This is test context for partial MCP" in result.stdout + + # Verify template structure is preserved + assert "# Task Template" in result.stdout + assert "## Ticket description" in result.stdout + assert "### Description" in result.stdout + assert "## Additional context" in result.stdout + assert "## Some math..." in result.stdout + diff --git a/tests/mcp_test_server.py b/tests/mcp_test_server.py index 648bfe5..987be3e 100644 --- a/tests/mcp_test_server.py +++ b/tests/mcp_test_server.py @@ -19,7 +19,7 @@ def add(a: int, b: int) -> int: @mcp.tool() def jsonTest(cloudId: str, ticketId: str, optional_other: str | None = None) -> dict[str, Any]: - """Mock JOSN test tool""" + """Mock JSON test tool""" return { "id": f"{cloudId} - {ticketId}", "summary": f"Summary for {ticketId}", diff --git a/tests/test_runner.py b/tests/test_runner.py index 0933ee5..1ca25d1 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -25,7 +25,48 @@ async def mock_collect_var_value(var_name: str) -> str: return mock_values.get(var_name, f"mock_value_{var_name}") +async def mock_collect_tool_input(input_schema, existing_args=None, include_optional=True): + """Mock implementation for collecting MCP tool input parameters.""" + if existing_args is None: + existing_args = {} + + properties = input_schema["properties"] + required = input_schema.get("required", []) + values = existing_args.copy() + + # Mock values for common MCP tool parameters + mock_tool_values = { + "ticketId": "mock_value_ticketId", + "b": 10, # For integer type parameters like 'b' in add function + "cloudId": "mock_cloudId", + } + + for field_name, field_info in properties.items(): + # Skip if field is already provided in existing_args + if field_name in existing_args: + continue + + is_required = field_name in required + field_type = field_info.get("type", "string") + + if not include_optional and not is_required: + continue + + # Provide mock value based on field type and name + if field_name in mock_tool_values: + values[field_name] = mock_tool_values[field_name] + elif field_type == "integer": + values[field_name] = 42 + elif field_type == "boolean": + values[field_name] = True + else: # string or other types + values[field_name] = f"mock_value_{field_name}" + + return values + + if __name__ == "__main__": - # Patch collect_var_value before running main - with patch("prompt.PromptHelper.collect_var_value", side_effect=mock_collect_var_value): + # Patch both collect_var_value and collect_tool_input before running main + with patch("prompt.PromptHelper.collect_var_value", side_effect=mock_collect_var_value), \ + patch("prompt.PromptHelper.collect_tool_input", side_effect=mock_collect_tool_input): asyncio.run(main()) From 940bb16844178fd5ec66dfcb4698397f010ca05e Mon Sep 17 00:00:00 2001 From: eyalz Date: Tue, 12 Aug 2025 14:43:49 +0300 Subject: [PATCH 5/6] CR fix --- commands/create_spec.py | 13 +++---------- commands/mcp.py | 2 +- engine/globals.py | 11 ++--------- util/parse.py | 20 ++++++++++++++++++++ 4 files changed, 26 insertions(+), 20 deletions(-) create mode 100644 util/parse.py diff --git a/commands/create_spec.py b/commands/create_spec.py index 1f7c069..f2c8c07 100644 --- a/commands/create_spec.py +++ b/commands/create_spec.py @@ -1,4 +1,3 @@ -import json import logging import os import sys @@ -6,6 +5,7 @@ from engine import TemplateEngine, TemplateParseError from prompt import PromptHelper from state import State +from util.parse import parse_input_string async def handle_create_spec( @@ -71,15 +71,8 @@ async def handle_create_spec( raw_value = await prompt_helper.collect_var_value(var) logging.info(f" {var}: {raw_value}") - # Try to parse as JSON if it looks like JSON - if raw_value and (raw_value.strip().startswith("{") or raw_value.strip().startswith("[")): - try: - collected_vars[var] = json.loads(raw_value) - except json.JSONDecodeError: - # If it's not valid JSON, use as string - collected_vars[var] = raw_value - else: - collected_vars[var] = raw_value + collected_vars[var] = parse_input_string(raw_value) + else: logging.info("No variables found in template") diff --git a/commands/mcp.py b/commands/mcp.py index 7ca6361..e1e7314 100644 --- a/commands/mcp.py +++ b/commands/mcp.py @@ -1,7 +1,7 @@ from pydantic import BaseModel -from state import State from mcp_client.config import SSEServerConfig, StdioServerConfig +from state import State class MCPAddSSEContext(BaseModel): diff --git a/engine/globals.py b/engine/globals.py index e0bc21b..3ffc3f9 100644 --- a/engine/globals.py +++ b/engine/globals.py @@ -1,4 +1,3 @@ -import json import logging from typing import Any @@ -8,6 +7,7 @@ from mcp_client.client_session_provider import get_client_session_by_server from prompt import PromptHelper from state import State +from util.parse import parse_input_string def create_mcp_tool_function(state: State, prompt_helper: PromptHelper): @@ -28,14 +28,7 @@ async def call_mcp_tool(server: str, tool_name: str, args: dict) -> str | dict[s result = await session.call_tool(tool_name, arguments=full_arguments) result_unstructured = result.content[0] if isinstance(result_unstructured, types.TextContent): - # Try to parse the result as JSON - if result_unstructured.text.startswith("{") or result_unstructured.text.startswith("["): - try: - return json.loads(result_unstructured.text) - except json.JSONDecodeError: - logging.error("Failed to parse result as JSON, returning raw text.") - - return result_unstructured.text + return parse_input_string(result_unstructured.text) else: return "" except Exception as e: diff --git a/util/parse.py b/util/parse.py new file mode 100644 index 0000000..831925e --- /dev/null +++ b/util/parse.py @@ -0,0 +1,20 @@ +import json +from typing import Any + + +def parse_input_string(value: str) -> str | dict[str, Any]: + """Parse a string input that may be JSON or a simple string. + + Args: + value (str): The input string to parse. + + Returns: + str | dict[str, Any]: Parsed value, either as a string or a dictionary. + """ + value = value.strip() + if value.startswith("{") or value.startswith("["): + try: + return json.loads(value) + except json.JSONDecodeError: + return value + return value From 1674b8f6143e68f03b8b544e9757808634db4611 Mon Sep 17 00:00:00 2001 From: eyalz Date: Tue, 12 Aug 2025 14:51:43 +0300 Subject: [PATCH 6/6] CR fix --- tests/test_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_runner.py b/tests/test_runner.py index 1ca25d1..63c8218 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -66,7 +66,8 @@ async def mock_collect_tool_input(input_schema, existing_args=None, include_opti if __name__ == "__main__": - # Patch both collect_var_value and collect_tool_input before running main - with patch("prompt.PromptHelper.collect_var_value", side_effect=mock_collect_var_value), \ - patch("prompt.PromptHelper.collect_tool_input", side_effect=mock_collect_tool_input): + with ( + patch("prompt.PromptHelper.collect_var_value", side_effect=mock_collect_var_value), + patch("prompt.PromptHelper.collect_tool_input", side_effect=mock_collect_tool_input), + ): asyncio.run(main())