diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 9c2c87290..ba92a8765 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -24,11 +24,11 @@ import type { ReadResourceRequest, Request, RequestHandlerExtra, - RequestMethod, RequestOptions, RequestTypeMap, Result, ServerCapabilities, + ServerToClientRequestMethod, SubscribeRequest, Tool, Transport, @@ -235,7 +235,7 @@ export class Client< RequestT extends Request = Request, NotificationT extends Notification = Notification, ResultT extends Result = Result -> extends Protocol { +> extends Protocol { private _serverCapabilities?: ServerCapabilities; private _serverVersion?: Implementation; private _capabilities: ClientCapabilities; @@ -328,7 +328,7 @@ export class Client< /** * Override request handler registration to enforce client-side validation for elicitation. */ - public override setRequestHandler( + public override setRequestHandler( method: M, handler: ( request: RequestTypeMap[M], diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index f537aa86c..8cf7648a7 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -320,7 +320,12 @@ type TimeoutInfo = { * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. */ -export abstract class Protocol { +export abstract class Protocol< + SendRequestT extends Request, + SendNotificationT extends Notification, + SendResultT extends Result, + ReceiveRequestMethod extends RequestMethod = RequestMethod +> { private _transport?: Transport; private _requestMessageId = 0; private _requestHandlers: Map< @@ -375,7 +380,7 @@ export abstract class Protocol ({}) as SendResultT @@ -385,7 +390,7 @@ export abstract class Protocol { + this._setRequestHandlerInternal('tasks/get', async (request, extra) => { const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); if (!task) { throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); @@ -398,7 +403,7 @@ export abstract class Protocol { + this._setRequestHandlerInternal('tasks/result', async (request, extra) => { const handleTaskResult = async (): Promise => { const taskId = request.params.taskId; @@ -485,7 +490,7 @@ export abstract class Protocol { + this._setRequestHandlerInternal('tasks/list', async (request, extra) => { try { const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, extra.sessionId); return { @@ -501,7 +506,7 @@ export abstract class Protocol { + this._setRequestHandlerInternal('tasks/cancel', async (request, extra) => { try { // Get the current task to check if it's in a terminal state, in case the implementation is not atomic const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); @@ -1393,12 +1398,32 @@ export abstract class Protocol( + method: M, + handler: ( + request: RequestTypeMap[M], + extra: RequestHandlerExtra + ) => SendResultT | Promise + ): void { + const schema = getRequestSchema(method); + + this._requestHandlers.set(method, (request, extra) => { + const parsed = parseWithCompat(schema, request) as RequestTypeMap[M]; + return Promise.resolve(handler(parsed, extra)); + }); + } + /** * Registers a handler to invoke when this protocol object receives a request with the given method. * * Note that this will replace any previous request handler for the same method. */ - setRequestHandler( + setRequestHandler( method: M, handler: ( request: RequestTypeMap[M], @@ -1417,14 +1442,14 @@ export abstract class Protocol; export type NotificationTypeMap = MethodToTypeMap; +// Narrowed method types for Client and Server request handlers +export type ServerToClientRequestMethod = ServerRequest['method']; +export type ClientToServerRequestMethod = ClientRequest['method']; + /* Runtime schema lookup */ type RequestSchemaType = (typeof ClientRequestSchema.options)[number] | (typeof ServerRequestSchema.options)[number]; type NotificationSchemaType = (typeof ClientNotificationSchema.options)[number] | (typeof ServerNotificationSchema.options)[number]; diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index ca43272cf..bec489daa 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -1,5 +1,6 @@ import type { ClientCapabilities, + ClientToServerRequestMethod, CreateMessageRequest, CreateMessageRequestParamsBase, CreateMessageRequestParamsWithTools, @@ -21,7 +22,6 @@ import type { ProtocolOptions, Request, RequestHandlerExtra, - RequestMethod, RequestOptions, RequestTypeMap, ResourceUpdatedNotification, @@ -130,7 +130,7 @@ export class Server< RequestT extends Request = Request, NotificationT extends Notification = Notification, ResultT extends Result = Result -> extends Protocol { +> extends Protocol { private _clientCapabilities?: ClientCapabilities; private _clientVersion?: Implementation; private _capabilities: ServerCapabilities; @@ -215,7 +215,7 @@ export class Server< /** * Override request handler registration to enforce server-side validation for tools/call. */ - public override setRequestHandler( + public override setRequestHandler( method: M, handler: ( request: RequestTypeMap[M],