diff --git a/CLAUDE.md b/CLAUDE.md index 00d42ea..78cad78 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -14,8 +14,11 @@ ContextKit is a CLI tool and MCP (Model Context Protocol) client for creating sp - `init.py` - Project initialization (creates `.cxk/` config directory) - `mcp.py` - MCP server management (add-sse, add-stdio, add-http) - `create_spec.py` - Template rendering with variable collection -- **Template engine**: `engine/` - Jinja2-based template processing -- **MCP configuration**: `util/mcp/config.py` - Pydantic models for MCP server configs +- **Template engine**: `engine/` - Jinja2-based template processing with async support and MCP tool integration + - `engine/globals.py` - Global Jinja2 functions including `mcp()` for calling MCP tools from templates +- **MCP client**: `mcp_client/` - MCP protocol client implementation + - `mcp_client/config.py` - Pydantic models for MCP server configs (stdio, SSE, HTTP) + - `mcp_client/client_session_provider.py` - Connection management for MCP servers - **User prompts**: `prompt/` - Interactive variable collection ## Core Concepts @@ -24,6 +27,7 @@ ContextKit is a CLI tool and MCP (Model Context Protocol) client for creating sp 2. **MCP servers**: Configured via CLI commands, stored in `.cxk/mcp.json` 3. **Spec templates**: Jinja2 templates with variables that get filled from MCP resources 4. **Context variables**: Can be automatic MCP resources or user-provided values +5. **MCP tool functions**: Templates can call MCP tools directly using `{{ mcp('server', 'tool', args) }}` syntax ## Common Commands @@ -35,6 +39,9 @@ uv sync # Run tests uv run pytest +# Run specific test +uv run pytest tests/test_specific.py + # Linting and formatting uv run ruff check @@ -54,6 +61,12 @@ python cxk.py mcp add-http server-name http://localhost:8000 # Create spec from template python cxk.py create-spec path/to/template.md uv run cxk.py create-spec tests/templates/spec1.md --var additional_context=aa --var ticket='{"id":1}' + +# Create spec with output file +uv run cxk.py create-spec tests/templates/spec1.md --output result.md + +# Pipe template content +cat tests/templates/spec1.md | uv run cxk.py create-spec --var ticket='{"id":1}' ``` ## Key Files diff --git a/commands/create_spec.py b/commands/create_spec.py index 5adb5f1..f2c8c07 100644 --- a/commands/create_spec.py +++ b/commands/create_spec.py @@ -1,20 +1,20 @@ -import json import logging import os import sys from engine import TemplateEngine, TemplateParseError -from prompt import collect_var_value +from prompt import PromptHelper +from state import State +from util.parse import parse_input_string async def handle_create_spec( spec_template: str | None, + state: State, output_file: str | None = None, var_overrides: list[str] | None = None, - verbose: bool = False, ): - log_level = logging.DEBUG if verbose else logging.WARNING - logging.basicConfig(level=log_level, format="%(message)s", force=True) + prompt_helper = PromptHelper(state) # Detect piped input (stdin not a TTY) and ensure there's data before using it stdin_piped = not sys.stdin.isatty() @@ -30,7 +30,7 @@ async def handle_create_spec( logging.error(f"Error: Template file '{spec_template}' not found") sys.exit(1) - template_engine = TemplateEngine.from_file(template_path) + template_engine = TemplateEngine.from_file(template_path, state, prompt_helper) elif stdin_piped: try: template_str = sys.stdin.read() @@ -41,7 +41,7 @@ async def handle_create_spec( logging.error("Error: No data received on stdin for template") sys.exit(1) - template_engine = TemplateEngine.from_string(template_str) + template_engine = TemplateEngine.from_string(template_str, state, prompt_helper) else: logging.error("Error: Missing spec_template argument (or provide template via stdin)") sys.exit(1) @@ -68,18 +68,11 @@ async def handle_create_spec( raw_value = provided_vars[var] else: - raw_value = await collect_var_value(var) + raw_value = await prompt_helper.collect_var_value(var) logging.info(f" {var}: {raw_value}") - # Try to parse as JSON if it looks like JSON - if raw_value and (raw_value.strip().startswith("{") or raw_value.strip().startswith("[")): - try: - collected_vars[var] = json.loads(raw_value) - except json.JSONDecodeError: - # If it's not valid JSON, use as string - collected_vars[var] = raw_value - else: - collected_vars[var] = raw_value + collected_vars[var] = parse_input_string(raw_value) + else: logging.info("No variables found in template") diff --git a/commands/mcp.py b/commands/mcp.py index 19c2067..e1e7314 100644 --- a/commands/mcp.py +++ b/commands/mcp.py @@ -1,7 +1,7 @@ from pydantic import BaseModel +from mcp_client.config import SSEServerConfig, StdioServerConfig from state import State -from util.mcp.config import SSEServerConfig, StdioServerConfig class MCPAddSSEContext(BaseModel): diff --git a/cxk.py b/cxk.py index 2104247..84fa18f 100644 --- a/cxk.py +++ b/cxk.py @@ -1,5 +1,6 @@ import argparse import asyncio +import logging import sys from commands.create_spec import handle_create_spec @@ -61,7 +62,9 @@ async def main(): await handle_init(state) elif args.command == "create-spec": - await handle_create_spec(args.spec_template, args.output, args.var, args.verbose) + log_level = logging.DEBUG if args.verbose else logging.WARNING + logging.basicConfig(level=log_level, format="%(message)s", force=True) + await handle_create_spec(args.spec_template, state, args.output, args.var) elif args.command == "mcp": if not args.mcp_command: @@ -94,7 +97,7 @@ async def main(): await handle_mcp(state, mcp_context) except Exception as e: - print(f"Error: {e}", file=sys.stderr) + logging.exception(f"Error: {e}") sys.exit(1) diff --git a/engine/__init__.py b/engine/__init__.py index ffa5604..025b703 100644 --- a/engine/__init__.py +++ b/engine/__init__.py @@ -3,6 +3,10 @@ from jinja2 import Environment, FileSystemLoader, Template, meta, select_autoescape +from engine.globals import create_mcp_tool_function +from prompt import PromptHelper +from state import State + class TemplateParseError(Exception): """Raised when template parsing fails""" @@ -14,20 +18,32 @@ class TemplateEngine: """Abstract away the jinja2 template engine with clean factory methods""" def __init__( - self, env: Environment, template: Template, source_path: Path | None = None, source_string: str | None = None + self, + env: Environment, + template: Template, + state: State, + prompt_helper: PromptHelper, + source_path: Path | None = None, + source_string: str | None = None, ): """Private constructor - use from_file() or from_string() instead""" self.env = env self.template = template self._source_path = source_path self._source_string = source_string + self._state = state + self._prompt_helper = prompt_helper + + # Add global functions to env + self.env.globals["mcp"] = create_mcp_tool_function(self._state, self._prompt_helper) @classmethod - def from_file(cls, path: str | Path) -> "TemplateEngine": + def from_file(cls, path: str | Path, state: State, prompt_helper: PromptHelper) -> "TemplateEngine": """Create a TemplateEngine from a template file. Args: path: Path to the template file + state: State object containing project configuration Returns: TemplateEngine instance @@ -55,14 +71,19 @@ def from_file(cls, path: str | Path) -> "TemplateEngine": except Exception as e: raise TemplateParseError(f"Failed to load template from {path}: {e}") from e - return cls(env=env, template=template, source_path=path, source_string=None) + return cls( + env=env, template=template, state=state, source_path=path, source_string=None, prompt_helper=prompt_helper + ) @classmethod - def from_string(cls, template_string: str, name: str = "") -> "TemplateEngine": + def from_string( + cls, template_string: str, state: State, prompt_helper: PromptHelper, name: str = "" + ) -> "TemplateEngine": """Create a TemplateEngine from a template string. Args: template_string: The template content as a string + state: State object containing project configuration name: Optional name for the template (for debugging) Returns: @@ -85,7 +106,14 @@ def from_string(cls, template_string: str, name: str = "") -> "TemplateEn # Store the name in the template for better error messages template.name = name - return cls(env=env, template=template, source_path=None, source_string=template_string) + return cls( + env=env, + template=template, + state=state, + source_path=None, + source_string=template_string, + prompt_helper=prompt_helper, + ) @property def source(self) -> str: diff --git a/engine/globals.py b/engine/globals.py new file mode 100644 index 0000000..3ffc3f9 --- /dev/null +++ b/engine/globals.py @@ -0,0 +1,41 @@ +import logging +from typing import Any + +from mcp import types +from mcp.shared.metadata_utils import get_display_name + +from mcp_client.client_session_provider import get_client_session_by_server +from prompt import PromptHelper +from state import State +from util.parse import parse_input_string + + +def create_mcp_tool_function(state: State, prompt_helper: PromptHelper): + """Create a call_mcp_tool function with state bound.""" + + async def call_mcp_tool(server: str, tool_name: str, args: dict) -> str | dict[str, Any]: + logging.info(f"Calling MCP tool: {tool_name} on server: {server} with args: {args}") + async with get_client_session_by_server(server, state) as session: + # Initialize the connection + await session.initialize() + + tools = await session.list_tools() + # Call the tool with collected input + try: + full_arguments = await prompt_helper.get_full_args(tools, tool_name, args) + + logging.debug(f"Full arguments for tool {tool_name}: {full_arguments}") + result = await session.call_tool(tool_name, arguments=full_arguments) + result_unstructured = result.content[0] + if isinstance(result_unstructured, types.TextContent): + return parse_input_string(result_unstructured.text) + else: + return "" + except Exception as e: + logging.error(f"Error calling tool {tool_name}: {e}") + logging.error("Available tools:") + for tool in tools.tools: + logging.error(f" - {tool.name}: {get_display_name(tool)}") + return "" + + return call_mcp_tool diff --git a/mcp_client/__init__.py b/mcp_client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcp_client/client_session_provider.py b/mcp_client/client_session_provider.py new file mode 100644 index 0000000..485723d --- /dev/null +++ b/mcp_client/client_session_provider.py @@ -0,0 +1,137 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from pathlib import Path +from typing import TYPE_CHECKING +from urllib.parse import parse_qs, urlparse + +from mcp import ClientSession, StdioServerParameters +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from pydantic import AnyUrl + +from .config import SSEServerConfig, StdioServerConfig +from .mcp_logger import get_mcp_log_file + +if TYPE_CHECKING: + from state import State + + +class InMemoryTokenStorage(TokenStorage): + """Demo In-memory token storage implementation.""" + + def __init__(self): + self.tokens: OAuthToken | None = None + self.client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + """Get stored tokens.""" + return self.tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + """Store tokens.""" + self.tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + """Get stored client information.""" + return self.client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + """Store client information.""" + self.client_info = client_info + + +async def handle_redirect(auth_url: str) -> None: + print(f"Visit: {auth_url}") + + +async def handle_callback() -> tuple[str, str | None]: + callback_url = input("Paste callback URL: ") + params = parse_qs(urlparse(callback_url).query) + return params["code"][0], params.get("state", [None])[0] + + +@asynccontextmanager +async def get_stdio_session(server_params: StdioServerParameters, config_dir: Path | None = None): + with get_mcp_log_file(config_dir) as errlog: + async with stdio_client(server_params, errlog=errlog) as (read, write): + async with ClientSession(read, write) as session: + yield session + + +@asynccontextmanager +async def get_streamablehttp_session(server_url: str): + oauth_auth = OAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="Example MCP Client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scope="user", + ), + storage=InMemoryTokenStorage(), + redirect_handler=handle_redirect, + callback_handler=handle_callback, + ) + # Connect to a streamable HTTP server + async with streamablehttp_client(server_url, auth=oauth_auth) as ( + read_stream, + write_stream, + _, + ): + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + yield session + + +@asynccontextmanager +async def get_sse_session(server_url: str): + oauth_auth = OAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="Example MCP Client", + redirect_uris=[AnyUrl("http://localhost:41008/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scope="user", + ), + storage=InMemoryTokenStorage(), + redirect_handler=handle_redirect, + callback_handler=handle_callback, + ) + + # Connect to a Server-Sent Events (SSE) server + async with sse_client( + url=server_url, + auth=oauth_auth, + timeout=60, + ) as (read_stream, write_stream): + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + yield session + + +@asynccontextmanager +async def get_client_session_by_server(server_name: str, state: "State") -> AsyncGenerator[ClientSession, None]: + # Find the server configuration by name + server_config = state.mcp_config.mcpServers.get(server_name) + if not server_config: + raise ValueError(f"Server '{server_name}' not found in configuration.") + if isinstance(server_config, StdioServerConfig): + async with get_stdio_session( + StdioServerParameters( + command=server_config.command, + args=server_config.args or [], + env=server_config.env, + ), + config_dir=state.config_dir, + ) as session: + yield session + elif isinstance(server_config, SSEServerConfig): + async with get_sse_session(server_config.url) as session: + yield session + else: + raise ValueError(f"Unsupported server type for '{server_name}': {type(server_config)}") diff --git a/util/mcp/config.py b/mcp_client/config.py similarity index 100% rename from util/mcp/config.py rename to mcp_client/config.py diff --git a/mcp_client/mcp_logger.py b/mcp_client/mcp_logger.py new file mode 100644 index 0000000..0932126 --- /dev/null +++ b/mcp_client/mcp_logger.py @@ -0,0 +1,32 @@ +from collections.abc import Generator +from contextlib import contextmanager +from pathlib import Path +from typing import TextIO + + +@contextmanager +def get_mcp_log_file(config_dir: Path | None) -> Generator[TextIO, None, None]: + """ + Context manager that provides a log file handle for MCP stderr output. + + Args: + config_dir: The configuration directory (typically .cxk) where logs should be stored. + If None, falls back to current directory. + + Yields: + TextIO: File handle for writing MCP stderr logs + """ + if config_dir is None: + log_dir = Path.cwd() / ".cxk" + else: + log_dir = config_dir + + # Ensure log directory exists + log_dir.mkdir(exist_ok=True) + + # Set up log file path + log_file_path = log_dir / "mcp.log" + + # Open log file for stderr output + with open(log_file_path, "a", encoding="utf-8") as errlog: + yield errlog diff --git a/prompt/__init__.py b/prompt/__init__.py index f937680..af8dfe1 100644 --- a/prompt/__init__.py +++ b/prompt/__init__.py @@ -1,10 +1,157 @@ +import logging + import questionary +from state import State + # Add interactive prompt helpers using questionary that will collect values for # unspecified template variables. -async def collect_var_value(var_name: str) -> str: - """Collect a value for a variable using questionary.""" +class PromptHelper: + def __init__( + self, + state: State, + ): + self._state = state + + async def collect_var_value(self, var_name: str) -> str: + """Collect a value for a variable using questionary.""" + + return await questionary.text(f"Please provide a value for '{var_name}':").ask_async() + + async def get_full_args(self, tools, tool_name, args): + """ + Get full arguments for the tool call, using collect_tool_input() to collect missing required fields. + + :param tools: List of available tools + :param tool_name: Name of the tool to call + :param args: Arguments provided by the user + :return: Full arguments including defaults and required fields + """ + selected_tool = next((t for t in tools.tools if t.name == tool_name), None) + if not selected_tool: + raise ValueError(f"Tool '{tool_name}' not found.") + + input_schema = selected_tool.inputSchema + if not input_schema: + return args + + # Use collect_tool_input to fill in missing required fields + full_args = await self.collect_tool_input(input_schema, args, include_optional=False) + + return full_args + + async def collect_tool_input(self, input_schema, existing_args=None, include_optional=True): + """ + Collect user input based on a schema using Questionary. + Skip fields that are already provided in existing_args. + + Args: + input_schema (dict): Schema with 'properties' and 'required' keys + existing_args (dict, optional): Already collected values + + Returns: + dict: Dictionary with field names as keys and collected values + + Example: + schema = { + 'properties': { + 'a': {'title': 'A', 'type': 'integer'}, + 'b': {'title': 'B', 'type': 'integer'} + }, + 'required': ['a', 'b'] + } + result = collect_input(schema) # Returns something like {"a": 5, "b": 3} + """ + properties = input_schema["properties"] + required = input_schema.get("required", []) + + if existing_args is None: + existing_args = {} + + values = existing_args.copy() + + for field_name, field_info in properties.items(): + # Skip if field is already provided in existing_args + if field_name in existing_args: + continue + + logging.info(f"Collecting input for field: {field_info}") + title = field_info["title"] if "title" in field_info else field_name + field_type = field_info["type"] if "type" in field_info else "string" + field_desc = field_info.get("description", None) + is_required = field_name in required + + if not include_optional and not is_required: + # Skip optional fields if not requested + continue + + # Create prompt text + prompt_text = title + if is_required: + prompt_text += " (required)" + else: + prompt_text += " (optional)" + prompt_text += ":" + + if field_type == "integer": + # Create validator for integer fields + def make_integer_validator(required): + def validate_integer(value): + # Allow empty for optional fields + if not value.strip(): + if required: + return "This field is required" + return True + # Validate integer format + try: + int(value) + return True + except ValueError: + return "Please enter a valid integer" + + return validate_integer + + result = await questionary.text(prompt_text, validate=make_integer_validator(is_required)).ask_async() + + # Process the result + if result and result.strip(): + values[field_name] = int(result) + elif is_required: + # This shouldn't happen due to validation + raise ValueError(f"Required field {field_name} is missing") + # Optional fields that are empty are not included in output + + elif field_type == "string": + # Create validator for string fields + def make_string_validator(required): + def validate_string(value): + if not value.strip() and required: + return "This field is required" + return True + + return validate_string + + result = await questionary.text( + prompt_text, + validate=make_string_validator(is_required), + instruction=field_desc, + ).ask_async() + + # Process the result + if result and result.strip(): + values[field_name] = result + elif is_required: + # This shouldn't happen due to validation + raise ValueError(f"Required field {field_name} is missing") + # Optional fields that are empty are not included in output + + elif field_type == "boolean": + # Create a yes/no prompt for boolean fields + result = await questionary.confirm(prompt_text, default=False, qmark="?").ask_async() + + # Store the boolean value + values[field_name] = result - return await questionary.text(f"Please provide a value for '{var_name}':").ask_async() + return values diff --git a/pyproject.toml b/pyproject.toml index 5b306cb..05e00e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "jinja2>=3.1.6", + "mcp[cli]>=1.12.4", "pydantic>=2.11.7", "questionary>=2.1.0", ] diff --git a/state.py b/state.py index bfc9ca5..8856227 100644 --- a/state.py +++ b/state.py @@ -1,7 +1,7 @@ import json from pathlib import Path -from util.mcp.config import MCPServersConfig +from mcp_client.config import MCPServersConfig class State: diff --git a/tests/README.md b/tests/README.md index 3555fb3..8cf9b15 100644 --- a/tests/README.md +++ b/tests/README.md @@ -7,6 +7,13 @@ uv run pytest ## Manual Testing +### Initialize the environment and adding mcp configuration +``` +uv run cxk.py init +uv run cxk.py mcp add-stdio server-name2 --env KEY=value -- python server.py +uv run cxk.py mcp add-stdio test-mcp -- uv run mcp run tests/mcp_test_server.py +``` + ``` uv run cxk.py create-spec tests/templates/spec1.md ``` @@ -27,4 +34,18 @@ uv run cxk.py create-spec tests/templates/spec1.md --verbose --var additional_co ### Piped ``` cat tests/templates/spec1.md | uv run cxk.py create-spec --verbose --var ticket='{"id":1}' --var additional_context=2 +``` + +### With MCP function (fully specified) +``` +uv run cxk.py create-spec tests/templates/spec2.md --var additional_context=aa +``` + +``` +uv run cxk.py create-spec tests/templates/spec2.md --var additional_context=aa --verbose --output res.md +``` + +### With MCP function (partially specified) +``` +uv run cxk.py create-spec tests/templates/spec3.md --var additional_context=aa ``` \ No newline at end of file diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index c575967..09efa9d 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -610,7 +610,7 @@ def test_create_spec_pipe_mode(self, temp_non_git_dir): "create-spec", "--verbose", "--var", - "ticket={\"id\":1}", + 'ticket={"id":1}', "--var", "additional_context=test context", ], @@ -656,3 +656,108 @@ def test_create_spec_pipe_mode_with_output_file(self, temp_non_git_dir): assert output_file.exists() content = output_file.read_text() assert "Piped template: Hello from pipe!" in content + + def test_create_spec_with_mcp_call(self, temp_git_repo): + """Test create-spec with template that includes MCP tool calls.""" + # Initialize project first + init_result = self.run_cli(["init"], cwd=temp_git_repo) + assert init_result.returncode == 0 + + # Add test MCP server + server_path = Path(__file__).parent.parent / "mcp_test_server.py" + add_server_result = self.run_cli( + ["mcp", "add-stdio", "test-mcp", "--", "uv", "run", "mcp", "run", str(server_path)], + cwd=temp_git_repo, + ) + assert add_server_result.returncode == 0 + + # Use the existing spec2.md template that has MCP calls + template_path = Path(__file__).parent.parent / "templates" / "spec2.md" + + # Run create-spec command with test runner and additional_context variable + result = self.run_cli( + [ + "create-spec", + "--verbose", + str(template_path), + "--var", + "additional_context=This is test context", + ], + cwd=temp_git_repo, + use_test_runner=True, + ) + + assert result.returncode == 0 + + # Verify that MCP tool was called and returned expected data + # The template uses: mcp('test-mcp', 'jsonTest', {'cloudId': '1234', 'ticketId': 'ACME-123'}) + # The jsonTest tool returns: {"id": "1234 - ACME-123", "summary": "Summary for ACME-123", + # "description": "This is a mock Jira ticket description."} + + # Check that the rendered template contains the expected MCP tool output + assert "1234 - ACME-123" in result.stdout # ticket.id from MCP call + assert "This is a mock Jira ticket description." in result.stdout # ticket.description from MCP call + assert "This is test context" in result.stdout # additional_context variable + + # Verify template structure is preserved + assert "# Task Template" in result.stdout + assert "## Ticket description" in result.stdout + assert "### Description" in result.stdout + assert "## Additional context" in result.stdout + + def test_create_spec_with_partial_mcp_call(self, temp_git_repo): + """Test create-spec with template that includes partial MCP tool calls requiring user input.""" + # Initialize project first + init_result = self.run_cli(["init"], cwd=temp_git_repo) + assert init_result.returncode == 0 + + # Add test MCP server + server_path = Path(__file__).parent.parent / "mcp_test_server.py" + add_server_result = self.run_cli( + ["mcp", "add-stdio", "test-mcp", "--", "uv", "run", "mcp", "run", str(server_path)], + cwd=temp_git_repo, + ) + assert add_server_result.returncode == 0 + + # Use the existing spec3.md template that has partial MCP calls + template_path = Path(__file__).parent.parent / "templates" / "spec3.md" + + # Run create-spec command with test runner and additional_context variable + # The template has: + # - mcp('test-mcp', 'jsonTest', {'cloudId': '1234'}) - missing 'ticketId' parameter + # - mcp('test-mcp', 'add', {'a': 5}) - missing 'b' parameter + result = self.run_cli( + [ + "create-spec", + "--verbose", + str(template_path), + "--var", + "additional_context=This is test context for partial MCP", + ], + cwd=temp_git_repo, + use_test_runner=True, + ) + + assert result.returncode == 0 + + # Verify that MCP tools were called with both provided and collected parameters + # For jsonTest: cloudId='1234' + ticketId from mock (should be 'mock_value_ticketId') + # Expected output: "1234 - mock_value_ticketId" + assert "1234 - mock_value_ticketId" in result.stdout + + # Verify the description from jsonTest call + assert "This is a mock Jira ticket description." in result.stdout + + # For add: a=5 + b from mock (should be 10), so result should be 15 + assert "15" in result.stdout + + # Verify additional_context variable + assert "This is test context for partial MCP" in result.stdout + + # Verify template structure is preserved + assert "# Task Template" in result.stdout + assert "## Ticket description" in result.stdout + assert "### Description" in result.stdout + assert "## Additional context" in result.stdout + assert "## Some math..." in result.stdout + diff --git a/tests/mcp_test_server.py b/tests/mcp_test_server.py new file mode 100644 index 0000000..987be3e --- /dev/null +++ b/tests/mcp_test_server.py @@ -0,0 +1,34 @@ +""" +FastMCP reference MCP server for testing and demonstration purposes. +""" + +from typing import Any + +from mcp.server.fastmcp import FastMCP + +# Create an MCP server +mcp = FastMCP("Test-Server", "1.0.0") + + +# Add an addition tool +@mcp.tool() +def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + +@mcp.tool() +def jsonTest(cloudId: str, ticketId: str, optional_other: str | None = None) -> dict[str, Any]: + """Mock JSON test tool""" + return { + "id": f"{cloudId} - {ticketId}", + "summary": f"Summary for {ticketId}", + "description": "This is a mock Jira ticket description.", + } + + +# Add a dynamic greeting resource +@mcp.resource("greeting://{name}") +def get_greeting(name: str) -> str: + """Get a personalized greeting""" + return f"Hello, {name}!" diff --git a/tests/templates/spec2.md b/tests/templates/spec2.md new file mode 100644 index 0000000..8b15073 --- /dev/null +++ b/tests/templates/spec2.md @@ -0,0 +1,14 @@ +# Task Template + +## Ticket description + +{% set ticket = mcp('test-mcp', 'jsonTest', {'cloudId': '1234', 'ticketId': 'ACME-123'}) %} + +{{ ticket.id }} + +### Description +{{ ticket.description }} + +## Additional context + +{{ additional_context }} \ No newline at end of file diff --git a/tests/templates/spec3.md b/tests/templates/spec3.md new file mode 100644 index 0000000..93778c9 --- /dev/null +++ b/tests/templates/spec3.md @@ -0,0 +1,18 @@ +# Task Template + +## Ticket description + +{% set ticket = mcp('test-mcp', 'jsonTest', {'cloudId': '1234'}) %} + +{{ ticket.id }} + +### Description +{{ ticket.description }} + +## Additional context + +{{ additional_context }} + +## Some math... + +{{ mcp('test-mcp', 'add', {'a': 5}) }} diff --git a/tests/test_runner.py b/tests/test_runner.py index 050b7c8..63c8218 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Test runner script that patches collect_var_value for e2e testing.""" + import asyncio import sys from pathlib import Path @@ -19,12 +20,54 @@ async def mock_collect_var_value(var_name: str) -> str: "city": "New York", "weather": '{"condition": "sunny", "temp": "75F"}', "username": "testuser", - "user": "test_user" + "user": "test_user", } return mock_values.get(var_name, f"mock_value_{var_name}") +async def mock_collect_tool_input(input_schema, existing_args=None, include_optional=True): + """Mock implementation for collecting MCP tool input parameters.""" + if existing_args is None: + existing_args = {} + + properties = input_schema["properties"] + required = input_schema.get("required", []) + values = existing_args.copy() + + # Mock values for common MCP tool parameters + mock_tool_values = { + "ticketId": "mock_value_ticketId", + "b": 10, # For integer type parameters like 'b' in add function + "cloudId": "mock_cloudId", + } + + for field_name, field_info in properties.items(): + # Skip if field is already provided in existing_args + if field_name in existing_args: + continue + + is_required = field_name in required + field_type = field_info.get("type", "string") + + if not include_optional and not is_required: + continue + + # Provide mock value based on field type and name + if field_name in mock_tool_values: + values[field_name] = mock_tool_values[field_name] + elif field_type == "integer": + values[field_name] = 42 + elif field_type == "boolean": + values[field_name] = True + else: # string or other types + values[field_name] = f"mock_value_{field_name}" + + return values + + if __name__ == "__main__": - # Patch collect_var_value before running main - with patch('commands.create_spec.collect_var_value', side_effect=mock_collect_var_value): - asyncio.run(main()) \ No newline at end of file + with ( + patch("prompt.PromptHelper.collect_var_value", side_effect=mock_collect_var_value), + patch("prompt.PromptHelper.collect_tool_input", side_effect=mock_collect_tool_input), + ): + asyncio.run(main()) diff --git a/util/parse.py b/util/parse.py new file mode 100644 index 0000000..831925e --- /dev/null +++ b/util/parse.py @@ -0,0 +1,20 @@ +import json +from typing import Any + + +def parse_input_string(value: str) -> str | dict[str, Any]: + """Parse a string input that may be JSON or a simple string. + + Args: + value (str): The input string to parse. + + Returns: + str | dict[str, Any]: Parsed value, either as a string or a dictionary. + """ + value = value.strip() + if value.startswith("{") or value.startswith("["): + try: + return json.loads(value) + except json.JSONDecodeError: + return value + return value