diff --git a/.gitignore b/.gitignore index 42d959a..d9cfd26 100644 --- a/.gitignore +++ b/.gitignore @@ -126,16 +126,6 @@ ENV/ env.bak/ venv.bak/ -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - # mypy .mypy_cache/ .dmypy.json @@ -189,4 +179,5 @@ marimo/_lsp/ __marimo__/ # ContextKit -.cxk \ No newline at end of file +.cxk +task.md \ No newline at end of file diff --git a/auth_server/__init__.py b/auth_server/__init__.py new file mode 100644 index 0000000..d6d15ec --- /dev/null +++ b/auth_server/__init__.py @@ -0,0 +1,5 @@ +"""OAuth2 auth server for handling MCP authorization callbacks.""" + +from .auth_server import AuthServer + +__all__ = ["AuthServer"] diff --git a/auth_server/auth_server.py b/auth_server/auth_server.py new file mode 100644 index 0000000..501eaeb --- /dev/null +++ b/auth_server/auth_server.py @@ -0,0 +1,185 @@ +"""OAuth2 authentication server for handling MCP client callbacks.""" + +import asyncio +import logging +import threading + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import HTMLResponse + + +class AuthServer: + """FastAPI-based OAuth2 callback server for MCP authentication.""" + + def __init__(self, host: str = "localhost", port: int = 41008): + """Initialize the auth server.""" + self.host = host + self.port = port + self.app = FastAPI() + self.server: uvicorn.Server | None = None + self.server_thread: threading.Thread | None = None + + # Storage for the callback response + self._callback_future: asyncio.Future | None = None + self._callback_code: str | None = None + self._callback_state: str | None = None + self._main_loop: asyncio.AbstractEventLoop | None = None + + # Setup routes + self._setup_routes() + + def _setup_routes(self): + """Setup FastAPI routes.""" + + @self.app.get("/callback") + async def callback_endpoint(request: Request): + """Handle OAuth2 callback.""" + query_params = dict(request.query_params) + + # Extract code and state from query parameters + code = query_params.get("code") + state = query_params.get("state") + error = query_params.get("error") + + if error: + error_description = query_params.get("error_description", "Unknown error") + if self._callback_future and not self._callback_future.done(): + # Schedule the exception in the main event loop + if self._main_loop: + self._main_loop.call_soon_threadsafe( + self._callback_future.set_exception, + Exception(f"OAuth error: {error} - {error_description}"), + ) + return HTMLResponse( + content=f"

Authentication Error

{error}: {error_description}

", # noqa: E501 + status_code=400, + ) + + if not code: + if self._callback_future and not self._callback_future.done(): + # Schedule the exception in the main event loop + if self._main_loop: + self._main_loop.call_soon_threadsafe( + self._callback_future.set_exception, Exception("No authorization code received") + ) + return HTMLResponse( + content="

Error

No authorization code received

", + status_code=400, + ) + + # Store the callback data + self._callback_code = code + self._callback_state = state + + # Signal that callback was received + if self._callback_future and not self._callback_future.done(): + # Schedule the future resolution in the main event loop + if self._main_loop: + self._main_loop.call_soon_threadsafe(self._callback_future.set_result, (code, state)) + logging.info("OAuth2 callback handled successfully") + else: + logging.error("Main event loop is not set; cannot set callback future result") + else: + logging.error("Callback future is not set or already done; cannot set result") + + return HTMLResponse( + content=""" + + +

Authorization Successful

+

You have successfully authorized the MCP client. You can now close this window.

+ + + """ + ) + + async def start(self): + """Start the auth server in a background thread.""" + logging.info(f"Starting auth server at http://{self.host}:{self.port}") + if self.server_thread and self.server_thread.is_alive(): + logging.info("Auth server is already running") + return # Already running + + config = uvicorn.Config( + app=self.app, + host=self.host, + port=self.port, + log_level="warning", # Reduce log noise + access_log=False, + ) + self.server = uvicorn.Server(config) + + # Start server in a separate thread to avoid blocking + self.server_thread = threading.Thread(target=self._run_server, daemon=True) + self.server_thread.start() + + # Wait a bit for the server to start + await asyncio.sleep(0.1) + + def _run_server(self): + """Run the server in the thread.""" + if self.server is not None: + asyncio.run(self.server.serve()) + + async def handle_callback(self) -> tuple[str, str | None]: + """ + Wait for and handle the OAuth2 callback. + + Returns: + tuple: (code, state) from the OAuth2 callback + + Raises: + Exception: If there's an error in the OAuth2 flow + """ + # Capture the current event loop + self._main_loop = asyncio.get_running_loop() + + # Create a future to wait for the callback + self._callback_future = asyncio.Future() + + try: + # Wait for the callback to be received + logging.info("Waiting for OAuth2 callback...") + code, state = await self._callback_future + logging.info(f"Received callback with code: {code}, state: {state}") + return code, state + finally: + # Clean up + logging.info("Cleaning up auth server state after callback") + self._callback_future = None + self._main_loop = None + + async def stop(self): + """Stop the auth server and clean up state.""" + if self.server: + self.server.should_exit = True + + if self.server_thread and self.server_thread.is_alive(): + # Wait a bit for graceful shutdown + self.server_thread.join(timeout=1.0) + + # Clear state + self._callback_code = None + self._callback_state = None + self._callback_future = None + self._main_loop = None + self.server = None + self.server_thread = None + + logging.info("Auth server stopped") + + @property + def callback_url(self) -> str: + """Get the callback URL for this auth server.""" + return f"http://{self.host}:{self.port}/callback" + + async def __aenter__(self): + """Async context manager entry.""" + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.stop() + return None diff --git a/commands/create_spec.py b/commands/create_spec.py index f2c8c07..13fe738 100644 --- a/commands/create_spec.py +++ b/commands/create_spec.py @@ -3,6 +3,7 @@ import sys from engine import TemplateEngine, TemplateParseError +from mcp_client import get_session_manager from prompt import PromptHelper from state import State from util.parse import parse_input_string @@ -14,84 +15,92 @@ async def handle_create_spec( output_file: str | None = None, var_overrides: list[str] | None = None, ): - prompt_helper = PromptHelper(state) - - # Detect piped input (stdin not a TTY) and ensure there's data before using it - stdin_piped = not sys.stdin.isatty() - - # Determine template source - template_engine: TemplateEngine - if spec_template: - # Resolve relative paths against current working directory - template_path = os.path.abspath(spec_template) - - # Check if template file exists - if not os.path.exists(template_path): - logging.error(f"Error: Template file '{spec_template}' not found") + # Pre-initialize all MCP client sessions + session_manager = get_session_manager() + try: + await session_manager.initialize_all_sessions(state) + prompt_helper = PromptHelper(state) + + # Detect piped input (stdin not a TTY) and ensure there's data before using it + stdin_piped = not sys.stdin.isatty() + + # Determine template source + template_engine: TemplateEngine + if spec_template: + # Resolve relative paths against current working directory + template_path = os.path.abspath(spec_template) + + # Check if template file exists + if not os.path.exists(template_path): + logging.error(f"Error: Template file '{spec_template}' not found") + sys.exit(1) + + template_engine = TemplateEngine.from_file(template_path, state, prompt_helper) + elif stdin_piped: + try: + template_str = sys.stdin.read() + except Exception as e: + logging.error(f"Error: Failed to read from stdin: {e}") + sys.exit(1) + if not template_str: + logging.error("Error: No data received on stdin for template") + sys.exit(1) + + template_engine = TemplateEngine.from_string(template_str, state, prompt_helper) + else: + logging.error("Error: Missing spec_template argument (or provide template via stdin)") sys.exit(1) - template_engine = TemplateEngine.from_file(template_path, state, prompt_helper) - elif stdin_piped: try: - template_str = sys.stdin.read() - except Exception as e: - logging.error(f"Error: Failed to read from stdin: {e}") + variables = template_engine.get_variables() + + # Parse var_overrides into a dictionary + provided_vars = {} + if var_overrides: + for var_override in var_overrides: + if "=" not in var_override: + logging.error(f"Error: Invalid variable format '{var_override}'. Use KEY=VALUE format.") + sys.exit(1) + key, value = var_override.split("=", 1) + provided_vars[key] = value + + # Collect values for each variable + collected_vars = {} + if variables: + logging.info("Collecting values for template variables:") + for var in sorted(variables): + if var in provided_vars: + raw_value = provided_vars[var] + + else: + raw_value = await prompt_helper.collect_var_value(var) + logging.info(f" {var}: {raw_value}") + + collected_vars[var] = parse_input_string(raw_value) + + else: + logging.info("No variables found in template") + + # Render the template with collected variables + rendered_content = await template_engine.render_async(**collected_vars) + + # Output to file or stdout + if output_file: + output_path = os.path.abspath(output_file) + with open(output_path, "w") as f: + f.write(rendered_content) + logging.info(f"Rendered template saved to: {output_path}") + else: + logging.debug("\nRendered template:") + print(rendered_content) + + except TemplateParseError as e: + logging.error(f"Error: {e}") sys.exit(1) - if not template_str: - logging.error("Error: No data received on stdin for template") + except Exception as e: + logging.error(f"Error: Failed to process template: {e}") sys.exit(1) - template_engine = TemplateEngine.from_string(template_str, state, prompt_helper) - else: - logging.error("Error: Missing spec_template argument (or provide template via stdin)") - sys.exit(1) - - try: - variables = template_engine.get_variables() - - # Parse var_overrides into a dictionary - provided_vars = {} - if var_overrides: - for var_override in var_overrides: - if "=" not in var_override: - logging.error(f"Error: Invalid variable format '{var_override}'. Use KEY=VALUE format.") - sys.exit(1) - key, value = var_override.split("=", 1) - provided_vars[key] = value - - # Collect values for each variable - collected_vars = {} - if variables: - logging.info("Collecting values for template variables:") - for var in sorted(variables): - if var in provided_vars: - raw_value = provided_vars[var] - - else: - raw_value = await prompt_helper.collect_var_value(var) - logging.info(f" {var}: {raw_value}") - - collected_vars[var] = parse_input_string(raw_value) - - else: - logging.info("No variables found in template") - - # Render the template with collected variables - rendered_content = await template_engine.render_async(**collected_vars) - - # Output to file or stdout - if output_file: - output_path = os.path.abspath(output_file) - with open(output_path, "w") as f: - f.write(rendered_content) - logging.info(f"Rendered template saved to: {output_path}") - else: - logging.debug("\nRendered template:") - print(rendered_content) - - except TemplateParseError as e: - logging.error(f"Error: {e}") - sys.exit(1) - except Exception as e: - logging.error(f"Error: Failed to process template: {e}") - sys.exit(1) + finally: + # Clean up all MCP sessions + await session_manager.cleanup() diff --git a/commands/mcp.py b/commands/mcp.py index e1e7314..deafcd4 100644 --- a/commands/mcp.py +++ b/commands/mcp.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from mcp_client.config import SSEServerConfig, StdioServerConfig +from mcp_client import SSEServerConfig, StdioServerConfig from state import State @@ -42,9 +42,7 @@ async def handle_mcp(state: State, context: MCPCommandContext): key, value = env_var.split("=", 1) env_dict[key] = value else: - raise ValueError( - f"Invalid environment variable format: {env_var}. Use KEY=VALUE format." - ) + raise ValueError(f"Invalid environment variable format: {env_var}. Use KEY=VALUE format.") await handle_add_stdio( state, @@ -92,6 +90,4 @@ async def handle_add_stdio( async def handle_add_http(state: State, server_name: str, url: str): - print( - f"HTTP server support not implemented yet. Would add '{server_name}' with URL: {url}" - ) + print(f"HTTP server support not implemented yet. Would add '{server_name}' with URL: {url}") diff --git a/engine/__init__.py b/engine/__init__.py index 6273f49..98cece2 100644 --- a/engine/__init__.py +++ b/engine/__init__.py @@ -1,168 +1,5 @@ -from pathlib import Path -from typing import Any +"""Jinja2-based template processing with async support and MCP tool integration.""" -from jinja2 import Environment, FileSystemLoader, Template, meta, select_autoescape +from .template_engine import TemplateEngine, TemplateParseError -from engine.globals import create_mcp_resource_function, create_mcp_tool_function -from prompt import PromptHelper -from state import State - - -class TemplateParseError(Exception): - """Raised when template parsing fails""" - - pass - - -class TemplateEngine: - """Abstract away the jinja2 template engine with clean factory methods""" - - def __init__( - self, - env: Environment, - template: Template, - state: State, - prompt_helper: PromptHelper, - source_path: Path | None = None, - source_string: str | None = None, - ): - """Private constructor - use from_file() or from_string() instead""" - self.env = env - self.template = template - self._source_path = source_path - self._source_string = source_string - self._state = state - self._prompt_helper = prompt_helper - - # Add global functions to env to support MCP tools and resources - self.env.globals["call_tool"] = create_mcp_tool_function(self._state, self._prompt_helper) - self.env.globals["get_resource"] = create_mcp_resource_function(self._state) - - @classmethod - def from_file(cls, path: str | Path, state: State, prompt_helper: PromptHelper) -> "TemplateEngine": - """Create a TemplateEngine from a template file. - - Args: - path: Path to the template file - state: State object containing project configuration - - Returns: - TemplateEngine instance - - Raises: - FileNotFoundError: If template file doesn't exist - TemplateParseError: If template parsing fails - """ - path = Path(path) - - if not path.exists(): - raise FileNotFoundError(f"Template file not found: {path}") - - template_dir = path.parent - template_name = path.name - - env = Environment( - loader=FileSystemLoader(str(template_dir)), - autoescape=select_autoescape(), - enable_async=True, - ) - - try: - template = env.get_template(template_name) - except Exception as e: - raise TemplateParseError(f"Failed to load template from {path}: {e}") from e - - return cls( - env=env, template=template, state=state, source_path=path, source_string=None, prompt_helper=prompt_helper - ) - - @classmethod - def from_string( - cls, template_string: str, state: State, prompt_helper: PromptHelper, name: str = "" - ) -> "TemplateEngine": - """Create a TemplateEngine from a template string. - - Args: - template_string: The template content as a string - state: State object containing project configuration - name: Optional name for the template (for debugging) - - Returns: - TemplateEngine instance - - Raises: - TemplateParseError: If template parsing fails - """ - - env = Environment( - autoescape=select_autoescape(), - enable_async=True, - ) - - try: - template = env.from_string(template_string) - except Exception as e: - raise TemplateParseError(f"Failed to parse template string: {e}") from e - - # Store the name in the template for better error messages - template.name = name - - return cls( - env=env, - template=template, - state=state, - source_path=None, - source_string=template_string, - prompt_helper=prompt_helper, - ) - - @property - def source(self) -> str: - """Get the template source content.""" - if self._source_string is not None: - return self._source_string - - if self._source_path is not None: - with open(self._source_path, encoding="utf-8") as f: - return f.read() - - # Should not reach here with proper factory method usage - raise AssertionError("No template source available") - - @property - def path(self) -> Path | None: - """Get the template file path if loaded from file.""" - return self._source_path - - @property - def is_from_file(self) -> bool: - """Check if template was loaded from a file.""" - return self._source_path is not None - - def get_variables(self) -> set[str]: - """Get the free (undeclared) variables in the template. - - Returns: - Set of variable names that are referenced but not defined in the template - - Raises: - TemplateParseError: If template parsing fails - """ - try: - ast = self.env.parse(self.source) - except Exception as e: - raise TemplateParseError(f"Failed to parse template: {e}") from e - - # Find all undeclared variables - variables = meta.find_undeclared_variables(ast) - return variables - - async def render_async(self, *args: Any, **kwargs: Any) -> str: - return await self.template.render_async(*args, **kwargs) - - def __repr__(self) -> str: - if self._source_path: - return f"TemplateEngine(from_file={self._source_path})" - else: - name = self.template.name if hasattr(self.template, "name") else "" - return f"TemplateEngine(from_string, name={name})" +__all__ = ["TemplateEngine", "TemplateParseError"] diff --git a/engine/globals.py b/engine/globals.py index 1ab5243..cf4ce30 100644 --- a/engine/globals.py +++ b/engine/globals.py @@ -5,23 +5,20 @@ from mcp.shared.metadata_utils import get_display_name from pydantic import AnyUrl -from mcp_client.client_session_provider import get_client_session_by_server +from mcp_client import get_client_session_by_server from prompt import PromptHelper -from state import State from util.parse import parse_input_string -def create_mcp_tool_function(state: State, prompt_helper: PromptHelper): +def create_mcp_tool_function(prompt_helper: PromptHelper): """Create a call_mcp_tool function with state bound.""" async def call_mcp_tool(server: str, tool_name: str, args: dict) -> str | dict[str, Any]: logging.info(f"Calling MCP tool: {tool_name} on server: {server} with args: {args}") - async with get_client_session_by_server(server, state) as session: - await session.initialize() - + async with get_client_session_by_server(server) as session: tools = await session.list_tools() - # Call the tool with collected input try: + # Collect missing required arguments interactively full_arguments = await prompt_helper.get_full_args(tools, tool_name, args) logging.debug(f"Full arguments for tool {tool_name}: {full_arguments}") @@ -41,14 +38,12 @@ async def call_mcp_tool(server: str, tool_name: str, args: dict) -> str | dict[s return call_mcp_tool -def create_mcp_resource_function(state: State): +def create_mcp_resource_function(): """Create a get_mcp_resource function with state bound.""" async def get_mcp_resource(server: str, resource_uri: str) -> str | dict[str, Any]: logging.info(f"Getting MCP resource: {resource_uri} in server: {server}") - async with get_client_session_by_server(server, state) as session: - await session.initialize() - + async with get_client_session_by_server(server) as session: # Fetch the resource (first content item for now) try: resource = await session.read_resource(AnyUrl(resource_uri)) diff --git a/engine/template_engine.py b/engine/template_engine.py new file mode 100644 index 0000000..f5d2e4e --- /dev/null +++ b/engine/template_engine.py @@ -0,0 +1,164 @@ +from pathlib import Path +from typing import Any + +from jinja2 import Environment, FileSystemLoader, Template, meta, select_autoescape + +from engine.globals import create_mcp_resource_function, create_mcp_tool_function +from prompt import PromptHelper +from state import State + + +class TemplateParseError(Exception): + pass + + +class TemplateEngine: + def __init__( + self, + env: Environment, + template: Template, + state: State, + prompt_helper: PromptHelper, + source_path: Path | None = None, + source_string: str | None = None, + ): + """Private constructor - use from_file() or from_string() instead""" + self.env = env + self.template = template + self._source_path = source_path + self._source_string = source_string + self._state = state + self._prompt_helper = prompt_helper + + # Add global functions to env to support MCP tools and resources + self.env.globals["call_tool"] = create_mcp_tool_function(self._prompt_helper) + self.env.globals["get_resource"] = create_mcp_resource_function() + + @classmethod + def from_file(cls, path: str | Path, state: State, prompt_helper: PromptHelper) -> "TemplateEngine": + """Create a TemplateEngine from a template file. + + Args: + path: Path to the template file + state: State object containing project configuration + + Returns: + TemplateEngine instance + + Raises: + FileNotFoundError: If template file doesn't exist + TemplateParseError: If template parsing fails + """ + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Template file not found: {path}") + + template_dir = path.parent + template_name = path.name + + env = Environment( + loader=FileSystemLoader(str(template_dir)), + autoescape=select_autoescape(), + enable_async=True, + ) + + try: + template = env.get_template(template_name) + except Exception as e: + raise TemplateParseError(f"Failed to load template from {path}: {e}") from e + + return cls( + env=env, template=template, state=state, source_path=path, source_string=None, prompt_helper=prompt_helper + ) + + @classmethod + def from_string( + cls, template_string: str, state: State, prompt_helper: PromptHelper, name: str = "" + ) -> "TemplateEngine": + """Create a TemplateEngine from a template string. + + Args: + template_string: The template content as a string + state: State object containing project configuration + name: Optional name for the template (for debugging) + + Returns: + TemplateEngine instance + + Raises: + TemplateParseError: If template parsing fails + """ + + env = Environment( + autoescape=select_autoescape(), + enable_async=True, + ) + + try: + template = env.from_string(template_string) + except Exception as e: + raise TemplateParseError(f"Failed to parse template string: {e}") from e + + # Store the name in the template for better error messages + template.name = name + + return cls( + env=env, + template=template, + state=state, + source_path=None, + source_string=template_string, + prompt_helper=prompt_helper, + ) + + @property + def source(self) -> str: + """Get the template source content.""" + if self._source_string is not None: + return self._source_string + + if self._source_path is not None: + with open(self._source_path, encoding="utf-8") as f: + return f.read() + + # Should not reach here with proper factory method usage + raise AssertionError("No template source available") + + @property + def path(self) -> Path | None: + """Get the template file path if loaded from file.""" + return self._source_path + + @property + def is_from_file(self) -> bool: + """Check if template was loaded from a file.""" + return self._source_path is not None + + def get_variables(self) -> set[str]: + """Get the free (undeclared) variables in the template. + + Returns: + Set of variable names that are referenced but not defined in the template + + Raises: + TemplateParseError: If template parsing fails + """ + try: + ast = self.env.parse(self.source) + except Exception as e: + raise TemplateParseError(f"Failed to parse template: {e}") from e + + # Find all undeclared variables + variables = meta.find_undeclared_variables(ast) + return variables + + async def render_async(self, *args: Any, **kwargs: Any) -> str: + return await self.template.render_async(*args, **kwargs) + + def __repr__(self) -> str: + if self._source_path: + return f"TemplateEngine(from_file={self._source_path})" + else: + name = self.template.name if hasattr(self.template, "name") else "" + return f"TemplateEngine(from_string, name={name})" diff --git a/mcp_client/__init__.py b/mcp_client/__init__.py index e69de29..f7e7c86 100644 --- a/mcp_client/__init__.py +++ b/mcp_client/__init__.py @@ -0,0 +1,16 @@ +"""MCP (Model Context Protocol) client implementation.""" + +from .client_session_provider import get_client_session_by_server +from .config import MCPServersConfig, SSEServerConfig, StdioServerConfig +from .session_manager import MCPSessionManager, get_session_manager +from .token_storage import KeychainTokenStorageWithFallback + +__all__ = [ + "MCPServersConfig", + "SSEServerConfig", + "StdioServerConfig", + "MCPSessionManager", + "KeychainTokenStorageWithFallback", + "get_session_manager", + "get_client_session_by_server", +] diff --git a/mcp_client/client_session_provider.py b/mcp_client/client_session_provider.py index 485723d..2c629f0 100644 --- a/mcp_client/client_session_provider.py +++ b/mcp_client/client_session_provider.py @@ -1,3 +1,4 @@ +import webbrowser from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path @@ -5,46 +6,28 @@ from urllib.parse import parse_qs, urlparse from mcp import ClientSession, StdioServerParameters -from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.auth import OAuthClientProvider from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client -from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from mcp.shared.auth import OAuthClientMetadata from pydantic import AnyUrl -from .config import SSEServerConfig, StdioServerConfig +from auth_server import AuthServer +from util.terminal import display_hyperlink + from .mcp_logger import get_mcp_log_file +from .session_manager import get_session_manager if TYPE_CHECKING: from state import State -class InMemoryTokenStorage(TokenStorage): - """Demo In-memory token storage implementation.""" - - def __init__(self): - self.tokens: OAuthToken | None = None - self.client_info: OAuthClientInformationFull | None = None - - async def get_tokens(self) -> OAuthToken | None: - """Get stored tokens.""" - return self.tokens - - async def set_tokens(self, tokens: OAuthToken) -> None: - """Store tokens.""" - self.tokens = tokens - - async def get_client_info(self) -> OAuthClientInformationFull | None: - """Get stored client information.""" - return self.client_info - - async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: - """Store client information.""" - self.client_info = client_info - - async def handle_redirect(auth_url: str) -> None: - print(f"Visit: {auth_url}") + print(f"ContextKit requires authorization, opening {display_hyperlink(auth_url)}") + opened = webbrowser.open(auth_url) + if not opened: + print("Failed to open browser automatically. Please open the URL manually in your browser.") async def handle_callback() -> tuple[str, str | None]: @@ -62,17 +45,18 @@ async def get_stdio_session(server_params: StdioServerParameters, config_dir: Pa @asynccontextmanager -async def get_streamablehttp_session(server_url: str): +async def get_streamablehttp_session(server_url: str, server_name: str, state: "State"): + token_storage = state.get_token_storage(server_name) oauth_auth = OAuthClientProvider( server_url=server_url, client_metadata=OAuthClientMetadata( - client_name="Example MCP Client", + client_name="ContextKit MCP Client", redirect_uris=[AnyUrl("http://localhost:3000/callback")], grant_types=["authorization_code", "refresh_token"], response_types=["code"], scope="user", ), - storage=InMemoryTokenStorage(), + storage=token_storage, redirect_handler=handle_redirect, callback_handler=handle_callback, ) @@ -88,19 +72,22 @@ async def get_streamablehttp_session(server_url: str): @asynccontextmanager -async def get_sse_session(server_url: str): +async def get_sse_session(server_url: str, server_name: str, state: "State", auth_server: AuthServer | None = None): + if auth_server is None: + raise ValueError("AuthServer must be provided for SSE sessions") + + token_storage = state.get_token_storage(server_name) oauth_auth = OAuthClientProvider( server_url=server_url, client_metadata=OAuthClientMetadata( - client_name="Example MCP Client", - redirect_uris=[AnyUrl("http://localhost:41008/callback")], + client_name="ContextKit MCP Client", + redirect_uris=[AnyUrl(auth_server.callback_url)], grant_types=["authorization_code", "refresh_token"], response_types=["code"], - scope="user", ), - storage=InMemoryTokenStorage(), + storage=token_storage, redirect_handler=handle_redirect, - callback_handler=handle_callback, + callback_handler=auth_server.handle_callback, ) # Connect to a Server-Sent Events (SSE) server @@ -115,23 +102,9 @@ async def get_sse_session(server_url: str): @asynccontextmanager -async def get_client_session_by_server(server_name: str, state: "State") -> AsyncGenerator[ClientSession, None]: - # Find the server configuration by name - server_config = state.mcp_config.mcpServers.get(server_name) - if not server_config: - raise ValueError(f"Server '{server_name}' not found in configuration.") - if isinstance(server_config, StdioServerConfig): - async with get_stdio_session( - StdioServerParameters( - command=server_config.command, - args=server_config.args or [], - env=server_config.env, - ), - config_dir=state.config_dir, - ) as session: - yield session - elif isinstance(server_config, SSEServerConfig): - async with get_sse_session(server_config.url) as session: - yield session - else: - raise ValueError(f"Unsupported server type for '{server_name}': {type(server_config)}") +async def get_client_session_by_server(server_name: str) -> AsyncGenerator[ClientSession, None]: + session_manager = get_session_manager() + + if session_manager.is_initialized: + session = session_manager.get_session(server_name) + yield session diff --git a/mcp_client/config.py b/mcp_client/config.py index 02db027..cabf088 100644 --- a/mcp_client/config.py +++ b/mcp_client/config.py @@ -20,9 +20,7 @@ class StdioServerConfig(BaseServerConfig): type: str | None = Field(default="stdio", description="Server transport type") command: str = Field(..., description="Command to execute the server") args: list[str] | None = Field(default=None, description="Command line arguments") - env: dict[str, str] | None = Field( - default=None, description="Environment variables" - ) + env: dict[str, str] | None = Field(default=None, description="Environment variables") @field_validator("type") @classmethod @@ -59,9 +57,7 @@ class MCPServersConfig(BaseModel): def validate_server_names(cls, v): for server_name in v.keys(): if not server_name or len(server_name) > 250: - raise ValueError( - "Server name must not be empty and must not be longer than 250 characters" - ) + raise ValueError("Server name must not be empty and must not be longer than 250 characters") return v @model_validator(mode="before") diff --git a/mcp_client/session_manager.py b/mcp_client/session_manager.py new file mode 100644 index 0000000..76280cf --- /dev/null +++ b/mcp_client/session_manager.py @@ -0,0 +1,121 @@ +"""Session manager for pre-initialized MCP client sessions.""" + +import logging +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING + +from mcp import ClientSession + +from .config import SSEServerConfig, StdioServerConfig + +if TYPE_CHECKING: + from state import State + + +class MCPSessionManager: + """Manages pre-initialized MCP client sessions.""" + + def __init__(self): + self._sessions: dict[str, ClientSession] = {} + self._exit_stack: AsyncExitStack | None = None + self._initialized = False + + async def initialize_all_sessions(self, state: "State") -> None: + """Pre-initialize all MCP server sessions from config.""" + if self._initialized: + return + + if not state.mcp_config or not state.mcp_config.mcpServers: + logging.info("No MCP servers configured, skipping session initialization") + return + + self._exit_stack = AsyncExitStack() + + logging.info(f"Initializing {len(state.mcp_config.mcpServers)} MCP server sessions...") + + # Use shared auth server for all SSE connections during initialization + from auth_server import AuthServer + async with AuthServer() as auth_server: + for server_name, server_config in state.mcp_config.mcpServers.items(): + try: + logging.debug(f"Initializing session for server: {server_name}") + + if isinstance(server_config, StdioServerConfig): + from mcp import StdioServerParameters + + from .client_session_provider import get_stdio_session + + session_cm = get_stdio_session( + StdioServerParameters( + command=server_config.command, + args=server_config.args or [], + env=server_config.env, + ), + config_dir=state.config_dir, + ) + elif isinstance(server_config, SSEServerConfig): + from .client_session_provider import get_sse_session + + session_cm = get_sse_session(server_config.url, server_name, state, auth_server) + else: + raise ValueError(f"Unsupported server type for '{server_name}': {type(server_config)}") + + # Enter the session context manager and store the session + session = await self._exit_stack.enter_async_context(session_cm) + await session.initialize() + self._sessions[server_name] = session + + logging.debug(f"Successfully initialized session for server: {server_name}") + + except Exception as e: + logging.error(f"Failed to initialize MCP server '{server_name}': {e}") + await self.cleanup() + raise RuntimeError(f"Failed to initialize MCP server '{server_name}': {e}") from e + + self._initialized = True + logging.info(f"Successfully initialized {len(self._sessions)} MCP server sessions") + + def get_session(self, server_name: str) -> ClientSession: + """Get a pre-initialized session by server name.""" + if not self._initialized: + raise RuntimeError("Session manager not initialized. Call initialize_all_sessions() first.") + + session = self._sessions.get(server_name) + if not session: + available_servers = list(self._sessions.keys()) + raise ValueError(f"Server '{server_name}' not found. Available servers: {available_servers}") + + return session + + async def cleanup(self) -> None: + """Clean up all sessions and resources.""" + if self._exit_stack: + try: + await self._exit_stack.aclose() + except Exception as e: + logging.error(f"Error during session cleanup: {e}") + finally: + self._exit_stack = None + + self._sessions.clear() + self._initialized = False + logging.debug("MCP session manager cleaned up") + + @property + def is_initialized(self) -> bool: + """Check if the session manager is initialized.""" + return self._initialized + + @property + def server_names(self) -> list[str]: + """Get list of initialized server names.""" + return list(self._sessions.keys()) + + +# Global session manager instance +_session_manager = MCPSessionManager() + + +def get_session_manager() -> MCPSessionManager: + """Get the global session manager instance.""" + return _session_manager diff --git a/mcp_client/token_storage.py b/mcp_client/token_storage.py new file mode 100644 index 0000000..5844d51 --- /dev/null +++ b/mcp_client/token_storage.py @@ -0,0 +1,193 @@ +import json +import logging +import os +import stat +from pathlib import Path +from typing import Any + +import keyring +from mcp.client.auth import TokenStorage +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +logger = logging.getLogger(__name__) + +KEYCHAIN_STORAGE_VERSION = 2 + + +class KeychainTokenStorageWithFallback(TokenStorage): + """ + Dual storage strategy for OAuth tokens: + 1. Primary: System keychain (macOS Keychain, Windows Credential Manager, Linux Secret Service) + 2. Fallback: Secured file storage if keychain unavailable + """ + + def __init__(self, state, server_name: str): + """Initialize with state object containing configuration.""" + self.state = state + self.server_name = server_name + self.service_name = f"contextkit_{KEYCHAIN_STORAGE_VERSION}" + self.username = f"secrets_{server_name}" + self.keychain_enabled = True + + # Determine fallback storage location + self.fallback_dir = self._get_fallback_dir() + self.fallback_file = self.fallback_dir / f"secrets_{server_name}.json" + + # Test keychain availability on init + self._test_keychain_availability() + + def _get_fallback_dir(self) -> Path: + """Get platform-specific application data directory.""" + + if os.name == "nt": # Windows + base = Path(os.environ.get("APPDATA", Path.home() / "AppData" / "Roaming")) + elif os.name == "posix": + if os.environ.get("XDG_CONFIG_HOME"): + base = Path(os.environ["XDG_CONFIG_HOME"]) + else: + base = Path.home() / ".config" + else: + base = Path.home() / ".config" + + return base / "contextkit" + + def _test_keychain_availability(self): + """Test if keychain is available and working.""" + try: + # Try a simple operation to test keychain availability + keyring.get_password(self.service_name, "test") + logger.debug("Keychain access is available") + except Exception as e: + logger.warning(f"Keychain access failed, will use file fallback: {e}") + self.keychain_enabled = False + + async def get_tokens(self) -> OAuthToken | None: + """Get stored tokens from keychain or fallback storage.""" + # Try keychain first + if self.keychain_enabled: + try: + token_data = keyring.get_password(self.service_name, f"{self.username}_tokens") + if token_data: + token_dict = json.loads(token_data) + return OAuthToken(**token_dict) + except Exception as e: + logger.warning(f"Failed to get tokens from keychain: {e}") + self.keychain_enabled = False + + # Fallback to file storage + return await self._get_tokens_from_file() + + async def set_tokens(self, tokens: OAuthToken) -> None: + """Store tokens in keychain or fallback storage.""" + token_data = json.dumps(tokens.model_dump()) + + # Try keychain first + if self.keychain_enabled: + try: + keyring.set_password(self.service_name, f"{self.username}_tokens", token_data) + logger.debug("Tokens stored in keychain") + return + except Exception as e: + logger.warning(f"Failed to store tokens in keychain: {e}") + self.keychain_enabled = False + + # Fallback to file storage + await self._set_tokens_to_file(tokens) + + async def get_client_info(self) -> OAuthClientInformationFull | None: + """Get stored client information from keychain or fallback storage.""" + # Try keychain first + if self.keychain_enabled: + try: + client_data = keyring.get_password(self.service_name, f"{self.username}_client_info") + if client_data: + client_dict = json.loads(client_data) + return OAuthClientInformationFull(**client_dict) + except Exception as e: + logger.warning(f"Failed to get client info from keychain: {e}") + self.keychain_enabled = False + + # Fallback to file storage + return await self._get_client_info_from_file() + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + """Store client information in keychain or fallback storage.""" + client_data = json.dumps(client_info.model_dump(mode="json")) + + # Try keychain first + if self.keychain_enabled: + try: + keyring.set_password(self.service_name, f"{self.username}_client_info", client_data) + logger.debug("Client info stored in keychain") + return + except Exception as e: + logger.warning(f"Failed to store client info in keychain: {e}") + self.keychain_enabled = False + + # Fallback to file storage + await self._set_client_info_to_file(client_info) + + async def _get_tokens_from_file(self) -> OAuthToken | None: + """Get tokens from fallback file storage.""" + try: + if not self.fallback_file.exists(): + return None + + with open(self.fallback_file) as f: + data = json.load(f) + token_data = data.get("tokens") + if token_data: + return OAuthToken(**token_data) + except Exception as e: + logger.error(f"Failed to read tokens from file: {e}") + + return None + + async def _set_tokens_to_file(self, tokens: OAuthToken) -> None: + """Store tokens in fallback file storage.""" + await self._update_file_data({"tokens": tokens.model_dump()}) + + async def _get_client_info_from_file(self) -> OAuthClientInformationFull | None: + """Get client info from fallback file storage.""" + try: + if not self.fallback_file.exists(): + return None + + with open(self.fallback_file) as f: + data = json.load(f) + client_data = data.get("client_info") + if client_data: + return OAuthClientInformationFull(**client_data) + except Exception as e: + logger.error(f"Failed to read client info from file: {e}") + + return None + + async def _set_client_info_to_file(self, client_info: OAuthClientInformationFull) -> None: + """Store client info in fallback file storage.""" + await self._update_file_data({"client_info": client_info.model_dump(mode="json")}) + + async def _update_file_data(self, update_data: dict[str, Any]) -> None: + """Update the fallback file with new data.""" + try: + # Ensure directory exists + self.fallback_dir.mkdir(parents=True, exist_ok=True) + + # Read existing data + existing_data = {} + if self.fallback_file.exists(): + with open(self.fallback_file) as f: + existing_data = json.load(f) + + # Update with new data + existing_data.update(update_data) + + # Write back to file with restricted permissions from creation + fd = os.open(self.fallback_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IRUSR | stat.S_IWUSR) + with os.fdopen(fd, "w") as f: + json.dump(existing_data, f, indent=2) + logger.debug(f"Data stored in fallback file: {self.fallback_file}") + + except Exception as e: + logger.error(f"Failed to update fallback file: {e}") + raise diff --git a/prompt/__init__.py b/prompt/__init__.py index af8dfe1..d75bf89 100644 --- a/prompt/__init__.py +++ b/prompt/__init__.py @@ -1,157 +1,5 @@ -import logging +"""Interactive prompt helpers using questionary for collecting template variables.""" -import questionary +from .prompt_helper import PromptHelper -from state import State - -# Add interactive prompt helpers using questionary that will collect values for -# unspecified template variables. - - -class PromptHelper: - def __init__( - self, - state: State, - ): - self._state = state - - async def collect_var_value(self, var_name: str) -> str: - """Collect a value for a variable using questionary.""" - - return await questionary.text(f"Please provide a value for '{var_name}':").ask_async() - - async def get_full_args(self, tools, tool_name, args): - """ - Get full arguments for the tool call, using collect_tool_input() to collect missing required fields. - - :param tools: List of available tools - :param tool_name: Name of the tool to call - :param args: Arguments provided by the user - :return: Full arguments including defaults and required fields - """ - selected_tool = next((t for t in tools.tools if t.name == tool_name), None) - if not selected_tool: - raise ValueError(f"Tool '{tool_name}' not found.") - - input_schema = selected_tool.inputSchema - if not input_schema: - return args - - # Use collect_tool_input to fill in missing required fields - full_args = await self.collect_tool_input(input_schema, args, include_optional=False) - - return full_args - - async def collect_tool_input(self, input_schema, existing_args=None, include_optional=True): - """ - Collect user input based on a schema using Questionary. - Skip fields that are already provided in existing_args. - - Args: - input_schema (dict): Schema with 'properties' and 'required' keys - existing_args (dict, optional): Already collected values - - Returns: - dict: Dictionary with field names as keys and collected values - - Example: - schema = { - 'properties': { - 'a': {'title': 'A', 'type': 'integer'}, - 'b': {'title': 'B', 'type': 'integer'} - }, - 'required': ['a', 'b'] - } - result = collect_input(schema) # Returns something like {"a": 5, "b": 3} - """ - properties = input_schema["properties"] - required = input_schema.get("required", []) - - if existing_args is None: - existing_args = {} - - values = existing_args.copy() - - for field_name, field_info in properties.items(): - # Skip if field is already provided in existing_args - if field_name in existing_args: - continue - - logging.info(f"Collecting input for field: {field_info}") - title = field_info["title"] if "title" in field_info else field_name - field_type = field_info["type"] if "type" in field_info else "string" - field_desc = field_info.get("description", None) - is_required = field_name in required - - if not include_optional and not is_required: - # Skip optional fields if not requested - continue - - # Create prompt text - prompt_text = title - if is_required: - prompt_text += " (required)" - else: - prompt_text += " (optional)" - prompt_text += ":" - - if field_type == "integer": - # Create validator for integer fields - def make_integer_validator(required): - def validate_integer(value): - # Allow empty for optional fields - if not value.strip(): - if required: - return "This field is required" - return True - # Validate integer format - try: - int(value) - return True - except ValueError: - return "Please enter a valid integer" - - return validate_integer - - result = await questionary.text(prompt_text, validate=make_integer_validator(is_required)).ask_async() - - # Process the result - if result and result.strip(): - values[field_name] = int(result) - elif is_required: - # This shouldn't happen due to validation - raise ValueError(f"Required field {field_name} is missing") - # Optional fields that are empty are not included in output - - elif field_type == "string": - # Create validator for string fields - def make_string_validator(required): - def validate_string(value): - if not value.strip() and required: - return "This field is required" - return True - - return validate_string - - result = await questionary.text( - prompt_text, - validate=make_string_validator(is_required), - instruction=field_desc, - ).ask_async() - - # Process the result - if result and result.strip(): - values[field_name] = result - elif is_required: - # This shouldn't happen due to validation - raise ValueError(f"Required field {field_name} is missing") - # Optional fields that are empty are not included in output - - elif field_type == "boolean": - # Create a yes/no prompt for boolean fields - result = await questionary.confirm(prompt_text, default=False, qmark="?").ask_async() - - # Store the boolean value - values[field_name] = result - - return values +__all__ = ["PromptHelper"] diff --git a/prompt/prompt_helper.py b/prompt/prompt_helper.py new file mode 100644 index 0000000..d798312 --- /dev/null +++ b/prompt/prompt_helper.py @@ -0,0 +1,154 @@ +import logging + +import questionary + +from state import State + + +class PromptHelper: + def __init__( + self, + state: State, + ): + self._state = state + + async def collect_var_value(self, var_name: str) -> str: + """Collect a value for a variable using questionary.""" + + return await questionary.text(f"Please provide a value for '{var_name}':").ask_async() + + async def get_full_args(self, tools, tool_name, args): + """ + Get full arguments for the tool call, using collect_tool_input() to collect missing required fields. + + :param tools: List of available tools + :param tool_name: Name of the tool to call + :param args: Arguments provided by the user + :return: Full arguments including defaults and required fields + """ + selected_tool = next((t for t in tools.tools if t.name == tool_name), None) + if not selected_tool: + raise ValueError(f"Tool '{tool_name}' not found.") + + input_schema = selected_tool.inputSchema + if not input_schema: + return args + + # Use collect_tool_input to fill in missing required fields + full_args = await self.collect_tool_input(input_schema, args, include_optional=False) + + return full_args + + async def collect_tool_input(self, input_schema, existing_args=None, include_optional=True): + """ + Collect user input based on a schema using Questionary. + Skip fields that are already provided in existing_args. + + Args: + input_schema (dict): Schema with 'properties' and 'required' keys + existing_args (dict, optional): Already collected values + + Returns: + dict: Dictionary with field names as keys and collected values + + Example: + schema = { + 'properties': { + 'a': {'title': 'A', 'type': 'integer'}, + 'b': {'title': 'B', 'type': 'integer'} + }, + 'required': ['a', 'b'] + } + result = collect_input(schema) # Returns something like {"a": 5, "b": 3} + """ + properties = input_schema["properties"] + required = input_schema.get("required", []) + + if existing_args is None: + existing_args = {} + + values = existing_args.copy() + + for field_name, field_info in properties.items(): + # Skip if field is already provided in existing_args + if field_name in existing_args: + continue + + logging.info(f"Collecting input for field: {field_info}") + title = field_info["title"] if "title" in field_info else field_name + field_type = field_info["type"] if "type" in field_info else "string" + field_desc = field_info.get("description", None) + is_required = field_name in required + + if not include_optional and not is_required: + # Skip optional fields if not requested + continue + + # Create prompt text + prompt_text = title + if is_required: + prompt_text += " (required)" + else: + prompt_text += " (optional)" + prompt_text += ":" + + if field_type == "integer": + # Create validator for integer fields + def make_integer_validator(required): + def validate_integer(value): + # Allow empty for optional fields + if not value.strip(): + if required: + return "This field is required" + return True + # Validate integer format + try: + int(value) + return True + except ValueError: + return "Please enter a valid integer" + + return validate_integer + + result = await questionary.text(prompt_text, validate=make_integer_validator(is_required)).ask_async() + + # Process the result + if result and result.strip(): + values[field_name] = int(result) + elif is_required: + # This shouldn't happen due to validation + raise ValueError(f"Required field {field_name} is missing") + # Optional fields that are empty are not included in output + + elif field_type == "string": + # Create validator for string fields + def make_string_validator(required): + def validate_string(value): + if not value.strip() and required: + return "This field is required" + return True + + return validate_string + + result = await questionary.text( + prompt_text, + validate=make_string_validator(is_required), + instruction=field_desc, + ).ask_async() + + # Process the result + if result and result.strip(): + values[field_name] = result + elif is_required: + # This shouldn't happen due to validation + raise ValueError(f"Required field {field_name} is missing") + # Optional fields that are empty are not included in output + + elif field_type == "boolean": + # Create a yes/no prompt for boolean fields + result = await questionary.confirm(prompt_text, default=False, qmark="?").ask_async() + + # Store the boolean value + values[field_name] = result + + return values diff --git a/pyproject.toml b/pyproject.toml index 05e00e8..5d48e6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,10 +5,13 @@ description = "A CLI tool and MCP client, used to create spec files for AI codin readme = "README.md" requires-python = ">=3.12" dependencies = [ + "fastapi>=0.104.0", "jinja2>=3.1.6", - "mcp[cli]>=1.12.4", + "keyring>=24.0.0", + "mcp[cli]>=1.13.1", "pydantic>=2.11.7", "questionary>=2.1.0", + "uvicorn>=0.24.0", ] [project.optional-dependencies] @@ -19,6 +22,9 @@ test = [ [dependency-groups] dev = [ + "debugpy>=1.8.16", + "pytest>=8.4.1", + "pytest-asyncio>=1.1.0", "ruff>=0.12.7", ] diff --git a/state.py b/state.py index 8856227..15c37c0 100644 --- a/state.py +++ b/state.py @@ -1,7 +1,7 @@ import json from pathlib import Path -from mcp_client.config import MCPServersConfig +from mcp_client import KeychainTokenStorageWithFallback, MCPServersConfig class State: @@ -10,6 +10,7 @@ def __init__(self): self.config_dir = self.project_root / ".cxk" if self.project_root else None self.config_file = self.config_dir / "mcp.json" if self.config_dir else None self._mcp_config: MCPServersConfig | None = None + self._token_storages: dict[str, KeychainTokenStorageWithFallback] = {} def _find_git_root(self) -> Path | None: current = Path.cwd() @@ -67,3 +68,9 @@ def initialize_project(self): self.config_dir.mkdir(exist_ok=True) if self.config_file and not self.config_file.exists(): self.save_mcp_config() + + def get_token_storage(self, server_name: str) -> KeychainTokenStorageWithFallback: + """Get or create token storage instance for a specific server.""" + if server_name not in self._token_storages: + self._token_storages[server_name] = KeychainTokenStorageWithFallback(self, server_name) + return self._token_storages[server_name] diff --git a/task.md b/task.md new file mode 100644 index 0000000..2841db8 --- /dev/null +++ b/task.md @@ -0,0 +1,37 @@ + +## Storing OAuth Token for reuse + +### Purpose: +OAuth access tokens are currently stored in memory (see InMemoryTokenStorage in mcp_client/client_session_provider.py). The purpose of this task is +to store them in a secure storage for reuse. + +### High level solution: + +Dual Storage Strategy: + 1. Primary: System keychain (macOS Keychain, Windows Credential Manager, Linux Secret Service) + 2. Fallback: Secured file storage if keychain unavailable + + Key Features: + - Automatic fallback: Falls back to file storage if keychain access fails + - Configurable security: Can disable keychain and use file-only storage + - Cross-platform support: Works across different operating systems + + Storage Locations + + System Keychain + + - Service Name: "contextkit" + - Username: "secrets" + - Storage: Platform-specific secure storage (Keychain Access on macOS, etc.) + + File Fallback + + - Location: ~/.config/contextkit/secrets.yaml (or platform-specific app data directory) + - Format: YAML-serialized credential data + - Permissions: Restricted file permissions for security + + +### Scope of task: +- Create a new class KeychainTokenStorageWithFallback under mcp_client/token_storage.py folder to manage the OAuth tokens (replaces InMemoryTokenStorage and impl TokenStorage). +- The new class is initialized with the State class in state.py which will be passed. The state can contain any configuration needed for the storage like path names, names of service names in keychain, etc. +- Implement the dual storage strategy as described above. \ No newline at end of file diff --git a/tests/README.md b/tests/README.md index 010a9cd..43bf747 100644 --- a/tests/README.md +++ b/tests/README.md @@ -61,4 +61,15 @@ uv run cxk.py create-spec tests/templates/spec5.md ``` uv run cxk.py create-spec tests/templates/spec6.md --var name=MrBean -``` \ No newline at end of file +``` + +### With SSE-based MCP resources: +``` +uv run cxk.py create-spec tests/templates/spec7.md +uv run -m debugpy --listen 5678 --wait-for-client cxk.py create-spec tests/templates/spec7.md +``` + +### With multiple MCP resources and tools: +``` +uv run cxk.py create-spec tests/templates/spec8.md +``` \ No newline at end of file diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 23520a6..ae2b8d2 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -1,4 +1,5 @@ import json +import shutil import subprocess import tempfile from pathlib import Path @@ -7,7 +8,7 @@ class TestCLI: - @pytest.fixture + @pytest.fixture(scope="session") def temp_git_repo(self): """Create a temporary git repository for testing.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -29,6 +30,16 @@ def temp_git_repo(self): yield repo_path + @pytest.fixture(autouse=True) + def cleanup_cxk_folder(self, temp_git_repo): + """Clean .cxk folder after each test to ensure test isolation.""" + yield # Run the test + + # Clean up .cxk folder after test + cxk_dir = temp_git_repo / ".cxk" + if cxk_dir.exists(): + shutil.rmtree(cxk_dir) + @pytest.fixture def temp_non_git_dir(self): """Create a temporary directory that is NOT a git repository.""" @@ -282,7 +293,7 @@ def test_cli_help(self): assert "mcp" in result.stdout assert "create-spec" in result.stdout - def test_create_spec_with_variables(self, temp_non_git_dir): + def test_create_spec_with_variables(self, temp_git_repo): """Test create-spec with a template containing variables.""" # Create a test template with variables template_content = """ @@ -290,11 +301,11 @@ def test_create_spec_with_variables(self, temp_non_git_dir): Your age is {{ age }} and you live in {{ city }}. Today's weather is {{ weather.condition }} with temperature {{ weather.temp }}. """ - template_file = temp_non_git_dir / "test_template.j2" + template_file = temp_git_repo / "test_template.j2" template_file.write_text(template_content) # Run create-spec command with test runner to patch collect_var_value - result = self.run_cli(["create-spec", "--verbose", str(template_file)], use_test_runner=True) + result = self.run_cli(["create-spec", "--verbose", str(template_file)], use_test_runner=True, cwd=temp_git_repo) assert result.returncode == 0 @@ -303,15 +314,15 @@ def test_create_spec_with_variables(self, temp_non_git_dir): assert "Your age is 25 and you live in New York." in result.stdout assert "Today's weather is sunny with temperature 75F." in result.stdout - def test_create_spec_no_variables(self, temp_non_git_dir): + def test_create_spec_no_variables(self, temp_git_repo): """Test create-spec with a template containing no variables.""" # Create a test template without variables template_content = "This is a static template with no variables." - template_file = temp_non_git_dir / "static_template.j2" + template_file = temp_git_repo / "static_template.j2" template_file.write_text(template_content) # Run create-spec command with test runner (patching won't affect this case) - result = self.run_cli(["create-spec", "--verbose", str(template_file)], use_test_runner=True) + result = self.run_cli(["create-spec", "--verbose", str(template_file)], cwd=temp_git_repo, use_test_runner=True) assert result.returncode == 0 assert "No variables found in template" in result.stderr @@ -319,16 +330,16 @@ def test_create_spec_no_variables(self, temp_non_git_dir): # Verify rendered template output for static template assert "This is a static template with no variables." in result.stdout - def test_create_spec_relative_path(self, temp_non_git_dir): + def test_create_spec_relative_path(self, temp_git_repo): """Test create-spec with a relative path (filename only).""" # Create a test template template_content = "Hello {{ username }}!" - template_file = temp_non_git_dir / "relative_template.j2" + template_file = temp_git_repo / "relative_template.j2" template_file.write_text(template_content) # Run create-spec command with just the filename (relative path) using test runner result = self.run_cli( - ["create-spec", "--verbose", "relative_template.j2"], cwd=temp_non_git_dir, use_test_runner=True + ["create-spec", "--verbose", "relative_template.j2"], cwd=temp_git_repo, use_test_runner=True ) assert result.returncode == 0 @@ -337,39 +348,41 @@ def test_create_spec_relative_path(self, temp_non_git_dir): # Verify rendered template output assert "Hello testuser!" in result.stdout - def test_create_spec_file_not_found(self, temp_non_git_dir): + def test_create_spec_file_not_found(self, temp_git_repo): """Test create-spec with non-existent template file.""" - result = self.run_cli(["create-spec", "non_existent.j2"], cwd=temp_non_git_dir, use_test_runner=True) + result = self.run_cli(["create-spec", "non_existent.j2"], cwd=temp_git_repo, use_test_runner=True) assert result.returncode != 0 assert "Error: Template file 'non_existent.j2' not found" in result.stderr - def test_create_spec_invalid_template(self, temp_non_git_dir): + def test_create_spec_invalid_template(self, temp_git_repo): """Test create-spec with invalid template syntax.""" # Create a template with invalid Jinja2 syntax template_content = "Hello {{ name with invalid syntax!" - template_file = temp_non_git_dir / "invalid_template.j2" + template_file = temp_git_repo / "invalid_template.j2" template_file.write_text(template_content) # Run create-spec command with test runner - result = self.run_cli(["create-spec", str(template_file)], use_test_runner=True) + result = self.run_cli(["create-spec", str(template_file)], cwd=temp_git_repo, use_test_runner=True) assert result.returncode != 0 - def test_create_spec_with_output_file(self, temp_non_git_dir): + def test_create_spec_with_output_file(self, temp_git_repo): """Test create-spec with --output flag saves to file.""" # Create a test template with variables template_content = """Hello {{ name }}! Your age is {{ age }} and you live in {{ city }}.""" - template_file = temp_non_git_dir / "output_test_template.j2" + template_file = temp_git_repo / "output_test_template.j2" template_file.write_text(template_content) # Define output file - output_file = temp_non_git_dir / "rendered_spec.md" + output_file = temp_git_repo / "rendered_spec.md" # Run create-spec command with --output flag result = self.run_cli( - ["create-spec", "--verbose", str(template_file), "--output", str(output_file)], use_test_runner=True + ["create-spec", "--verbose", str(template_file), "--output", str(output_file)], + cwd=temp_git_repo, + use_test_runner=True, ) assert result.returncode == 0 @@ -380,36 +393,38 @@ def test_create_spec_with_output_file(self, temp_non_git_dir): assert "Hello John!" in content assert "Your age is 25 and you live in New York." in content - def test_create_spec_output_file_relative_path(self, temp_non_git_dir): + def test_create_spec_output_file_relative_path(self, temp_git_repo): """Test create-spec with --output using relative path.""" # Create a test template template_content = "Template for {{ username }}" - template_file = temp_non_git_dir / "relative_output_template.j2" + template_file = temp_git_repo / "relative_output_template.j2" template_file.write_text(template_content) # Run create-spec command with relative output path result = self.run_cli( ["create-spec", "--verbose", str(template_file), "--output", "output.md"], - cwd=temp_non_git_dir, + cwd=temp_git_repo, use_test_runner=True, ) assert result.returncode == 0 # Verify file was created with absolute path in message - output_file = temp_non_git_dir / "output.md" + output_file = temp_git_repo / "output.md" assert output_file.exists() assert "Template for testuser" in output_file.read_text() - def test_create_spec_stdout_vs_file_output(self, temp_non_git_dir): + def test_create_spec_stdout_vs_file_output(self, temp_git_repo): """Test that stdout and file output contain the same content.""" # Create a test template template_content = "Hello {{ name }}! You are {{ age }} years old." - template_file = temp_non_git_dir / "comparison_template.j2" + template_file = temp_git_repo / "comparison_template.j2" template_file.write_text(template_content) # Run without --output (stdout) - result_stdout = self.run_cli(["create-spec", "--verbose", str(template_file)], use_test_runner=True) + result_stdout = self.run_cli( + ["create-spec", "--verbose", str(template_file)], cwd=temp_git_repo, use_test_runner=True + ) # Extract rendered content from stdout stdout_lines = result_stdout.stdout.split("\n") @@ -436,9 +451,11 @@ def test_create_spec_stdout_vs_file_output(self, temp_non_git_dir): stdout_rendered = "\n".join(template_lines).strip() # Run with --output (file) - output_file = temp_non_git_dir / "comparison_output.md" + output_file = temp_git_repo / "comparison_output.md" result_file = self.run_cli( - ["create-spec", "--verbose", str(template_file), "--output", str(output_file)], use_test_runner=True + ["create-spec", "--verbose", str(template_file), "--output", str(output_file)], + cwd=temp_git_repo, + use_test_runner=True, ) assert result_stdout.returncode == 0 @@ -449,16 +466,18 @@ def test_create_spec_stdout_vs_file_output(self, temp_non_git_dir): assert stdout_rendered == file_content assert "Hello John! You are 25 years old." in file_content - def test_create_spec_with_var_override_single(self, temp_non_git_dir): + def test_create_spec_with_var_override_single(self, temp_git_repo): """Test create-spec with single --var override.""" # Create a test template with variables template_content = "Hello {{ name }}! You are {{ age }} years old." - template_file = temp_non_git_dir / "var_override_template.j2" + template_file = temp_git_repo / "var_override_template.j2" template_file.write_text(template_content) # Run create-spec command with --var override result = self.run_cli( - ["create-spec", "--verbose", str(template_file), "--var", "name=Alice"], use_test_runner=True + ["create-spec", "--verbose", str(template_file), "--var", "name=Alice"], + cwd=temp_git_repo, + use_test_runner=True, ) assert result.returncode == 0 @@ -466,16 +485,17 @@ def test_create_spec_with_var_override_single(self, temp_non_git_dir): # Verify rendered template output assert "Hello Alice! You are 25 years old." in result.stdout - def test_create_spec_with_var_override_multiple(self, temp_non_git_dir): + def test_create_spec_with_var_override_multiple(self, temp_git_repo): """Test create-spec with multiple --var overrides.""" # Create a test template with variables template_content = "Hello {{ name }}! You are {{ age }} years old and live in {{ city }}." - template_file = temp_non_git_dir / "multiple_var_override_template.j2" + template_file = temp_git_repo / "multiple_var_override_template.j2" template_file.write_text(template_content) # Run create-spec command with multiple --var overrides result = self.run_cli( ["create-spec", "--verbose", str(template_file), "--var", "name=Bob", "--var", "city=Boston"], + cwd=temp_git_repo, use_test_runner=True, ) @@ -484,11 +504,11 @@ def test_create_spec_with_var_override_multiple(self, temp_non_git_dir): # Verify rendered template output assert "Hello Bob! You are 25 years old and live in Boston." in result.stdout - def test_create_spec_with_var_override_all_variables(self, temp_non_git_dir): + def test_create_spec_with_var_override_all_variables(self, temp_git_repo): """Test create-spec with all variables provided via --var (no interactive prompts).""" # Create a test template with variables template_content = "{{ greeting }} {{ name }}! Your score is {{ score }}." - template_file = temp_non_git_dir / "all_var_override_template.j2" + template_file = temp_git_repo / "all_var_override_template.j2" template_file.write_text(template_content) # Run create-spec command with all variables provided @@ -504,6 +524,7 @@ def test_create_spec_with_var_override_all_variables(self, temp_non_git_dir): "--var", "score=100", ], + cwd=temp_git_repo, use_test_runner=True, ) @@ -512,17 +533,19 @@ def test_create_spec_with_var_override_all_variables(self, temp_non_git_dir): # Verify rendered template output assert "Hi Charlie! Your score is 100." in result.stdout - def test_create_spec_with_var_override_json_value(self, temp_non_git_dir): + def test_create_spec_with_var_override_json_value(self, temp_git_repo): """Test create-spec with --var containing JSON value.""" # Create a test template with JSON variable template_content = "User: {{ user.name }}, Email: {{ user.email }}" - template_file = temp_non_git_dir / "json_var_override_template.j2" + template_file = temp_git_repo / "json_var_override_template.j2" template_file.write_text(template_content) # Run create-spec command with JSON --var json_value = '{"name": "Dave", "email": "dave@example.com"}' result = self.run_cli( - ["create-spec", "--verbose", str(template_file), "--var", f"user={json_value}"], use_test_runner=True + ["create-spec", "--verbose", str(template_file), "--var", f"user={json_value}"], + cwd=temp_git_repo, + use_test_runner=True, ) assert result.returncode == 0 @@ -530,31 +553,35 @@ def test_create_spec_with_var_override_json_value(self, temp_non_git_dir): # Verify rendered template output assert "User: Dave, Email: dave@example.com" in result.stdout - def test_create_spec_with_var_invalid_format(self, temp_non_git_dir): + def test_create_spec_with_var_invalid_format(self, temp_git_repo): """Test create-spec with invalid --var format.""" # Create a test template template_content = "Hello {{ name }}!" - template_file = temp_non_git_dir / "invalid_var_format_template.j2" + template_file = temp_git_repo / "invalid_var_format_template.j2" template_file.write_text(template_content) # Run create-spec command with invalid --var format (missing =) result = self.run_cli( - ["create-spec", "--verbose", str(template_file), "--var", "invalid_format"], use_test_runner=True + ["create-spec", "--verbose", str(template_file), "--var", "invalid_format"], + cwd=temp_git_repo, + use_test_runner=True, ) assert result.returncode != 0 assert "Error: Invalid variable format 'invalid_format'. Use KEY=VALUE format." in result.stderr - def test_create_spec_with_var_equals_in_value(self, temp_non_git_dir): + def test_create_spec_with_var_equals_in_value(self, temp_git_repo): """Test create-spec with --var value containing equals sign.""" # Create a test template template_content = "Equation: {{ equation }}" - template_file = temp_non_git_dir / "equals_in_value_template.j2" + template_file = temp_git_repo / "equals_in_value_template.j2" template_file.write_text(template_content) # Run create-spec command with --var value containing equals result = self.run_cli( - ["create-spec", "--verbose", str(template_file), "--var", "equation=x=y+z"], use_test_runner=True + ["create-spec", "--verbose", str(template_file), "--var", "equation=x=y+z"], + cwd=temp_git_repo, + use_test_runner=True, ) assert result.returncode == 0 @@ -562,15 +589,15 @@ def test_create_spec_with_var_equals_in_value(self, temp_non_git_dir): # Verify rendered template output assert "Equation: x=y+z" in result.stdout - def test_create_spec_with_var_and_output_file(self, temp_non_git_dir): + def test_create_spec_with_var_and_output_file(self, temp_git_repo): """Test create-spec with --var and --output together.""" # Create a test template template_content = "Project: {{ project }}, Version: {{ version }}" - template_file = temp_non_git_dir / "var_and_output_template.j2" + template_file = temp_git_repo / "var_and_output_template.j2" template_file.write_text(template_content) # Define output file - output_file = temp_non_git_dir / "var_output.md" + output_file = temp_git_repo / "var_output.md" # Run create-spec command with both --var and --output result = self.run_cli( @@ -585,6 +612,7 @@ def test_create_spec_with_var_and_output_file(self, temp_non_git_dir): "--output", str(output_file), ], + cwd=temp_git_repo, use_test_runner=True, ) @@ -596,7 +624,7 @@ def test_create_spec_with_var_and_output_file(self, temp_non_git_dir): content = output_file.read_text() assert "Project: MyApp, Version: 1.0.0" in content - def test_create_spec_pipe_mode(self, temp_non_git_dir): + def test_create_spec_pipe_mode(self, temp_git_repo): """Test create-spec with stdin pipe mode (no template file argument).""" # Template content to pipe via stdin template_content = ( @@ -614,7 +642,7 @@ def test_create_spec_pipe_mode(self, temp_non_git_dir): "--var", "additional_context=test context", ], - cwd=temp_non_git_dir, + cwd=temp_git_repo, input=template_content, ) @@ -627,13 +655,13 @@ def test_create_spec_pipe_mode(self, temp_non_git_dir): assert "## Additional context" in result.stdout assert "test context" in result.stdout - def test_create_spec_pipe_mode_with_output_file(self, temp_non_git_dir): + def test_create_spec_pipe_mode_with_output_file(self, temp_git_repo): """Test create-spec with stdin pipe mode and --output flag.""" # Template content to pipe via stdin template_content = "Piped template: {{ message }}" # Define output file - output_file = temp_non_git_dir / "piped_output.md" + output_file = temp_git_repo / "piped_output.md" # Run create-spec command with stdin and --output result = self.run_cli( @@ -645,7 +673,7 @@ def test_create_spec_pipe_mode_with_output_file(self, temp_non_git_dir): "--output", str(output_file), ], - cwd=temp_non_git_dir, + cwd=temp_git_repo, input=template_content, ) @@ -798,4 +826,3 @@ def test_create_spec_with_mcp_resource(self, temp_git_repo): # Verify template structure is preserved assert "# Task Template" in result.stdout assert "## Information from greeting service" in result.stdout - diff --git a/tests/templates/spec7.md b/tests/templates/spec7.md new file mode 100644 index 0000000..c93f0cc --- /dev/null +++ b/tests/templates/spec7.md @@ -0,0 +1,12 @@ +# Task Template + +## Ticket description + +{% set ticket1 = call_tool('linear-test', 'get_issue', {'id': 'MCP-1'}) %} +{% set ticket2 = call_tool('linear-test', 'get_issue', {'id': 'MCP-2'}) %} + +### Description +{{ ticket1.description }} + +### Other Description +{{ ticket2.description }} diff --git a/tests/templates/spec8.md b/tests/templates/spec8.md new file mode 100644 index 0000000..ab9c95f --- /dev/null +++ b/tests/templates/spec8.md @@ -0,0 +1,19 @@ +# Task Template + +## Ticket description + +{% set ticket = call_tool('jira', 'getJiraIssue', {'cloudId': '483da417-278f-434d-8baf-132455657f48', 'issueIdOrKey': 'SCRUM-5'}) %} + +### Description +{{ ticket.fields.description }} + +## Information from greeting service + +{{ get_resource('test-mcp', 'greeting://foobar') }} + +## Another ticket description + +{% set ticket = call_tool('jira', 'getJiraIssue', {'cloudId': '483da417-278f-434d-8baf-132455657f48', 'issueIdOrKey': 'SCRUM-4'}) %} + +### Description +{{ ticket.fields.description }} \ No newline at end of file diff --git a/util/terminal.py b/util/terminal.py new file mode 100644 index 0000000..ec20eb2 --- /dev/null +++ b/util/terminal.py @@ -0,0 +1,21 @@ +import os + + +def supports_hyperlinks(): + # Check if terminal supports hyperlinks + return ( + os.getenv("TERM_PROGRAM") in ["iTerm.app", "vscode"] + or os.getenv("COLORTERM") == "truecolor" + or "hyperlinks" in os.getenv("TERM", "") + ) + + +def underline(text): + return f"\033[4m{text}\033[0m" + + +def display_hyperlink(url): + if supports_hyperlinks(): + return f"\033]8;;{url}\033\\{url}\033]8;;\033\\" + else: + return url