Skip to content
64 changes: 34 additions & 30 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from pydantic import ValidationError

from mcp.client._transport import TransportStreams
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
INVALID_REQUEST,
PARSE_ERROR,
ErrorData,
InitializeResult,
JSONRPCError,
Expand Down Expand Up @@ -163,6 +166,11 @@ async def _handle_sse_event(

except Exception as exc: # pragma: no cover
logger.exception("Error parsing SSE message")
if original_request_id is not None:
error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse SSE message: {exc}")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data))
await read_stream_writer.send(error_msg)
return True
await read_stream_writer.send(exc)
return False
else: # pragma: no cover
Expand Down Expand Up @@ -260,7 +268,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:

if response.status_code == 404: # pragma: no branch
if isinstance(message, JSONRPCRequest): # pragma: no branch
await self._send_session_terminated_error(ctx.read_stream_writer, message.id)
error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated")
session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data))
await ctx.read_stream_writer.send(session_message)
return

response.raise_for_status()
Expand All @@ -272,20 +282,24 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
if isinstance(message, JSONRPCRequest):
content_type = response.headers.get("content-type", "").lower()
if content_type.startswith("application/json"):
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
await self._handle_json_response(
response, ctx.read_stream_writer, is_initialization, request_id=message.id
)
elif content_type.startswith("text/event-stream"):
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type( # pragma: no cover
content_type, # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
) # pragma: no cover
logger.error(f"Unexpected content type: {content_type}")
error_data = ErrorData(code=INVALID_REQUEST, message=f"Unexpected content type: {content_type}")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data))
await ctx.read_stream_writer.send(error_msg)

async def _handle_json_response(
self,
response: httpx.Response,
read_stream_writer: StreamWriter,
is_initialization: bool = False,
*,
request_id: RequestId,
) -> None:
"""Handle JSON response from the server."""
try:
Expand All @@ -298,9 +312,11 @@ async def _handle_json_response(

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except Exception as exc: # pragma: no cover
except (httpx.StreamError, ValidationError) as exc:
logger.exception("Error parsing JSON response")
await read_stream_writer.send(exc)
error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse JSON response: {exc}")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data))
await read_stream_writer.send(error_msg)

async def _handle_sse_response(
self,
Expand All @@ -312,6 +328,11 @@ async def _handle_sse_response(
last_event_id: str | None = None
retry_interval_ms: int | None = None

# The caller (_handle_post_request) only reaches here inside
# isinstance(message, JSONRPCRequest), so this is always a JSONRPCRequest.
assert isinstance(ctx.session_message.message, JSONRPCRequest)
original_request_id = ctx.session_message.message.id

try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse(): # pragma: no branch
Expand All @@ -326,6 +347,7 @@ async def _handle_sse_response(
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id=original_request_id,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
Expand All @@ -334,8 +356,8 @@ async def _handle_sse_response(
if is_complete:
await response.aclose()
return # Normal completion, no reconnect needed
except Exception as e:
logger.debug(f"SSE stream ended: {e}") # pragma: no cover
except Exception:
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover

# Stream ended without response - reconnect if we received an event with ID
if last_event_id is not None: # pragma: no branch
Expand Down Expand Up @@ -400,24 +422,6 @@ async def _handle_reconnection(
# Try to reconnect again if we still have an event ID
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)

async def _handle_unexpected_content_type(
self, content_type: str, read_stream_writer: StreamWriter
) -> None: # pragma: no cover
"""Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}" # pragma: no cover
logger.error(error_msg) # pragma: no cover
await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover

async def _send_session_terminated_error(self, read_stream_writer: StreamWriter, request_id: RequestId) -> None:
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=32600, message="Session terminated"),
)
session_message = SessionMessage(jsonrpc_error)
await read_stream_writer.send(session_message)

async def post_writer(
self,
client: httpx.AsyncClient,
Expand Down Expand Up @@ -467,8 +471,8 @@ async def handle_request_async():
else:
await handle_request_async()

except Exception:
logger.exception("Error in post_writer") # pragma: no cover
except Exception: # pragma: lax no cover
logger.exception("Error in post_writer")
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/types/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class JSONRPCError(BaseModel):
"""A response to a request that indicates an error occurred."""

jsonrpc: Literal["2.0"]
id: str | int
id: RequestId
error: ErrorData


Expand Down
Loading