Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 65 additions & 44 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,47 +331,61 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
await event_source.response.aclose()
break

async def _send_error_response(self, ctx: RequestContext, error: Exception) -> None:
"""Send an error response to the client."""
error_data = ErrorData(code=32000, message=str(error))
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=ctx.session_message.message.root.id, error=error_data)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
await ctx.read_stream_writer.send(session_message)

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._prepare_headers()
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)

async with ctx.client.stream(
"POST",
self.url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
headers=headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
return
try:
async with ctx.client.stream(
"POST",
self.url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
headers=headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
return

if response.status_code == 404: # pragma: no branch
if isinstance(message.root, JSONRPCRequest):
await self._send_session_terminated_error( # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
message.root.id, # pragma: no cover
) # pragma: no cover
return # pragma: no cover
if response.status_code == 404: # pragma: no branch
if isinstance(message.root, JSONRPCRequest):
await self._send_session_terminated_error( # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
message.root.id, # pragma: no cover
) # pragma: no cover
return # pragma: no cover

response.raise_for_status()
response.raise_for_status()
if is_initialization:
self._maybe_extract_session_id_from_response(response)

# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = response.headers.get(CONTENT_TYPE, "").lower()
if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
elif content_type.startswith(SSE):
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type( # pragma: no cover
content_type, # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
) # pragma: no cover
except Exception as exc:
if is_initialization:
self._maybe_extract_session_id_from_response(response)

# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = response.headers.get(CONTENT_TYPE, "").lower()
if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
elif content_type.startswith(SSE):
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type( # pragma: no cover
content_type, # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
) # pragma: no cover
raise exc
else:
await self._send_error_response(ctx, exc)

async def _handle_json_response(
self,
Expand Down Expand Up @@ -406,7 +420,7 @@ async def _handle_sse_response(

try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse(): # pragma: no branch
async for sse in event_source.aiter_sse():
# Track last event ID for potential reconnection
if sse.id:
last_event_id = sse.id
Expand All @@ -426,13 +440,17 @@ async def _handle_sse_response(
if is_complete:
await response.aclose()
return # Normal completion, no reconnect needed
except Exception as e: # pragma: no cover
logger.debug(f"SSE stream ended: {e}")

# Stream ended without response - reconnect if we received an event with ID
if last_event_id is not None: # pragma: no branch
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
# Stream ended without response - try to reconnect if we have an event ID
if last_event_id is not None:
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
else:
# No event ID received, can't reconnect - report error
raise Exception("SSE stream ended without completing")
except Exception as exc:
logger.exception("Error handling SSE response")
await self._send_error_response(ctx, exc)

async def _handle_reconnection(
self,
Expand All @@ -441,11 +459,14 @@ async def _handle_reconnection(
retry_interval_ms: int | None = None,
attempt: int = 0,
) -> None:
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
"""Reconnect with Last-Event-ID to resume stream after server disconnect.

Raises:
Exception: If max reconnection attempts exceeded or reconnection fails.
"""
# Bail if max retries exceeded
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
return
if attempt >= MAX_RECONNECTION_ATTEMPTS:
raise Exception(f"SSE stream reconnection failed after {MAX_RECONNECTION_ATTEMPTS} attempts")

# Always wait - use server value or default
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
Expand Down Expand Up @@ -492,7 +513,7 @@ async def _handle_reconnection(
# Stream ended again without response - reconnect again (reset attempt counter)
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
except Exception as e: # pragma: no cover
except Exception as e:
logger.debug(f"Reconnection failed: {e}")
# Try to reconnect again if we still have an event ID
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
Expand Down
Loading