Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import logging
import os
import sys
from datetime import timedelta
from urllib.parse import ParseResult, parse_qs, urlparse

import httpx
Expand Down Expand Up @@ -263,8 +262,8 @@ async def _run_session(server_url: str, oauth_auth: OAuthClientProvider) -> None
async with streamablehttp_client(
url=server_url,
auth=oauth_auth,
timeout=timedelta(seconds=30),
sse_read_timeout=timedelta(seconds=60),
timeout=30.0,
sse_read_timeout=60.0,
) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ async def _default_redirect_handler(authorization_url: str) -> None:
async with sse_client(
url=self.server_url,
auth=oauth_auth,
timeout=60,
timeout=60.0,
) as (read_stream, write_stream):
await self._run_session(read_stream, write_stream, None)
else:
Expand Down
5 changes: 2 additions & 3 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from datetime import timedelta
from typing import Any, Protocol, overload

import anyio.lowlevel
Expand Down Expand Up @@ -113,7 +112,7 @@ def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
Expand Down Expand Up @@ -369,7 +368,7 @@ async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
Expand Down
29 changes: 14 additions & 15 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import logging
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from types import TracebackType
from typing import Any, TypeAlias, overload

Expand Down Expand Up @@ -41,11 +40,11 @@ class SseServerParameters(BaseModel):
# Optional headers to include in requests.
headers: dict[str, Any] | None = None

# HTTP timeout for regular operations.
timeout: float = 5
# HTTP timeout for regular operations (in seconds).
timeout: float = 5.0

# Timeout for SSE read operations.
sse_read_timeout: float = 60 * 5
# Timeout for SSE read operations (in seconds).
sse_read_timeout: float = 300.0


class StreamableHttpParameters(BaseModel):
Expand All @@ -57,11 +56,11 @@ class StreamableHttpParameters(BaseModel):
# Optional headers to include in requests.
headers: dict[str, Any] | None = None

# HTTP timeout for regular operations.
timeout: timedelta = timedelta(seconds=30)
# HTTP timeout for regular operations (in seconds).
timeout: float = 30.0

# Timeout for SSE read operations.
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
# Timeout for SSE read operations (in seconds).
sse_read_timeout: float = 300.0

# Close the client session when the transport closes.
terminate_on_close: bool = True
Expand All @@ -76,7 +75,7 @@ class StreamableHttpParameters(BaseModel):
class ClientSessionParameters:
"""Parameters for establishing a client session to an MCP server."""

read_timeout_seconds: timedelta | None = None
read_timeout_seconds: float | None = None
sampling_callback: SamplingFnT | None = None
elicitation_callback: ElicitationFnT | None = None
list_roots_callback: ListRootsFnT | None = None
Expand Down Expand Up @@ -197,7 +196,7 @@ async def call_tool(
self,
name: str,
arguments: dict[str, Any],
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
Expand All @@ -210,7 +209,7 @@ async def call_tool(
name: str,
*,
args: dict[str, Any],
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult: ...
Expand All @@ -219,7 +218,7 @@ async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
Expand Down Expand Up @@ -314,8 +313,8 @@ async def _establish_session(
httpx_client = create_mcp_http_client(
headers=server_params.headers,
timeout=httpx.Timeout(
server_params.timeout.total_seconds(),
read=server_params.sse_read_timeout.total_seconds(),
server_params.timeout,
read=server_params.sse_read_timeout,
),
)
await session_stack.enter_async_context(httpx_client)
Expand Down
8 changes: 4 additions & 4 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
timeout: float = 5.0,
sse_read_timeout: float = 300.0,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
on_session_created: Callable[[str], None] | None = None,
Expand All @@ -46,8 +46,8 @@ async def sse_client(
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
timeout: HTTP timeout for regular operations (in seconds).
sse_read_timeout: Timeout for SSE read operations (in seconds).
auth: Optional HTTPX authentication handler.
on_session_created: Optional callback invoked with the session ID when received.
"""
Expand Down
8 changes: 4 additions & 4 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def __init__(
self,
url: str,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
timeout: float = 30.0,
sse_read_timeout: float = 300.0,
auth: httpx.Auth | None = None,
) -> None: ...

Expand All @@ -118,8 +118,8 @@ def __init__(
Args:
url: The endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
timeout: HTTP timeout for regular operations (in seconds).
sse_read_timeout: Timeout for SSE read operations (in seconds).
auth: Optional HTTPX authentication handler.
"""
# Check for deprecated parameters and issue runtime warning
Expand Down
3 changes: 1 addition & 2 deletions src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import Any

import anyio
Expand Down Expand Up @@ -49,7 +48,7 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS
@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server[Any] | FastMCP,
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
Expand Down
9 changes: 4 additions & 5 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from collections.abc import Callable
from contextlib import AsyncExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Protocol, TypeVar

Expand Down Expand Up @@ -189,7 +188,7 @@ def __init__(
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
read_timeout_seconds: float | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
Expand Down Expand Up @@ -241,7 +240,7 @@ async def send_request(
self,
request: SendRequestT,
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
request_read_timeout_seconds: float | None = None,
metadata: MessageMetadata = None,
progress_callback: ProgressFnT | None = None,
) -> ReceiveResultT:
Expand Down Expand Up @@ -283,9 +282,9 @@ async def send_request(
# request read timeout takes precedence over session read timeout
timeout = None
if request_read_timeout_seconds is not None: # pragma: no cover
timeout = request_read_timeout_seconds.total_seconds()
timeout = request_read_timeout_seconds
elif self._session_read_timeout_seconds is not None: # pragma: no cover
timeout = self._session_read_timeout_seconds.total_seconds()
timeout = self._session_read_timeout_seconds

try:
with anyio.fail_after(timeout):
Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ async def test_disconnect_non_existent_server(self):
"mcp.client.session_group.mcp.stdio_client",
),
(
SseServerParameters(url="http://test.com/sse", timeout=10),
SseServerParameters(url="http://test.com/sse", timeout=10.0),
"sse",
"mcp.client.session_group.sse_client",
), # url, headers, timeout, sse_read_timeout
Expand Down
7 changes: 2 additions & 5 deletions tests/issues/test_88_random_error.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Test to reproduce issue #88: Random error thrown on response."""

from collections.abc import Sequence
from datetime import timedelta
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -93,11 +92,9 @@ async def client(
assert not slow_request_lock.is_set()

# Second call should timeout (slow operation with minimal timeout)
# Use 10ms timeout to trigger quickly without waiting
# Use very small timeout to trigger quickly without waiting
with pytest.raises(McpError) as exc_info:
await session.call_tool(
"slow", read_timeout_seconds=timedelta(microseconds=1)
) # artificial timeout that always fails
await session.call_tool("slow", read_timeout_seconds=0.000001) # artificial timeout that always fails
assert "Timed out while waiting" in str(exc_info.value)

# release the slow request not to have hanging process
Expand Down
4 changes: 1 addition & 3 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,10 @@ async def mock_server():
async def make_request(client_session: ClientSession):
try:
# Use a short timeout since we expect this to fail
from datetime import timedelta

await client_session.send_request(
ClientRequest(types.PingRequest()),
types.EmptyResult,
request_read_timeout_seconds=timedelta(seconds=0.5),
request_read_timeout_seconds=0.5,
)
pytest.fail("Expected timeout") # pragma: no cover
except McpError as e:
Expand Down
5 changes: 2 additions & 3 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import socket
import time
from collections.abc import Generator
from datetime import timedelta
from typing import Any
from unittest.mock import MagicMock

Expand Down Expand Up @@ -2370,8 +2369,8 @@ async def test_streamable_http_transport_deprecated_params_ignored(basic_server:
transport = StreamableHTTPTransport( # pyright: ignore[reportDeprecated]
url=f"{basic_server_url}/mcp",
headers={"X-Should-Be-Ignored": "ignored"},
timeout=999,
sse_read_timeout=timedelta(seconds=999),
timeout=999.0,
sse_read_timeout=999.0,
auth=None,
)

Expand Down