Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,6 @@ ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
Expand Down Expand Up @@ -189,4 +179,5 @@ marimo/_lsp/
__marimo__/

# ContextKit
.cxk
.cxk
task.md
5 changes: 5 additions & 0 deletions auth_server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""OAuth2 auth server for handling MCP authorization callbacks."""

from .auth_server import AuthServer

__all__ = ["AuthServer"]
185 changes: 185 additions & 0 deletions auth_server/auth_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""OAuth2 authentication server for handling MCP client callbacks."""

import asyncio
import logging
import threading

import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse


class AuthServer:
"""FastAPI-based OAuth2 callback server for MCP authentication."""

def __init__(self, host: str = "localhost", port: int = 41008):
"""Initialize the auth server."""
self.host = host
self.port = port
self.app = FastAPI()
self.server: uvicorn.Server | None = None
self.server_thread: threading.Thread | None = None

# Storage for the callback response
self._callback_future: asyncio.Future | None = None
self._callback_code: str | None = None
self._callback_state: str | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None

# Setup routes
self._setup_routes()

def _setup_routes(self):
"""Setup FastAPI routes."""

@self.app.get("/callback")
async def callback_endpoint(request: Request):
"""Handle OAuth2 callback."""
query_params = dict(request.query_params)

# Extract code and state from query parameters
code = query_params.get("code")
state = query_params.get("state")
error = query_params.get("error")

if error:
error_description = query_params.get("error_description", "Unknown error")
if self._callback_future and not self._callback_future.done():
# Schedule the exception in the main event loop
if self._main_loop:
self._main_loop.call_soon_threadsafe(
self._callback_future.set_exception,
Exception(f"OAuth error: {error} - {error_description}"),
)
return HTMLResponse(
content=f"<html><body><h2>Authentication Error</h2><p>{error}: {error_description}</p></body></html>", # noqa: E501
status_code=400,
)

if not code:
if self._callback_future and not self._callback_future.done():
# Schedule the exception in the main event loop
if self._main_loop:
self._main_loop.call_soon_threadsafe(
self._callback_future.set_exception, Exception("No authorization code received")
)
Comment on lines +47 to +65
Copy link

Copilot AI Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The callback error handling logic is duplicated in multiple places. Consider extracting this into a helper method like _set_callback_exception() to reduce code duplication.

Suggested change
if self._callback_future and not self._callback_future.done():
# Schedule the exception in the main event loop
if self._main_loop:
self._main_loop.call_soon_threadsafe(
self._callback_future.set_exception,
Exception(f"OAuth error: {error} - {error_description}"),
)
return HTMLResponse(
content=f"<html><body><h2>Authentication Error</h2><p>{error}: {error_description}</p></body></html>", # noqa: E501
status_code=400,
)
if not code:
if self._callback_future and not self._callback_future.done():
# Schedule the exception in the main event loop
if self._main_loop:
self._main_loop.call_soon_threadsafe(
self._callback_future.set_exception, Exception("No authorization code received")
)
self._set_callback_exception(Exception(f"OAuth error: {error} - {error_description}"))
return HTMLResponse(
content=f"<html><body><h2>Authentication Error</h2><p>{error}: {error_description}</p></body></html>", # noqa: E501
status_code=400,
)
if not code:
self._set_callback_exception(Exception("No authorization code received"))

Copilot uses AI. Check for mistakes.
return HTMLResponse(
content="<html><body><h2>Error</h2><p>No authorization code received</p></body></html>",
status_code=400,
)

# Store the callback data
self._callback_code = code
self._callback_state = state

# Signal that callback was received
if self._callback_future and not self._callback_future.done():
# Schedule the future resolution in the main event loop
if self._main_loop:
self._main_loop.call_soon_threadsafe(self._callback_future.set_result, (code, state))
logging.info("OAuth2 callback handled successfully")
else:
logging.error("Main event loop is not set; cannot set callback future result")
else:
logging.error("Callback future is not set or already done; cannot set result")

return HTMLResponse(
content="""
<html>
<body>
<h2>Authorization Successful</h2>
<p>You have successfully authorized the MCP client. You can now close this window.</p>
</body>
</html>
"""
)

async def start(self):
"""Start the auth server in a background thread."""
logging.info(f"Starting auth server at http://{self.host}:{self.port}")
if self.server_thread and self.server_thread.is_alive():
logging.info("Auth server is already running")
return # Already running

config = uvicorn.Config(
app=self.app,
host=self.host,
port=self.port,
log_level="warning", # Reduce log noise
access_log=False,
)
self.server = uvicorn.Server(config)

# Start server in a separate thread to avoid blocking
self.server_thread = threading.Thread(target=self._run_server, daemon=True)
self.server_thread.start()

# Wait a bit for the server to start
await asyncio.sleep(0.1)

def _run_server(self):
"""Run the server in the thread."""
if self.server is not None:
asyncio.run(self.server.serve())

async def handle_callback(self) -> tuple[str, str | None]:
"""
Wait for and handle the OAuth2 callback.

Returns:
tuple: (code, state) from the OAuth2 callback

Raises:
Exception: If there's an error in the OAuth2 flow
"""
# Capture the current event loop
self._main_loop = asyncio.get_running_loop()

# Create a future to wait for the callback
self._callback_future = asyncio.Future()

try:
# Wait for the callback to be received
logging.info("Waiting for OAuth2 callback...")
code, state = await self._callback_future
logging.info(f"Received callback with code: {code}, state: {state}")
return code, state
finally:
# Clean up
logging.info("Cleaning up auth server state after callback")
self._callback_future = None
self._main_loop = None

async def stop(self):
"""Stop the auth server and clean up state."""
if self.server:
self.server.should_exit = True

if self.server_thread and self.server_thread.is_alive():
# Wait a bit for graceful shutdown
self.server_thread.join(timeout=1.0)

# Clear state
self._callback_code = None
self._callback_state = None
self._callback_future = None
self._main_loop = None
self.server = None
self.server_thread = None

logging.info("Auth server stopped")

@property
def callback_url(self) -> str:
"""Get the callback URL for this auth server."""
return f"http://{self.host}:{self.port}/callback"

async def __aenter__(self):
"""Async context manager entry."""
await self.start()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.stop()
return None
159 changes: 84 additions & 75 deletions commands/create_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys

from engine import TemplateEngine, TemplateParseError
from mcp_client import get_session_manager
from prompt import PromptHelper
from state import State
from util.parse import parse_input_string
Expand All @@ -14,84 +15,92 @@ async def handle_create_spec(
output_file: str | None = None,
var_overrides: list[str] | None = None,
):
prompt_helper = PromptHelper(state)

# Detect piped input (stdin not a TTY) and ensure there's data before using it
stdin_piped = not sys.stdin.isatty()

# Determine template source
template_engine: TemplateEngine
if spec_template:
# Resolve relative paths against current working directory
template_path = os.path.abspath(spec_template)

# Check if template file exists
if not os.path.exists(template_path):
logging.error(f"Error: Template file '{spec_template}' not found")
# Pre-initialize all MCP client sessions
session_manager = get_session_manager()
try:
await session_manager.initialize_all_sessions(state)
prompt_helper = PromptHelper(state)

# Detect piped input (stdin not a TTY) and ensure there's data before using it
stdin_piped = not sys.stdin.isatty()

# Determine template source
template_engine: TemplateEngine
if spec_template:
# Resolve relative paths against current working directory
template_path = os.path.abspath(spec_template)

# Check if template file exists
if not os.path.exists(template_path):
logging.error(f"Error: Template file '{spec_template}' not found")
sys.exit(1)

template_engine = TemplateEngine.from_file(template_path, state, prompt_helper)
elif stdin_piped:
try:
template_str = sys.stdin.read()
except Exception as e:
logging.error(f"Error: Failed to read from stdin: {e}")
sys.exit(1)
if not template_str:
logging.error("Error: No data received on stdin for template")
sys.exit(1)

template_engine = TemplateEngine.from_string(template_str, state, prompt_helper)
else:
logging.error("Error: Missing spec_template argument (or provide template via stdin)")
sys.exit(1)

template_engine = TemplateEngine.from_file(template_path, state, prompt_helper)
elif stdin_piped:
try:
template_str = sys.stdin.read()
except Exception as e:
logging.error(f"Error: Failed to read from stdin: {e}")
variables = template_engine.get_variables()

# Parse var_overrides into a dictionary
provided_vars = {}
if var_overrides:
for var_override in var_overrides:
if "=" not in var_override:
logging.error(f"Error: Invalid variable format '{var_override}'. Use KEY=VALUE format.")
sys.exit(1)
key, value = var_override.split("=", 1)
provided_vars[key] = value

# Collect values for each variable
collected_vars = {}
if variables:
logging.info("Collecting values for template variables:")
for var in sorted(variables):
if var in provided_vars:
raw_value = provided_vars[var]

else:
raw_value = await prompt_helper.collect_var_value(var)
logging.info(f" {var}: {raw_value}")

collected_vars[var] = parse_input_string(raw_value)

else:
logging.info("No variables found in template")

# Render the template with collected variables
rendered_content = await template_engine.render_async(**collected_vars)

# Output to file or stdout
if output_file:
output_path = os.path.abspath(output_file)
with open(output_path, "w") as f:
f.write(rendered_content)
logging.info(f"Rendered template saved to: {output_path}")
else:
logging.debug("\nRendered template:")
print(rendered_content)

except TemplateParseError as e:
logging.error(f"Error: {e}")
sys.exit(1)
if not template_str:
logging.error("Error: No data received on stdin for template")
except Exception as e:
logging.error(f"Error: Failed to process template: {e}")
sys.exit(1)

template_engine = TemplateEngine.from_string(template_str, state, prompt_helper)
else:
logging.error("Error: Missing spec_template argument (or provide template via stdin)")
sys.exit(1)

try:
variables = template_engine.get_variables()

# Parse var_overrides into a dictionary
provided_vars = {}
if var_overrides:
for var_override in var_overrides:
if "=" not in var_override:
logging.error(f"Error: Invalid variable format '{var_override}'. Use KEY=VALUE format.")
sys.exit(1)
key, value = var_override.split("=", 1)
provided_vars[key] = value

# Collect values for each variable
collected_vars = {}
if variables:
logging.info("Collecting values for template variables:")
for var in sorted(variables):
if var in provided_vars:
raw_value = provided_vars[var]

else:
raw_value = await prompt_helper.collect_var_value(var)
logging.info(f" {var}: {raw_value}")

collected_vars[var] = parse_input_string(raw_value)

else:
logging.info("No variables found in template")

# Render the template with collected variables
rendered_content = await template_engine.render_async(**collected_vars)

# Output to file or stdout
if output_file:
output_path = os.path.abspath(output_file)
with open(output_path, "w") as f:
f.write(rendered_content)
logging.info(f"Rendered template saved to: {output_path}")
else:
logging.debug("\nRendered template:")
print(rendered_content)

except TemplateParseError as e:
logging.error(f"Error: {e}")
sys.exit(1)
except Exception as e:
logging.error(f"Error: Failed to process template: {e}")
sys.exit(1)
finally:
# Clean up all MCP sessions
await session_manager.cleanup()
Loading