Skip to content

Commit 35a31ed

Browse files
committed
notifications and client side
1 parent 341ad92 commit 35a31ed

File tree

16 files changed

+3332
-6
lines changed

16 files changed

+3332
-6
lines changed

src/mcp/client/session.py

Lines changed: 211 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,95 @@ async def __call__(
4949
) -> None: ... # pragma: no branch
5050

5151

52+
# Experimental: Task handler protocols for server -> client requests
53+
class GetTaskHandlerFnT(Protocol):
54+
"""Handler for tasks/get requests from server.
55+
56+
WARNING: This is experimental and may change without notice.
57+
"""
58+
59+
async def __call__(
60+
self,
61+
context: RequestContext["ClientSession", Any],
62+
params: types.GetTaskRequestParams,
63+
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
64+
65+
66+
class GetTaskResultHandlerFnT(Protocol):
67+
"""Handler for tasks/result requests from server.
68+
69+
WARNING: This is experimental and may change without notice.
70+
"""
71+
72+
async def __call__(
73+
self,
74+
context: RequestContext["ClientSession", Any],
75+
params: types.GetTaskPayloadRequestParams,
76+
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
77+
78+
79+
class ListTasksHandlerFnT(Protocol):
80+
"""Handler for tasks/list requests from server.
81+
82+
WARNING: This is experimental and may change without notice.
83+
"""
84+
85+
async def __call__(
86+
self,
87+
context: RequestContext["ClientSession", Any],
88+
params: types.PaginatedRequestParams | None,
89+
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
90+
91+
92+
class CancelTaskHandlerFnT(Protocol):
93+
"""Handler for tasks/cancel requests from server.
94+
95+
WARNING: This is experimental and may change without notice.
96+
"""
97+
98+
async def __call__(
99+
self,
100+
context: RequestContext["ClientSession", Any],
101+
params: types.CancelTaskRequestParams,
102+
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
103+
104+
105+
class TaskAugmentedSamplingFnT(Protocol):
106+
"""Handler for task-augmented sampling/createMessage requests from server.
107+
108+
When server sends a CreateMessageRequest with task field, this callback
109+
is invoked. The callback should create a task, spawn background work,
110+
and return CreateTaskResult immediately.
111+
112+
WARNING: This is experimental and may change without notice.
113+
"""
114+
115+
async def __call__(
116+
self,
117+
context: RequestContext["ClientSession", Any],
118+
params: types.CreateMessageRequestParams,
119+
task_metadata: types.TaskMetadata,
120+
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
121+
122+
123+
class TaskAugmentedElicitationFnT(Protocol):
124+
"""Handler for task-augmented elicitation/create requests from server.
125+
126+
When server sends an ElicitRequest with task field, this callback
127+
is invoked. The callback should create a task, spawn background work,
128+
and return CreateTaskResult immediately.
129+
130+
WARNING: This is experimental and may change without notice.
131+
"""
132+
133+
async def __call__(
134+
self,
135+
context: RequestContext["ClientSession", Any],
136+
params: types.ElicitRequestParams,
137+
task_metadata: types.TaskMetadata,
138+
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
139+
140+
52141
class MessageHandlerFnT(Protocol):
53142
async def __call__(
54143
self,
@@ -97,6 +186,69 @@ async def _default_logging_callback(
97186
pass
98187

99188

189+
# Default handlers for experimental task requests (return "not supported" errors)
190+
async def _default_get_task_handler(
191+
context: RequestContext["ClientSession", Any],
192+
params: types.GetTaskRequestParams,
193+
) -> types.GetTaskResult | types.ErrorData:
194+
return types.ErrorData(
195+
code=types.METHOD_NOT_FOUND,
196+
message="tasks/get not supported",
197+
)
198+
199+
200+
async def _default_get_task_result_handler(
201+
context: RequestContext["ClientSession", Any],
202+
params: types.GetTaskPayloadRequestParams,
203+
) -> types.GetTaskPayloadResult | types.ErrorData:
204+
return types.ErrorData(
205+
code=types.METHOD_NOT_FOUND,
206+
message="tasks/result not supported",
207+
)
208+
209+
210+
async def _default_list_tasks_handler(
211+
context: RequestContext["ClientSession", Any],
212+
params: types.PaginatedRequestParams | None,
213+
) -> types.ListTasksResult | types.ErrorData:
214+
return types.ErrorData(
215+
code=types.METHOD_NOT_FOUND,
216+
message="tasks/list not supported",
217+
)
218+
219+
220+
async def _default_cancel_task_handler(
221+
context: RequestContext["ClientSession", Any],
222+
params: types.CancelTaskRequestParams,
223+
) -> types.CancelTaskResult | types.ErrorData:
224+
return types.ErrorData(
225+
code=types.METHOD_NOT_FOUND,
226+
message="tasks/cancel not supported",
227+
)
228+
229+
230+
async def _default_task_augmented_sampling_callback(
231+
context: RequestContext["ClientSession", Any],
232+
params: types.CreateMessageRequestParams,
233+
task_metadata: types.TaskMetadata,
234+
) -> types.CreateTaskResult | types.ErrorData:
235+
return types.ErrorData(
236+
code=types.INVALID_REQUEST,
237+
message="Task-augmented sampling not supported",
238+
)
239+
240+
241+
async def _default_task_augmented_elicitation_callback(
242+
context: RequestContext["ClientSession", Any],
243+
params: types.ElicitRequestParams,
244+
task_metadata: types.TaskMetadata,
245+
) -> types.CreateTaskResult | types.ErrorData:
246+
return types.ErrorData(
247+
code=types.INVALID_REQUEST,
248+
message="Task-augmented elicitation not supported",
249+
)
250+
251+
100252
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
101253

102254

@@ -120,6 +272,14 @@ def __init__(
120272
logging_callback: LoggingFnT | None = None,
121273
message_handler: MessageHandlerFnT | None = None,
122274
client_info: types.Implementation | None = None,
275+
tasks_capability: types.ClientTasksCapability | None = None,
276+
# Experimental: Task handlers for server -> client requests
277+
get_task_handler: GetTaskHandlerFnT | None = None,
278+
get_task_result_handler: GetTaskResultHandlerFnT | None = None,
279+
list_tasks_handler: ListTasksHandlerFnT | None = None,
280+
cancel_task_handler: CancelTaskHandlerFnT | None = None,
281+
task_augmented_sampling_callback: TaskAugmentedSamplingFnT | None = None,
282+
task_augmented_elicitation_callback: TaskAugmentedElicitationFnT | None = None,
123283
) -> None:
124284
super().__init__(
125285
read_stream,
@@ -134,9 +294,21 @@ def __init__(
134294
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
135295
self._logging_callback = logging_callback or _default_logging_callback
136296
self._message_handler = message_handler or _default_message_handler
297+
self._tasks_capability = tasks_capability
137298
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
138299
self._server_capabilities: types.ServerCapabilities | None = None
139300
self._experimental: ExperimentalClientFeatures | None = None
301+
# Experimental: Task handlers
302+
self._get_task_handler = get_task_handler or _default_get_task_handler
303+
self._get_task_result_handler = get_task_result_handler or _default_get_task_result_handler
304+
self._list_tasks_handler = list_tasks_handler or _default_list_tasks_handler
305+
self._cancel_task_handler = cancel_task_handler or _default_cancel_task_handler
306+
self._task_augmented_sampling_callback = (
307+
task_augmented_sampling_callback or _default_task_augmented_sampling_callback
308+
)
309+
self._task_augmented_elicitation_callback = (
310+
task_augmented_elicitation_callback or _default_task_augmented_elicitation_callback
311+
)
140312

141313
async def initialize(self) -> types.InitializeResult:
142314
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
@@ -162,6 +334,7 @@ async def initialize(self) -> types.InitializeResult:
162334
elicitation=elicitation,
163335
experimental=None,
164336
roots=roots,
337+
tasks=self._tasks_capability,
165338
),
166339
clientInfo=self._client_info,
167340
),
@@ -187,7 +360,7 @@ def get_server_capabilities(self) -> types.ServerCapabilities | None:
187360
return self._server_capabilities
188361

189362
@property
190-
def experimental(self) -> "ExperimentalClientFeatures":
363+
def experimental(self) -> ExperimentalClientFeatures:
191364
"""Experimental APIs for tasks and other features.
192365
193366
WARNING: These APIs are experimental and may change without notice.
@@ -534,13 +707,21 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
534707
match responder.request.root:
535708
case types.CreateMessageRequest(params=params):
536709
with responder:
537-
response = await self._sampling_callback(ctx, params)
710+
# Check if this is a task-augmented request
711+
if params.task is not None:
712+
response = await self._task_augmented_sampling_callback(ctx, params, params.task)
713+
else:
714+
response = await self._sampling_callback(ctx, params)
538715
client_response = ClientResponse.validate_python(response)
539716
await responder.respond(client_response)
540717

541718
case types.ElicitRequest(params=params):
542719
with responder:
543-
response = await self._elicitation_callback(ctx, params)
720+
# Check if this is a task-augmented request
721+
if params.task is not None:
722+
response = await self._task_augmented_elicitation_callback(ctx, params, params.task)
723+
else:
724+
response = await self._elicitation_callback(ctx, params)
544725
client_response = ClientResponse.validate_python(response)
545726
await responder.respond(client_response)
546727

@@ -553,7 +734,33 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
553734
case types.PingRequest(): # pragma: no cover
554735
with responder:
555736
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
556-
case _:
737+
738+
# Experimental: Task management requests from server
739+
case types.GetTaskRequest(params=params):
740+
with responder:
741+
response = await self._get_task_handler(ctx, params)
742+
client_response = ClientResponse.validate_python(response)
743+
await responder.respond(client_response)
744+
745+
case types.GetTaskPayloadRequest(params=params):
746+
with responder:
747+
response = await self._get_task_result_handler(ctx, params)
748+
client_response = ClientResponse.validate_python(response)
749+
await responder.respond(client_response)
750+
751+
case types.ListTasksRequest(params=params):
752+
with responder:
753+
response = await self._list_tasks_handler(ctx, params)
754+
client_response = ClientResponse.validate_python(response)
755+
await responder.respond(client_response)
756+
757+
case types.CancelTaskRequest(params=params):
758+
with responder:
759+
response = await self._cancel_task_handler(ctx, params)
760+
client_response = ClientResponse.validate_python(response)
761+
await responder.respond(client_response)
762+
763+
case _: # pragma: no cover
557764
raise NotImplementedError()
558765

559766
async def _handle_incoming(

src/mcp/shared/experimental/tasks/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
- TaskStore: Abstract interface for task state storage
66
- TaskContext: Context object for task work to interact with state/notifications
77
- InMemoryTaskStore: Reference implementation for testing/development
8+
- TaskMessageQueue: FIFO queue for task messages delivered via tasks/result
9+
- InMemoryTaskMessageQueue: Reference implementation for message queue
810
- Helper functions: run_task, is_terminal, create_task_state, generate_task_id
911
1012
Architecture:
1113
- TaskStore is pure storage - it doesn't know about execution
14+
- TaskMessageQueue stores messages to be delivered via tasks/result
1215
- TaskContext wraps store + session, providing a clean API for task work
1316
- run_task is optional convenience for spawning in-process tasks
1417
@@ -24,15 +27,31 @@
2427
task_execution,
2528
)
2629
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
30+
from mcp.shared.experimental.tasks.message_queue import (
31+
InMemoryTaskMessageQueue,
32+
QueuedMessage,
33+
TaskMessageQueue,
34+
)
35+
from mcp.shared.experimental.tasks.result_handler import (
36+
TaskResultHandler,
37+
create_task_result_handler,
38+
)
2739
from mcp.shared.experimental.tasks.store import TaskStore
40+
from mcp.shared.experimental.tasks.task_session import TaskSession
2841

2942
__all__ = [
3043
"TaskStore",
3144
"TaskContext",
45+
"TaskSession",
46+
"TaskResultHandler",
3247
"InMemoryTaskStore",
48+
"TaskMessageQueue",
49+
"InMemoryTaskMessageQueue",
50+
"QueuedMessage",
3351
"run_task",
3452
"task_execution",
3553
"is_terminal",
3654
"create_task_state",
3755
"generate_task_id",
56+
"create_task_result_handler",
3857
]

src/mcp/shared/experimental/tasks/in_memory_task_store.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
For production, consider implementing TaskStore with a database or distributed cache.
99
"""
1010

11+
import asyncio
1112
from dataclasses import dataclass, field
1213
from datetime import datetime, timedelta, timezone
1314

@@ -46,6 +47,7 @@ class InMemoryTaskStore(TaskStore):
4647
def __init__(self, page_size: int = 10) -> None:
4748
self._tasks: dict[str, StoredTask] = {}
4849
self._page_size = page_size
50+
self._update_events: dict[str, asyncio.Event] = {}
4951

5052
def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
5153
"""Calculate expiry time from TTL in milliseconds."""
@@ -111,8 +113,10 @@ async def update_task(
111113
if stored is None:
112114
raise ValueError(f"Task with ID {task_id} not found")
113115

114-
if status is not None:
116+
status_changed = False
117+
if status is not None and stored.task.status != status:
115118
stored.task.status = status
119+
status_changed = True
116120

117121
if status_message is not None:
118122
stored.task.statusMessage = status_message
@@ -121,6 +125,10 @@ async def update_task(
121125
if status is not None and is_terminal(status) and stored.task.ttl is not None:
122126
stored.expires_at = self._calculate_expiry(stored.task.ttl)
123127

128+
# Notify waiters if status changed
129+
if status_changed:
130+
await self.notify_update(task_id)
131+
124132
return Task(**stored.task.model_dump())
125133

126134
async def store_result(self, task_id: str, result: Result) -> None:
@@ -175,11 +183,31 @@ async def delete_task(self, task_id: str) -> bool:
175183
del self._tasks[task_id]
176184
return True
177185

186+
async def wait_for_update(self, task_id: str) -> None:
187+
"""Wait until the task status changes."""
188+
if task_id not in self._tasks:
189+
raise ValueError(f"Task with ID {task_id} not found")
190+
191+
# Get or create the event for this task
192+
if task_id not in self._update_events:
193+
self._update_events[task_id] = asyncio.Event()
194+
195+
event = self._update_events[task_id]
196+
# Clear before waiting so we wait for NEW updates
197+
event.clear()
198+
await event.wait()
199+
200+
async def notify_update(self, task_id: str) -> None:
201+
"""Signal that a task has been updated."""
202+
if task_id in self._update_events:
203+
self._update_events[task_id].set()
204+
178205
# --- Testing/debugging helpers ---
179206

180207
def cleanup(self) -> None:
181208
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
182209
self._tasks.clear()
210+
self._update_events.clear()
183211

184212
def get_all_tasks(self) -> list[Task]:
185213
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""

0 commit comments

Comments
 (0)