diff --git a/commands/mcp.py b/commands/mcp.py index deafcd4..a7489b6 100644 --- a/commands/mcp.py +++ b/commands/mcp.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from mcp_client import SSEServerConfig, StdioServerConfig +from mcp_client import HTTPServerConfig, SSEServerConfig, StdioServerConfig from state import State @@ -18,6 +18,7 @@ class MCPAddStdioContext(BaseModel): class MCPAddHttpContext(BaseModel): server_name: str url: str + headers: list[str] | None = None class MCPCommandContext(BaseModel): @@ -52,7 +53,21 @@ async def handle_mcp(state: State, context: MCPCommandContext): ) elif context.subcommand == "add-http" and context.add_http: - await handle_add_http(state, context.add_http.server_name, context.add_http.url) + headers_dict = {} + if context.add_http.headers: + for header in context.add_http.headers: + if "=" in header: + key, value = header.split("=", 1) + headers_dict[key] = value + else: + raise ValueError(f"Invalid header format: {header}. Use KEY=VALUE format.") + + await handle_add_http( + state, + context.add_http.server_name, + context.add_http.url, + headers_dict if headers_dict else None, + ) async def handle_add_sse(state: State, server_name: str, url: str): @@ -89,5 +104,14 @@ async def handle_add_stdio( print(f"Environment variables: {env}") -async def handle_add_http(state: State, server_name: str, url: str): - print(f"HTTP server support not implemented yet. Would add '{server_name}' with URL: {url}") +async def handle_add_http(state: State, server_name: str, url: str, headers: dict[str, str] | None = None): + if server_name in state.mcp_config.mcpServers: + raise ValueError(f"Server '{server_name}' already exists") + + server_config = HTTPServerConfig(url=url, headers=headers) + state.mcp_config.mcpServers[server_name] = server_config + state.save_mcp_config() + + print(f"Added HTTP server '{server_name}' with URL: {url}") + if headers: + print(f"Headers: {headers}") diff --git a/cxk.py b/cxk.py index 4069499..fe27f9a 100644 --- a/cxk.py +++ b/cxk.py @@ -44,10 +44,11 @@ async def main(): add_stdio_parser.add_argument("--env", action="append", help="Environment variable (key=value)") add_stdio_parser.add_argument("command_line", nargs=argparse.ONE_OR_MORE, help="Command to run") - # cxk mcp add-http [server-name] [url] + # cxk mcp add-http [server-name] [url] --header [header] add_http_parser = mcp_subparsers.add_parser("add-http", help="Add HTTP MCP server") add_http_parser.add_argument("server_name", help="Name of the server") add_http_parser.add_argument("url", help="URL of the HTTP server") + add_http_parser.add_argument("--header", action="append", help="HTTP header (key=value)") args = parser.parse_args() @@ -91,7 +92,7 @@ async def main(): elif args.mcp_command == "add-http": mcp_context = MCPCommandContext( subcommand="add-http", - add_http=MCPAddHttpContext(server_name=args.server_name, url=args.url), + add_http=MCPAddHttpContext(server_name=args.server_name, url=args.url, headers=args.header), ) await handle_mcp(state, mcp_context) diff --git a/mcp_client/__init__.py b/mcp_client/__init__.py index f7e7c86..d5b51e5 100644 --- a/mcp_client/__init__.py +++ b/mcp_client/__init__.py @@ -1,11 +1,12 @@ """MCP (Model Context Protocol) client implementation.""" from .client_session_provider import get_client_session_by_server -from .config import MCPServersConfig, SSEServerConfig, StdioServerConfig +from .config import HTTPServerConfig, MCPServersConfig, SSEServerConfig, StdioServerConfig from .session_manager import MCPSessionManager, get_session_manager from .token_storage import KeychainTokenStorageWithFallback __all__ = [ + "HTTPServerConfig", "MCPServersConfig", "SSEServerConfig", "StdioServerConfig", diff --git a/mcp_client/client_session_provider.py b/mcp_client/client_session_provider.py index 2c629f0..6db2b4d 100644 --- a/mcp_client/client_session_provider.py +++ b/mcp_client/client_session_provider.py @@ -14,6 +14,7 @@ from pydantic import AnyUrl from auth_server import AuthServer +from mcp_client.config import HTTPServerConfig from util.terminal import display_hyperlink from .mcp_logger import get_mcp_log_file @@ -45,23 +46,30 @@ async def get_stdio_session(server_params: StdioServerParameters, config_dir: Pa @asynccontextmanager -async def get_streamablehttp_session(server_url: str, server_name: str, state: "State"): - token_storage = state.get_token_storage(server_name) - oauth_auth = OAuthClientProvider( - server_url=server_url, - client_metadata=OAuthClientMetadata( - client_name="ContextKit MCP Client", - redirect_uris=[AnyUrl("http://localhost:3000/callback")], - grant_types=["authorization_code", "refresh_token"], - response_types=["code"], - scope="user", - ), - storage=token_storage, - redirect_handler=handle_redirect, - callback_handler=handle_callback, - ) +async def get_streamablehttp_session( + server_config: HTTPServerConfig, server_name: str, state: "State", auth_server: AuthServer | None = None +): + if auth_server is None: + raise ValueError("AuthServer must be provided for HTTP sessions") + + # If the server_config.headers includes an Authorization header, skip oauth auth (PAT or API key used directly) + oauth_auth = None + if "Authorization" not in (server_config.headers or {}): + token_storage = state.get_token_storage(server_name) + oauth_auth = OAuthClientProvider( + server_url=server_config.url, + client_metadata=OAuthClientMetadata( + client_name="ContextKit MCP Client", + redirect_uris=[AnyUrl(auth_server.callback_url)], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + ), + storage=token_storage, + redirect_handler=handle_redirect, + callback_handler=auth_server.handle_callback, + ) # Connect to a streamable HTTP server - async with streamablehttp_client(server_url, auth=oauth_auth) as ( + async with streamablehttp_client(server_config.url, headers=server_config.headers, auth=oauth_auth) as ( read_stream, write_stream, _, diff --git a/mcp_client/config.py b/mcp_client/config.py index cabf088..cb858eb 100644 --- a/mcp_client/config.py +++ b/mcp_client/config.py @@ -45,10 +45,25 @@ def validate_type(cls, v): return v +class HTTPServerConfig(BaseServerConfig): + """Configuration for HTTP-based MCP servers.""" + + type: str = Field(default="http", description="Server transport type") + url: str = Field(..., description="URL endpoint for the HTTP server") + headers: dict[str, str] | None = Field(default=None, description="HTTP headers") + + @field_validator("type") + @classmethod + def validate_type(cls, v): + if v != "http": + raise ValueError('type must be "http" for HTTPServerConfig') + return v + + class MCPServersConfig(BaseModel): """Root configuration containing all MCP servers.""" - mcpServers: dict[str, StdioServerConfig | SSEServerConfig] = Field( + mcpServers: dict[str, StdioServerConfig | SSEServerConfig | HTTPServerConfig] = Field( ..., description="Dictionary of server name to server configuration" ) @@ -74,6 +89,8 @@ def validate_servers(cls, values): if server_type == "sse": validated_servers[name] = SSEServerConfig(**config) + elif server_type == "http": + validated_servers[name] = HTTPServerConfig(**config) else: validated_servers[name] = StdioServerConfig(**config) else: diff --git a/mcp_client/session_manager.py b/mcp_client/session_manager.py index 76280cf..fa1390c 100644 --- a/mcp_client/session_manager.py +++ b/mcp_client/session_manager.py @@ -6,7 +6,7 @@ from mcp import ClientSession -from .config import SSEServerConfig, StdioServerConfig +from .config import HTTPServerConfig, SSEServerConfig, StdioServerConfig if TYPE_CHECKING: from state import State @@ -35,6 +35,7 @@ async def initialize_all_sessions(self, state: "State") -> None: # Use shared auth server for all SSE connections during initialization from auth_server import AuthServer + async with AuthServer() as auth_server: for server_name, server_config in state.mcp_config.mcpServers.items(): try: @@ -57,6 +58,9 @@ async def initialize_all_sessions(self, state: "State") -> None: from .client_session_provider import get_sse_session session_cm = get_sse_session(server_config.url, server_name, state, auth_server) + elif isinstance(server_config, HTTPServerConfig): + from .client_session_provider import get_streamablehttp_session + session_cm = get_streamablehttp_session(server_config, server_name, state, auth_server) else: raise ValueError(f"Unsupported server type for '{server_name}': {type(server_config)}") diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index ae2b8d2..1891a0e 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -193,24 +193,68 @@ def test_mcp_add_stdio_with_env(self, temp_git_repo): assert server_config["args"] == ["server.js"] assert server_config["env"] == {"API_KEY": "test123", "DEBUG": "true"} - def test_mcp_add_http_placeholder(self, temp_git_repo): - """Test adding HTTP MCP server (placeholder functionality).""" + def test_mcp_add_http(self, temp_git_repo): + """Test adding HTTP MCP server.""" # Initialize project first init_result = self.run_cli(["init"], cwd=temp_git_repo) assert init_result.returncode == 0 - # Add HTTP server (should show placeholder message) + # Add HTTP server result = self.run_cli( ["mcp", "add-http", "test-http", "http://example.com/api"], cwd=temp_git_repo, ) assert result.returncode == 0 - assert ( - "HTTP server support not implemented yet. Would add 'test-http' with URL: http://example.com/api" - in result.stdout + assert "Added HTTP server 'test-http' with URL: http://example.com/api" in result.stdout + + # Verify the configuration was saved + config_file = temp_git_repo / ".cxk" / "mcp.json" + assert config_file.exists() + + config_data = json.loads(config_file.read_text()) + assert "test-http" in config_data["mcpServers"] + server_config = config_data["mcpServers"]["test-http"] + assert server_config["type"] == "http" + assert server_config["url"] == "http://example.com/api" + + def test_mcp_add_http_with_headers(self, temp_git_repo): + """Test adding HTTP MCP server with headers.""" + # Initialize project first + init_result = self.run_cli(["init"], cwd=temp_git_repo) + assert init_result.returncode == 0 + + # Add HTTP server with headers + result = self.run_cli( + [ + "mcp", + "add-http", + "test-http-headers", + "http://example.com/api", + "--header", + "Authorization=Bearer token123", + "--header", + "Content-Type=application/json", + ], + cwd=temp_git_repo, ) + assert result.returncode == 0 + assert "Added HTTP server 'test-http-headers' with URL: http://example.com/api" in result.stdout + assert "Headers: {'Authorization': 'Bearer token123', 'Content-Type': 'application/json'}" in result.stdout + + # Verify the configuration was saved with headers + config_file = temp_git_repo / ".cxk" / "mcp.json" + assert config_file.exists() + + config_data = json.loads(config_file.read_text()) + assert "test-http-headers" in config_data["mcpServers"] + server_config = config_data["mcpServers"]["test-http-headers"] + assert server_config["type"] == "http" + assert server_config["url"] == "http://example.com/api" + assert server_config["headers"]["Authorization"] == "Bearer token123" + assert server_config["headers"]["Content-Type"] == "application/json" + def test_mcp_duplicate_server_name(self, temp_git_repo): """Test adding MCP server with duplicate name should fail.""" # Initialize project first diff --git a/tests/templates/spec9.md b/tests/templates/spec9.md new file mode 100644 index 0000000..7adae2f --- /dev/null +++ b/tests/templates/spec9.md @@ -0,0 +1,8 @@ +# Task Template + +## Issue description + +{% set issue = call_tool('github', 'get_issue', {'issue_number': 17, 'owner': 'eyalzh', 'repo': 'browser-control-mcp'}) %} + +### Github Issue Description +{{ issue.body }} \ No newline at end of file