Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:

async def handle_message(session_message: SessionMessage) -> None:
logger.debug(f"Sending client message: {session_message}")
response = await client.post(
endpoint_url,
Expand All @@ -143,6 +144,13 @@ async def post_writer(endpoint_url: str):
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")

async for session_message in write_stream_reader:
async with anyio.create_task_group() as tg_local:
session_message.context.run(
tg_local.start_soon, handle_message, session_message
)

except Exception: # pragma: lax no cover
logger.exception("Error in post_writer")
finally:
Expand Down
11 changes: 8 additions & 3 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ async def post_writer(
"""Handle writing requests to the server."""
try:
async with write_stream_reader:
async for session_message in write_stream_reader:

async def handle_message(session_message: SessionMessage) -> None:
message = session_message.message
metadata = (
session_message.metadata
Expand Down Expand Up @@ -471,8 +472,12 @@ async def handle_request_async():
else:
await handle_request_async()

except Exception: # pragma: lax no cover
logger.exception("Error in post_writer")
async for session_message in write_stream_reader:
async with anyio.create_task_group() as tg_local:
session_message.context.run(tg_local.start_soon, handle_message, session_message)

except Exception:
logger.exception("Error in post_writer") # pragma: no cover
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
Expand Down
9 changes: 8 additions & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,14 @@ async def run(
async for message in session.incoming_messages:
logger.debug("Received message: %s", message)

tg.start_soon(
if isinstance(message, RequestResponder) and message.context is not None:
logger.debug("Got a context to propagate, %s", message.context)
context = message.context
else:
context = contextvars.copy_context()

context.run(
tg.start_soon,
self._handle_message,
message,
session,
Expand Down
4 changes: 3 additions & 1 deletion src/mcp/shared/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
to support transport-specific features like resumability.
"""

import contextvars
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from dataclasses import dataclass, field

from mcp.types import JSONRPCMessage, RequestId

Expand Down Expand Up @@ -46,4 +47,5 @@ class SessionMessage:
"""A message with specific metadata for transport-specific features."""

message: JSONRPCMessage
context: contextvars.Context = field(default_factory=contextvars.copy_context)
metadata: MessageMetadata = None
18 changes: 14 additions & 4 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextvars
import logging
from collections.abc import Callable
from contextlib import AsyncExitStack
Expand Down Expand Up @@ -77,11 +78,13 @@ def __init__(
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any],
message_metadata: MessageMetadata = None,
context: contextvars.Context | None = None,
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self.message_metadata = message_metadata
self.context = context
self._session = session
self._completed = False
self._cancel_scope = anyio.CancelScope()
Expand Down Expand Up @@ -330,10 +333,9 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
async def _receive_loop(self) -> None:
async with self._read_stream, self._write_stream:
try:
async for message in self._read_stream:
if isinstance(message, Exception): # pragma: no cover
await self._handle_incoming(message)
elif isinstance(message.message, JSONRPCRequest):

async def handle_message(message: SessionMessage) -> None:
if isinstance(message.message, JSONRPCRequest):
try:
validated_request = self._receive_request_adapter.validate_python(
message.message.model_dump(by_alias=True, mode="json", exclude_none=True),
Expand All @@ -346,6 +348,7 @@ async def _receive_loop(self) -> None:
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
context=message.context,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
Expand Down Expand Up @@ -403,6 +406,13 @@ async def _receive_loop(self) -> None:
else: # Response or error
await self._handle_response(message)

async for message in self._read_stream:
if isinstance(message, Exception): # pragma: no cover
await self._handle_incoming(message)
else:
async with anyio.create_task_group() as tg:
message.context.run(tg.start_soon, handle_message, message)

except anyio.ClosedResourceError:
# This is expected when the client disconnects abruptly.
# Without this handler, the exception would propagate up and
Expand Down
186 changes: 186 additions & 0 deletions tests/test_context_propagation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import contextvars
import multiprocessing
import socket
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Literal

import httpx
import pytest
import uvicorn
from inline_snapshot import snapshot
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response

import mcp.types as types
from mcp import Client
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamable_http_client
from mcp.server import MCPServer
from tests.test_helpers import wait_for_server

TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial")


@contextmanager
def set_test_contextvar(value: str) -> Iterator[None]:
token = TEST_CONTEXTVAR.set(value)
try:
yield
finally:
TEST_CONTEXTVAR.reset(token)


# Sends header CLIENT_HEADER with a configured value
class SendClientHeaderTransport(httpx.AsyncHTTPTransport):
def __init__(self) -> None:
super().__init__()
self.client_header_value: str = "initial"

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
request.headers["CLIENT_HEADER"] = self.client_header_value
return await super().handle_async_request(request)


# Intercepts the httpx call to capture the contextvar's value
class ContextCapturingTransport(httpx.AsyncHTTPTransport):
def __init__(self):
super().__init__()
self.captured_context_var: str | None = None

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
self.captured_context_var = TEST_CONTEXTVAR.get()
return await super().handle_async_request(request)


def create_server() -> MCPServer:
mcp = MCPServer("test_server")

# tool that returns the value of TEST_CONTEXT_VAR.
@mcp.tool()
async def my_tool() -> str:
return TEST_CONTEXTVAR.get()

return mcp


@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


def run_server(transport: Literal["sse", "streamable_http"], port: int): # pragma: no cover
class ContextVarMiddleware(BaseHTTPMiddleware): # pragma: lax no cover
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
actual_value = request.headers.get("CLIENT_HEADER")
with set_test_contextvar(f"from middleware CLIENT_HEADER={actual_value}"):
return await call_next(request)

server = create_server()

match transport:
case "sse":
app = server.sse_app(host="127.0.0.1")
case "streamable_http":
app = server.streamable_http_app(host="127.0.0.1")

app.add_middleware(ContextVarMiddleware)

uvicorn.run(app, host="127.0.0.1", port=port, log_level="error")


@contextmanager
def start_server_process(transport: Literal["sse", "streamable_http"], port: int):
"""Start server in a separate process."""
process = multiprocessing.Process(target=run_server, args=(transport, port))

process.start()
try:
wait_for_server(port)
yield process
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_memory_transport_client_to_server():
async with Client(create_server()) as client:
with set_test_contextvar("client_value"):
result = await client.call_tool(name="my_tool")

assert isinstance(result, types.CallToolResult)
assert result.content == snapshot([types.TextContent(text="client_value")])


@pytest.mark.anyio
async def test_streamable_http_asgi_to_mcpserver(server_port: int):
with start_server_process("streamable_http", server_port):
async with (
SendClientHeaderTransport() as transport,
httpx.AsyncClient(transport=transport) as http_client,
Client(streamable_http_client(f"http://127.0.0.1:{server_port}/mcp", http_client=http_client)) as client,
):
transport.client_header_value = "expected_value"
result = await client.call_tool("my_tool")
assert result.content == snapshot([types.TextContent(text="from middleware CLIENT_HEADER=expected_value")])


@pytest.mark.anyio
async def test_streamable_http_mcpclient_to_httpx(server_port: int):
with start_server_process("streamable_http", server_port):
async with (
ContextCapturingTransport() as transport,
httpx.AsyncClient(transport=transport) as http_client,
Client(streamable_http_client(f"http://127.0.0.1:{server_port}/mcp", http_client=http_client)) as client,
):
with set_test_contextvar("client_value_list"):
await client.list_tools()
assert transport.captured_context_var == snapshot("client_value_list")

with set_test_contextvar("client_value_call_tool"): # pragma: lax no cover
await client.call_tool("my_tool")
assert transport.captured_context_var == snapshot("client_value_call_tool")


@pytest.mark.anyio
async def test_sse_asgi_to_mcpserver(server_port: int):
transport = SendClientHeaderTransport()

def client_factory(
headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None
) -> httpx.AsyncClient:
return httpx.AsyncClient(transport=transport, headers=headers, timeout=timeout, auth=auth)

with start_server_process("sse", server_port):
async with Client(
sse_client(f"http://127.0.0.1:{server_port}/sse", httpx_client_factory=client_factory)
) as client:
transport.client_header_value = "expected_value"
result = await client.call_tool("my_tool")
assert result.content == snapshot([types.TextContent(text="from middleware CLIENT_HEADER=expected_value")])


@pytest.mark.anyio
async def test_sse_mcpclient_to_httpx(server_port: int):
transport = ContextCapturingTransport()

def client_factory(
headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None
) -> httpx.AsyncClient:
return httpx.AsyncClient(transport=transport, headers=headers, timeout=timeout, auth=auth)

with start_server_process("sse", server_port):
async with Client(
sse_client(f"http://127.0.0.1:{server_port}/sse", httpx_client_factory=client_factory)
) as client:
with set_test_contextvar("client_value_list"):
await client.list_tools()
assert transport.captured_context_var == snapshot("client_value_list")

with set_test_contextvar("client_value_call_tool"): # pragma: lax no cover
await client.call_tool("my_tool")
assert transport.captured_context_var == snapshot("client_value_call_tool")
Loading
Loading