Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions packages/client/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ import type {
ReadResourceRequest,
Request,
RequestHandlerExtra,
RequestMethod,
RequestOptions,
RequestTypeMap,
Result,
ServerCapabilities,
ServerToClientRequestMethod,
SubscribeRequest,
Tool,
Transport,
Expand Down Expand Up @@ -235,7 +235,7 @@ export class Client<
RequestT extends Request = Request,
NotificationT extends Notification = Notification,
ResultT extends Result = Result
> extends Protocol<ClientRequest | RequestT, ClientNotification | NotificationT, ClientResult | ResultT> {
> extends Protocol<ClientRequest | RequestT, ClientNotification | NotificationT, ClientResult | ResultT, ServerToClientRequestMethod> {
private _serverCapabilities?: ServerCapabilities;
private _serverVersion?: Implementation;
private _capabilities: ClientCapabilities;
Expand Down Expand Up @@ -328,7 +328,7 @@ export class Client<
/**
* Override request handler registration to enforce client-side validation for elicitation.
*/
public override setRequestHandler<M extends RequestMethod>(
public override setRequestHandler<M extends ServerToClientRequestMethod>(
method: M,
handler: (
request: RequestTypeMap[M],
Expand Down
43 changes: 34 additions & 9 deletions packages/core/src/shared/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,12 @@ type TimeoutInfo = {
* Implements MCP protocol framing on top of a pluggable transport, including
* features like request/response linking, notifications, and progress.
*/
export abstract class Protocol<SendRequestT extends Request, SendNotificationT extends Notification, SendResultT extends Result> {
export abstract class Protocol<
SendRequestT extends Request,
SendNotificationT extends Notification,
SendResultT extends Result,
ReceiveRequestMethod extends RequestMethod = RequestMethod
> {
private _transport?: Transport;
private _requestMessageId = 0;
private _requestHandlers: Map<
Expand Down Expand Up @@ -375,7 +380,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
this._onprogress(notification);
});

this.setRequestHandler(
this._setRequestHandlerInternal(
'ping',
// Automatic pong by default.
_request => ({}) as SendResultT
Expand All @@ -385,7 +390,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
this._taskStore = _options?.taskStore;
this._taskMessageQueue = _options?.taskMessageQueue;
if (this._taskStore) {
this.setRequestHandler('tasks/get', async (request, extra) => {
this._setRequestHandlerInternal('tasks/get', async (request, extra) => {
const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId);
if (!task) {
throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found');
Expand All @@ -398,7 +403,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
} as unknown as SendResultT;
});

this.setRequestHandler('tasks/result', async (request, extra) => {
this._setRequestHandlerInternal('tasks/result', async (request, extra) => {
const handleTaskResult = async (): Promise<SendResultT> => {
const taskId = request.params.taskId;

Expand Down Expand Up @@ -485,7 +490,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
return await handleTaskResult();
});

this.setRequestHandler('tasks/list', async (request, extra) => {
this._setRequestHandlerInternal('tasks/list', async (request, extra) => {
try {
const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, extra.sessionId);
return {
Expand All @@ -501,7 +506,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
}
});

this.setRequestHandler('tasks/cancel', async (request, extra) => {
this._setRequestHandlerInternal('tasks/cancel', async (request, extra) => {
try {
// Get the current task to check if it's in a terminal state, in case the implementation is not atomic
const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId);
Expand Down Expand Up @@ -1393,12 +1398,32 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
await this._transport.send(jsonrpcNotification, options);
}

/**
* Internal method for registering built-in request handlers without the ReceiveRequestMethod constraint.
* Used by Protocol constructor for handlers like ping, tasks/*, which apply to all protocol instances.
* These handlers don't require capability checks as they are core protocol functionality.
*/
private _setRequestHandlerInternal<M extends RequestMethod>(
method: M,
handler: (
request: RequestTypeMap[M],
extra: RequestHandlerExtra<SendRequestT, SendNotificationT>
) => SendResultT | Promise<SendResultT>
): void {
const schema = getRequestSchema(method);

this._requestHandlers.set(method, (request, extra) => {
const parsed = parseWithCompat(schema, request) as RequestTypeMap[M];
return Promise.resolve(handler(parsed, extra));
});
}

/**
* Registers a handler to invoke when this protocol object receives a request with the given method.
*
* Note that this will replace any previous request handler for the same method.
*/
setRequestHandler<M extends RequestMethod>(
setRequestHandler<M extends ReceiveRequestMethod>(
method: M,
handler: (
request: RequestTypeMap[M],
Expand All @@ -1417,14 +1442,14 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
/**
* Removes the request handler for the given method.
*/
removeRequestHandler(method: RequestMethod): void {
removeRequestHandler(method: ReceiveRequestMethod): void {
this._requestHandlers.delete(method);
}

/**
* Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed.
*/
assertCanSetRequestHandler(method: RequestMethod): void {
assertCanSetRequestHandler(method: ReceiveRequestMethod): void {
if (this._requestHandlers.has(method)) {
throw new Error(`A request handler for ${method} already exists, which would be overridden`);
}
Expand Down
4 changes: 4 additions & 0 deletions packages/core/src/types/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2615,6 +2615,10 @@ export type NotificationMethod = ClientNotification['method'] | ServerNotificati
export type RequestTypeMap = MethodToTypeMap<ClientRequest | ServerRequest>;
export type NotificationTypeMap = MethodToTypeMap<ClientNotification | ServerNotification>;

// Narrowed method types for Client and Server request handlers
export type ServerToClientRequestMethod = ServerRequest['method'];
export type ClientToServerRequestMethod = ClientRequest['method'];

/* Runtime schema lookup */
type RequestSchemaType = (typeof ClientRequestSchema.options)[number] | (typeof ServerRequestSchema.options)[number];
type NotificationSchemaType = (typeof ClientNotificationSchema.options)[number] | (typeof ServerNotificationSchema.options)[number];
Expand Down
6 changes: 3 additions & 3 deletions packages/server/src/server/server.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type {
ClientCapabilities,
ClientToServerRequestMethod,
CreateMessageRequest,
CreateMessageRequestParamsBase,
CreateMessageRequestParamsWithTools,
Expand All @@ -21,7 +22,6 @@ import type {
ProtocolOptions,
Request,
RequestHandlerExtra,
RequestMethod,
RequestOptions,
RequestTypeMap,
ResourceUpdatedNotification,
Expand Down Expand Up @@ -130,7 +130,7 @@ export class Server<
RequestT extends Request = Request,
NotificationT extends Notification = Notification,
ResultT extends Result = Result
> extends Protocol<ServerRequest | RequestT, ServerNotification | NotificationT, ServerResult | ResultT> {
> extends Protocol<ServerRequest | RequestT, ServerNotification | NotificationT, ServerResult | ResultT, ClientToServerRequestMethod> {
private _clientCapabilities?: ClientCapabilities;
private _clientVersion?: Implementation;
private _capabilities: ServerCapabilities;
Expand Down Expand Up @@ -215,7 +215,7 @@ export class Server<
/**
* Override request handler registration to enforce server-side validation for tools/call.
*/
public override setRequestHandler<M extends RequestMethod>(
public override setRequestHandler<M extends ClientToServerRequestMethod>(
method: M,
handler: (
request: RequestTypeMap[M],
Expand Down
Loading