From 697d5d12b2c355056e37e528c569fc032c9cdf2a Mon Sep 17 00:00:00 2001 From: Matt Carey Date: Tue, 3 Feb 2026 10:50:47 +0000 Subject: [PATCH 1/3] feat!: drop Zod v3 support BREAKING CHANGE: This SDK now requires Zod v4.0 or later. - Remove all Zod v3 imports, types, and runtime detection - Simplify `SchemaOutput` to use direct indexed access `S['_zod']['output']` - Remove `zod-to-json-schema` dependency (z4-mini has built-in JSON schema support) - Add closure-based handlers for prompts and tools to eliminate runtime type casts - Simplify `executeToolHandler` and `handleAutomaticTaskPolling` - Remove v3 test blocks from integration tests --- README.md | 2 +- .../client/src/experimental/tasks/client.ts | 6 +- packages/core/package.json | 3 +- packages/core/src/shared/protocol.ts | 8 +- packages/core/src/util/zodCompat.ts | 146 ++--------- packages/core/src/util/zodJsonSchemaCompat.ts | 24 +- packages/server/package.json | 3 +- .../server/src/experimental/tasks/server.ts | 4 +- packages/server/src/server/completable.ts | 1 - packages/server/src/server/mcp.ts | 226 +++++++++++++----- pnpm-lock.yaml | 15 +- pnpm-workspace.yaml | 3 +- test/helpers/src/fixtures/zodTestMatrix.ts | 15 +- 13 files changed, 213 insertions(+), 243 deletions(-) diff --git a/README.md b/README.md index 658769313..745887313 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ This monorepo publishes split packages: - **`@modelcontextprotocol/server`**: build MCP servers - **`@modelcontextprotocol/client`**: build MCP clients -Both packages have a **required peer dependency** on `zod` for schema validation. The SDK internally imports from `zod/v4`, but remains compatible with projects using Zod v3.25+. +Both packages have a **required peer dependency** on `zod` for schema validation. The SDK uses Zod v4. ### Middleware packages (optional) diff --git a/packages/client/src/experimental/tasks/client.ts b/packages/client/src/experimental/tasks/client.ts index df57e91a4..d5f310fbc 100644 --- a/packages/client/src/experimental/tasks/client.ts +++ b/packages/client/src/experimental/tasks/client.ts @@ -33,7 +33,7 @@ interface ClientInternal { request: ClientRequest | RequestT, resultSchema: T, options?: RequestOptions - ): AsyncGenerator>, void, void>; + ): AsyncGenerator & Result>, void, void>; isToolTask(toolName: string): boolean; getToolOutputValidator(toolName: string): ((data: unknown) => { valid: boolean; errorMessage?: string }) | undefined; } @@ -260,14 +260,14 @@ export class ExperimentalClientTasks< request: ClientRequest | RequestT, resultSchema: T, options?: RequestOptions - ): AsyncGenerator>, void, void> { + ): AsyncGenerator & Result>, void, void> { // Delegate to the client's underlying Protocol method type ClientWithRequestStream = { requestStream( request: ClientRequest | RequestT, resultSchema: U, options?: RequestOptions - ): AsyncGenerator>, void, void>; + ): AsyncGenerator & Result>, void, void>; }; return (this._client as unknown as ClientWithRequestStream).requestStream(request, resultSchema, options); } diff --git a/packages/core/package.json b/packages/core/package.json index 33bfe54ed..cbda5734c 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -37,8 +37,7 @@ "ajv": "catalog:runtimeShared", "ajv-formats": "catalog:runtimeShared", "json-schema-typed": "catalog:runtimeShared", - "zod": "catalog:runtimeShared", - "zod-to-json-schema": "catalog:runtimeShared" + "zod": "catalog:runtimeShared" }, "peerDependencies": { "@cfworker/json-schema": "catalog:runtimeShared", diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index f537aa86c..6909f17ff 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -981,14 +981,14 @@ export abstract class Protocol>, void, void> { + ): AsyncGenerator & Result>, void, void> { const { task } = options ?? {}; // For non-task requests, just yield the result if (!task) { try { const result = await this.request(request, resultSchema, options); - yield { type: 'result', result }; + yield { type: 'result', result: result as SchemaOutput & Result }; } catch (error) { yield { type: 'error', @@ -1025,7 +1025,7 @@ export abstract class Protocol & Result }; break; } @@ -1054,7 +1054,7 @@ export abstract class Protocol & Result }; return; } diff --git a/packages/core/src/util/zodCompat.ts b/packages/core/src/util/zodCompat.ts index 667a6897b..813494d83 100644 --- a/packages/core/src/util/zodCompat.ts +++ b/packages/core/src/util/zodCompat.ts @@ -1,32 +1,17 @@ // zod-compat.ts // ---------------------------------------------------- -// Unified types + helpers to accept Zod v3 and v4 (Mini) +// Types + helpers for Zod v4 (Mini) schemas // ---------------------------------------------------- -import type * as z3 from 'zod/v3'; -import * as z3rt from 'zod/v3'; import type * as z4 from 'zod/v4/core'; import * as z4mini from 'zod/v4-mini'; -// --- Unified schema types --- -export type AnySchema = z3.ZodTypeAny | z4.$ZodType; -export type AnyObjectSchema = z3.AnyZodObject | z4.$ZodObject | AnySchema; +// --- Schema types --- +export type AnySchema = z4.$ZodType; +export type AnyObjectSchema = z4.$ZodObject | AnySchema; export type ZodRawShapeCompat = Record; // --- Internal property access helpers --- -// These types help us safely access internal properties that differ between v3 and v4 -export interface ZodV3Internal { - _def?: { - typeName?: string; - value?: unknown; - values?: unknown[]; - shape?: Record | (() => Record); - description?: string; - }; - shape?: Record | (() => Record); - value?: unknown; -} - export interface ZodV4Internal { _zod?: { def?: { @@ -40,9 +25,10 @@ export interface ZodV4Internal { } // --- Type inference helpers --- -export type SchemaOutput = S extends z3.ZodTypeAny ? z3.infer : S extends z4.$ZodType ? z4.output : never; - -export type SchemaInput = S extends z3.ZodTypeAny ? z3.input : S extends z4.$ZodType ? z4.input : never; +// Use direct indexed access for better generic type inference +// This avoids the conditional type in z4.output which resolves to unknown for generic S +export type SchemaOutput = S['_zod']['output']; +export type SchemaInput = S['_zod']['input']; /** * Infers the output type from a ZodRawShapeCompat (raw shape object). @@ -52,25 +38,9 @@ export type ShapeOutput = { [K in keyof Shape]: SchemaOutput; }; -// --- Runtime detection --- -export function isZ4Schema(s: AnySchema): s is z4.$ZodType { - // Present on Zod 4 (Classic & Mini) schemas; absent on Zod 3 - const schema = s as unknown as ZodV4Internal; - return !!schema._zod; -} - // --- Schema construction --- export function objectFromShape(shape: ZodRawShapeCompat): AnyObjectSchema { - const values = Object.values(shape); - if (values.length === 0) return z4mini.object({}); // default to v4 Mini - - const allV4 = values.every(element => isZ4Schema(element)); - const allV3 = values.every(s => !isZ4Schema(s)); - - if (allV4) return z4mini.object(shape as Record); - if (allV3) return z3rt.object(shape as Record); - - throw new Error('Mixed Zod versions detected in object shape.'); + return z4mini.object(shape as Record); } // --- Unified parsing --- @@ -78,13 +48,7 @@ export function safeParse( schema: S, data: unknown ): { success: true; data: SchemaOutput } | { success: false; error: unknown } { - if (isZ4Schema(schema)) { - // Mini exposes top-level safeParse - const result = z4mini.safeParse(schema, data); - return result as { success: true; data: SchemaOutput } | { success: false; error: unknown }; - } - const v3Schema = schema as z3.ZodTypeAny; - const result = v3Schema.safeParse(data); + const result = z4mini.safeParse(schema, data); return result as { success: true; data: SchemaOutput } | { success: false; error: unknown }; } @@ -92,13 +56,7 @@ export async function safeParseAsync( schema: S, data: unknown ): Promise<{ success: true; data: SchemaOutput } | { success: false; error: unknown }> { - if (isZ4Schema(schema)) { - // Mini exposes top-level safeParseAsync - const result = await z4mini.safeParseAsync(schema, data); - return result as { success: true; data: SchemaOutput } | { success: false; error: unknown }; - } - const v3Schema = schema as z3.ZodTypeAny; - const result = await v3Schema.safeParseAsync(data); + const result = await z4mini.safeParseAsync(schema, data); return result as { success: true; data: SchemaOutput } | { success: false; error: unknown }; } @@ -106,16 +64,8 @@ export async function safeParseAsync( export function getObjectShape(schema: AnyObjectSchema | undefined): Record | undefined { if (!schema) return undefined; - // Zod v3 exposes `.shape`; Zod v4 keeps the shape on `_zod.def.shape` - let rawShape: Record | (() => Record) | undefined; - - if (isZ4Schema(schema)) { - const v4Schema = schema as unknown as ZodV4Internal; - rawShape = v4Schema._zod?.def?.shape; - } else { - const v3Schema = schema as unknown as ZodV3Internal; - rawShape = v3Schema.shape; - } + const v4Schema = schema as unknown as ZodV4Internal; + const rawShape = v4Schema._zod?.def?.shape; if (!rawShape) return undefined; @@ -133,22 +83,18 @@ export function getObjectShape(schema: AnyObjectSchema | undefined): Record) - // Raw shapes don't have _def or _zod properties and aren't schemas themselves if (typeof schema === 'object') { - // Check if it's actually a ZodRawShapeCompat (not a schema instance) - // by checking if it lacks schema-like internal properties - const asV3 = schema as unknown as ZodV3Internal; const asV4 = schema as unknown as ZodV4Internal; - // If it's not a schema instance (no _def or _zod), it might be a raw shape - if (!asV3._def && !asV4._zod) { + // If it's not a schema instance (no _zod), it might be a raw shape + if (!asV4._zod) { // Check if all values are schemas (heuristic to confirm it's a raw shape) const values = Object.values(schema); if ( @@ -157,9 +103,7 @@ export function normalizeObjectSchema(schema: AnySchema | ZodRawShapeCompat | un v => typeof v === 'object' && v !== null && - ((v as unknown as ZodV3Internal)._def !== undefined || - (v as unknown as ZodV4Internal)._zod !== undefined || - typeof (v as { parse?: unknown }).parse === 'function') + ((v as unknown as ZodV4Internal)._zod !== undefined || typeof (v as { parse?: unknown }).parse === 'function') ) ) { return objectFromShape(schema as ZodRawShapeCompat); @@ -167,21 +111,11 @@ export function normalizeObjectSchema(schema: AnySchema | ZodRawShapeCompat | un } } - // If we get here, it should be an AnySchema (not a raw shape) // Check if it's already an object schema - if (isZ4Schema(schema as AnySchema)) { - // Check if it's a v4 object - const v4Schema = schema as unknown as ZodV4Internal; - const def = v4Schema._zod?.def; - if (def && (def.type === 'object' || def.shape !== undefined)) { - return schema as AnyObjectSchema; - } - } else { - // Check if it's a v3 object - const v3Schema = schema as unknown as ZodV3Internal; - if (v3Schema.shape !== undefined) { - return schema as AnyObjectSchema; - } + const v4Schema = schema as unknown as ZodV4Internal; + const def = v4Schema._zod?.def; + if (def && (def.type === 'object' || def.shape !== undefined)) { + return schema as AnyObjectSchema; } return undefined; @@ -190,11 +124,9 @@ export function normalizeObjectSchema(schema: AnySchema | ZodRawShapeCompat | un // --- Error message extraction --- /** * Safely extracts an error message from a parse result error. - * Zod errors can have different structures, so we handle various cases. */ export function getParseErrorMessage(error: unknown): string { if (error && typeof error === 'object') { - // Try common error structures if ('message' in error && typeof error.message === 'string') { return error.message; } @@ -204,7 +136,6 @@ export function getParseErrorMessage(error: unknown): string { return String(firstIssue.message); } } - // Fallback: try to stringify the error try { return JSON.stringify(error); } catch { @@ -217,10 +148,6 @@ export function getParseErrorMessage(error: unknown): string { // --- Schema metadata access --- /** * Gets the description from a schema, if available. - * Works with both Zod v3 and v4. - * - * Both versions expose a `.description` getter that returns the description - * from their respective internal storage (v3: _def, v4: globalRegistry). */ export function getSchemaDescription(schema: AnySchema): string | undefined { return (schema as { description?: string }).description; @@ -228,47 +155,26 @@ export function getSchemaDescription(schema: AnySchema): string | undefined { /** * Checks if a schema is optional. - * Works with both Zod v3 and v4. */ export function isSchemaOptional(schema: AnySchema): boolean { - if (isZ4Schema(schema)) { - const v4Schema = schema as unknown as ZodV4Internal; - return v4Schema._zod?.def?.type === 'optional'; - } - const v3Schema = schema as unknown as ZodV3Internal; - // v3 has isOptional() method - if (typeof (schema as { isOptional?: () => boolean }).isOptional === 'function') { - return (schema as { isOptional: () => boolean }).isOptional(); - } - return v3Schema._def?.typeName === 'ZodOptional'; + const v4Schema = schema as unknown as ZodV4Internal; + return v4Schema._zod?.def?.type === 'optional'; } /** * Gets the literal value from a schema, if it's a literal schema. - * Works with both Zod v3 and v4. * Returns undefined if the schema is not a literal or the value cannot be determined. */ export function getLiteralValue(schema: AnySchema): unknown { - if (isZ4Schema(schema)) { - const v4Schema = schema as unknown as ZodV4Internal; - const def = v4Schema._zod?.def; - if (def) { - // Try various ways to get the literal value - if (def.value !== undefined) return def.value; - if (Array.isArray(def.values) && def.values.length > 0) { - return def.values[0]; - } - } - } - const v3Schema = schema as unknown as ZodV3Internal; - const def = v3Schema._def; + const v4Schema = schema as unknown as ZodV4Internal; + const def = v4Schema._zod?.def; if (def) { if (def.value !== undefined) return def.value; if (Array.isArray(def.values) && def.values.length > 0) { return def.values[0]; } } - // Fallback: check for direct value property (some Zod versions) + // Fallback: check for direct value property const directValue = (schema as { value?: unknown }).value; if (directValue !== undefined) return directValue; return undefined; diff --git a/packages/core/src/util/zodJsonSchemaCompat.ts b/packages/core/src/util/zodJsonSchemaCompat.ts index 68bd98edc..404c8d4d3 100644 --- a/packages/core/src/util/zodJsonSchemaCompat.ts +++ b/packages/core/src/util/zodJsonSchemaCompat.ts @@ -1,20 +1,17 @@ // zod-json-schema-compat.ts // ---------------------------------------------------- -// JSON Schema conversion for both Zod v3 and Zod v4 (Mini) -// v3 uses your vendored converter; v4 uses Mini's toJSONSchema +// JSON Schema conversion for Zod v4 (Mini) // ---------------------------------------------------- -import type * as z3 from 'zod/v3'; import type * as z4c from 'zod/v4/core'; import * as z4mini from 'zod/v4-mini'; -import { zodToJsonSchema } from 'zod-to-json-schema'; import type { AnyObjectSchema, AnySchema, SchemaOutput } from './zodCompat.js'; -import { isZ4Schema, safeParse } from './zodCompat.js'; +import { safeParse } from './zodCompat.js'; type JsonSchema = Record; -// Options accepted by call sites; we map them appropriately +// Options accepted by call sites type CommonOpts = { strictUnions?: boolean; pipeStrategy?: 'input' | 'output'; @@ -29,18 +26,9 @@ function mapMiniTarget(t: CommonOpts['target'] | undefined): 'draft-7' | 'draft- } export function toJsonSchemaCompat(schema: AnyObjectSchema, opts?: CommonOpts): JsonSchema { - if (isZ4Schema(schema)) { - // v4 branch — use Mini's built-in toJSONSchema - return z4mini.toJSONSchema(schema as z4c.$ZodType, { - target: mapMiniTarget(opts?.target), - io: opts?.pipeStrategy ?? 'input' - }) as JsonSchema; - } - - // v3 branch — use vendored converter - return zodToJsonSchema(schema as z3.ZodTypeAny, { - strictUnions: opts?.strictUnions ?? true, - pipeStrategy: opts?.pipeStrategy ?? 'input' + return z4mini.toJSONSchema(schema as z4c.$ZodType, { + target: mapMiniTarget(opts?.target), + io: opts?.pipeStrategy ?? 'input' }) as JsonSchema; } diff --git a/packages/server/package.json b/packages/server/package.json index b4fed9b4b..6445fa041 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -45,8 +45,7 @@ "client": "tsx scripts/cli.ts client" }, "dependencies": { - "zod": "catalog:runtimeShared", - "zod-to-json-schema": "catalog:runtimeShared" + "zod": "catalog:runtimeShared" }, "peerDependencies": { "@cfworker/json-schema": "catalog:runtimeShared", diff --git a/packages/server/src/experimental/tasks/server.ts b/packages/server/src/experimental/tasks/server.ts index 33bde3298..d9d887fe3 100644 --- a/packages/server/src/experimental/tasks/server.ts +++ b/packages/server/src/experimental/tasks/server.ts @@ -58,14 +58,14 @@ export class ExperimentalServerTasks< request: ServerRequest | RequestT, resultSchema: T, options?: RequestOptions - ): AsyncGenerator>, void, void> { + ): AsyncGenerator & Result>, void, void> { // Delegate to the server's underlying Protocol method type ServerWithRequestStream = { requestStream( request: ServerRequest | RequestT, resultSchema: U, options?: RequestOptions - ): AsyncGenerator>, void, void>; + ): AsyncGenerator & Result>, void, void>; }; return (this._server as unknown as ServerWithRequestStream).requestStream(request, resultSchema, options); } diff --git a/packages/server/src/server/completable.ts b/packages/server/src/server/completable.ts index 7174bff37..fed16704d 100644 --- a/packages/server/src/server/completable.ts +++ b/packages/server/src/server/completable.ts @@ -19,7 +19,6 @@ export type CompletableSchema = T & { /** * Wraps a Zod type to provide autocompletion capabilities. Useful for, e.g., prompt arguments in MCP. - * Works with both Zod v3 and v4 schemas. */ export function completable(schema: T, complete: CompleteCallback): CompletableSchema { Object.defineProperty(schema as object, COMPLETABLE_SYMBOL, { diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index b2c17b71e..69024a829 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -317,35 +317,8 @@ export class McpServer { args: unknown, extra: RequestHandlerExtra ): Promise { - const handler = tool.handler as AnyToolHandler; - const isTaskHandler = 'createTask' in handler; - - if (isTaskHandler) { - if (!extra.taskStore) { - throw new Error('No task store provided.'); - } - const taskExtra = { ...extra, taskStore: extra.taskStore }; - - if (tool.inputSchema) { - const typedHandler = handler as ToolTaskHandler; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler.createTask(args as any, taskExtra)); - } else { - const typedHandler = handler as ToolTaskHandler; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((typedHandler.createTask as any)(taskExtra)); - } - } - - if (tool.inputSchema) { - const typedHandler = handler as ToolCallback; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler(args as any, extra)); - } else { - const typedHandler = handler as ToolCallback; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((typedHandler as any)(extra)); - } + // Executor encapsulates handler invocation with proper types + return tool.executor(args, extra); } /** @@ -360,15 +333,9 @@ export class McpServer { throw new Error('No task store provided for task-capable tool.'); } - // Validate input and create task + // Validate input and create task using the executor const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); - const handler = tool.handler as ToolTaskHandler; - const taskExtra = { ...extra, taskStore: extra.taskStore }; - - const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined - ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, taskExtra)) - : // eslint-disable-next-line @typescript-eslint/no-explicit-any - await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(taskExtra)); + const createTaskResult = (await tool.executor(args, extra)) as CreateTaskResult; // Poll until completion const taskId = createTaskResult.task.taskId; @@ -597,23 +564,8 @@ export class McpServer { throw new McpError(ErrorCode.InvalidParams, `Prompt ${request.params.name} disabled`); } - if (prompt.argsSchema) { - const argsObj = normalizeObjectSchema(prompt.argsSchema) as AnyObjectSchema; - const parseResult = await safeParseAsync(argsObj, request.params.arguments); - if (!parseResult.success) { - const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; - const errorMessage = getParseErrorMessage(error); - throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for prompt ${request.params.name}: ${errorMessage}`); - } - - const args = parseResult.data; - const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(args, extra)); - } else { - const cb = prompt.callback as PromptCallback; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve((cb as any)(extra)); - } + // Handler encapsulates parsing and callback invocation with proper types + return prompt.handler(request.params.arguments, extra); }); this._promptHandlersInitialized = true; @@ -752,11 +704,15 @@ export class McpServer { argsSchema: PromptArgsRawShape | undefined, callback: PromptCallback ): RegisteredPrompt { + // Track current schema and callback for handler regeneration + let currentArgsSchema = argsSchema; + let currentCallback = callback; + const registeredPrompt: RegisteredPrompt = { title, description, argsSchema: argsSchema === undefined ? undefined : objectFromShape(argsSchema), - callback, + handler: createPromptHandler(name, argsSchema, callback), enabled: true, disable: () => registeredPrompt.update({ enabled: false }), enable: () => registeredPrompt.update({ enabled: true }), @@ -768,8 +724,22 @@ export class McpServer { } if (updates.title !== undefined) registeredPrompt.title = updates.title; if (updates.description !== undefined) registeredPrompt.description = updates.description; - if (updates.argsSchema !== undefined) registeredPrompt.argsSchema = objectFromShape(updates.argsSchema); - if (updates.callback !== undefined) registeredPrompt.callback = updates.callback; + + // Track if we need to regenerate the handler + let needsHandlerRegen = false; + if (updates.argsSchema !== undefined) { + registeredPrompt.argsSchema = objectFromShape(updates.argsSchema); + currentArgsSchema = updates.argsSchema; + needsHandlerRegen = true; + } + if (updates.callback !== undefined) { + currentCallback = updates.callback as PromptCallback; + needsHandlerRegen = true; + } + if (needsHandlerRegen) { + registeredPrompt.handler = createPromptHandler(name, currentArgsSchema, currentCallback); + } + if (updates.enabled !== undefined) registeredPrompt.enabled = updates.enabled; this.sendPromptListChanged(); } @@ -804,6 +774,10 @@ export class McpServer { // Validate tool name according to SEP specification validateAndWarnToolName(name); + // Track current schema and handler for executor regeneration + let currentInputSchema = inputSchema; + let currentHandler = handler; + const registeredTool: RegisteredTool = { title, description, @@ -813,6 +787,7 @@ export class McpServer { execution, _meta, handler: handler, + executor: createToolExecutor(inputSchema, handler), enabled: true, disable: () => registeredTool.update({ enabled: false }), enable: () => registeredTool.update({ enabled: true }), @@ -827,9 +802,24 @@ export class McpServer { } if (updates.title !== undefined) registeredTool.title = updates.title; if (updates.description !== undefined) registeredTool.description = updates.description; - if (updates.paramsSchema !== undefined) registeredTool.inputSchema = objectFromShape(updates.paramsSchema); + + // Track if we need to regenerate the executor + let needsExecutorRegen = false; + if (updates.paramsSchema !== undefined) { + registeredTool.inputSchema = objectFromShape(updates.paramsSchema); + currentInputSchema = updates.paramsSchema; + needsExecutorRegen = true; + } + if (updates.callback !== undefined) { + registeredTool.handler = updates.callback; + currentHandler = updates.callback as AnyToolHandler; + needsExecutorRegen = true; + } + if (needsExecutorRegen) { + registeredTool.executor = createToolExecutor(currentInputSchema, currentHandler); + } + if (updates.outputSchema !== undefined) registeredTool.outputSchema = objectFromShape(updates.outputSchema); - if (updates.callback !== undefined) registeredTool.handler = updates.callback; if (updates.annotations !== undefined) registeredTool.annotations = updates.annotations; if (updates._meta !== undefined) registeredTool._meta = updates._meta; if (updates.enabled !== undefined) registeredTool.enabled = updates.enabled; @@ -1045,6 +1035,15 @@ export type ToolCallback = ToolCallback | ToolTaskHandler; +/** + * Internal executor type that encapsulates handler invocation with proper types. + * This allows type-safe handling without runtime type assertions at the call site. + */ +type ToolExecutor = ( + args: unknown, + extra: RequestHandlerExtra +) => Promise; + export type RegisteredTool = { title?: string; description?: string; @@ -1054,6 +1053,8 @@ export type RegisteredTool = { execution?: ToolExecution; _meta?: Record; handler: AnyToolHandler; + /** @internal */ + executor: ToolExecutor; enabled: boolean; enable(): void; disable(): void; @@ -1090,6 +1091,57 @@ function isZodTypeLike(value: unknown): value is AnySchema { ); } +/** + * Creates a type-safe tool executor that captures the schema and handler in a closure. + * This eliminates the need for type assertions at the call site. + */ +function createToolExecutor( + inputSchema: ZodRawShapeCompat | AnySchema | undefined, + handler: AnyToolHandler +): ToolExecutor { + const isTaskHandler = 'createTask' in handler; + + if (isTaskHandler) { + // Task handler path + const taskHandler = handler as ToolTaskHandler; + + if (inputSchema) { + const typedHandler = taskHandler as ToolTaskHandler; + return async (args, extra) => { + if (!extra.taskStore) { + throw new Error('No task store provided.'); + } + const taskExtra = { ...extra, taskStore: extra.taskStore }; + return typedHandler.createTask(args as ShapeOutput, taskExtra); + }; + } else { + const typedHandler = taskHandler as ToolTaskHandler; + return async (_args, extra) => { + if (!extra.taskStore) { + throw new Error('No task store provided.'); + } + const taskExtra = { ...extra, taskStore: extra.taskStore }; + return typedHandler.createTask(taskExtra); + }; + } + } else { + // Regular callback path + const callback = handler as ToolCallback; + + if (inputSchema) { + const typedCallback = callback as ToolCallback; + return async (args, extra) => { + return typedCallback(args as ShapeOutput, extra); + }; + } else { + const typedCallback = callback as ToolCallback; + return async (_args, extra) => { + return typedCallback(extra); + }; + } + } +} + /** * Checks if an object is a Zod schema instance (v3 or v4). * @@ -1219,11 +1271,21 @@ export type PromptCallback, extra: RequestHandlerExtra) => GetPromptResult | Promise : (extra: RequestHandlerExtra) => GetPromptResult | Promise; +/** + * Internal handler type that encapsulates parsing and callback invocation. + * This allows type-safe handling without runtime type assertions. + */ +type PromptHandler = ( + args: Record | undefined, + extra: RequestHandlerExtra +) => Promise; + export type RegisteredPrompt = { title?: string; description?: string; argsSchema?: AnyObjectSchema; - callback: PromptCallback; + /** @internal */ + handler: PromptHandler; enabled: boolean; enable(): void; disable(): void; @@ -1238,6 +1300,45 @@ export type RegisteredPrompt = { remove(): void; }; +/** + * Creates a type-safe prompt handler that captures the schema and callback in a closure. + * This eliminates the need for type assertions at the call site. + */ +function createPromptHandler( + name: string, + argsSchema: PromptArgsRawShape | undefined, + callback: PromptCallback +): PromptHandler { + if (argsSchema) { + // Capture the schema and callback with their types at registration time + const schemaObj = objectFromShape(argsSchema); + const typedCallback = callback as ( + args: ShapeOutput, + extra: RequestHandlerExtra + ) => GetPromptResult | Promise; + + return async (args, extra) => { + const parseResult = await safeParseAsync(schemaObj, args); + if (!parseResult.success) { + const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; + const errorMessage = getParseErrorMessage(error); + throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for prompt ${name}: ${errorMessage}`); + } + // parseResult.data is validated against schemaObj, so it matches ShapeOutput + return typedCallback(parseResult.data as ShapeOutput, extra); + }; + } else { + // No args schema - callback takes only extra + const typedCallback = callback as ( + extra: RequestHandlerExtra + ) => GetPromptResult | Promise; + + return async (_args, extra) => { + return typedCallback(extra); + }; + } +} + function promptArgumentsFromSchema(schema: AnyObjectSchema): PromptArgument[] { const shape = getObjectShape(schema); if (!shape) return []; @@ -1254,10 +1355,11 @@ function promptArgumentsFromSchema(schema: AnyObjectSchema): PromptArgument[] { }); } -function createCompletionResult(suggestions: string[]): CompleteResult { +function createCompletionResult(suggestions: readonly unknown[]): CompleteResult { + const values = suggestions.map(String).slice(0, 100); return { completion: { - values: suggestions.slice(0, 100), + values, total: suggestions.length, hasMore: suggestions.length > 100 } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index cf05bc33e..10c3f9729 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -118,11 +118,8 @@ catalogs: specifier: ^5.0.0 version: 5.0.1 zod: - specifier: ^3.25 || ^4.0 + specifier: ^4.0 version: 4.3.5 - zod-to-json-schema: - specifier: ^3.25.0 - version: 3.25.1 overrides: strip-ansi: 6.0.1 @@ -548,9 +545,6 @@ importers: zod: specifier: catalog:runtimeShared version: 4.3.5 - zod-to-json-schema: - specifier: catalog:runtimeShared - version: 3.25.1(zod@4.3.5) devDependencies: '@cfworker/json-schema': specifier: catalog:runtimeShared @@ -780,9 +774,6 @@ importers: zod: specifier: catalog:runtimeShared version: 4.3.5 - zod-to-json-schema: - specifier: catalog:runtimeShared - version: 3.25.1(zod@4.3.5) devDependencies: '@cfworker/json-schema': specifier: catalog:runtimeShared @@ -7273,10 +7264,6 @@ snapshots: dependencies: zod: 3.25.76 - zod-to-json-schema@3.25.1(zod@4.3.5): - dependencies: - zod: 4.3.5 - zod@3.25.76: {} zod@4.3.5: {} diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index df4e61ea1..504e164d6 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -47,8 +47,7 @@ catalogs: ajv-formats: ^3.0.1 json-schema-typed: ^8.0.2 pkce-challenge: ^5.0.0 - zod: ^3.25 || ^4.0 - zod-to-json-schema: ^3.25.0 + zod: ^4.0 enableGlobalVirtualStore: false diff --git a/test/helpers/src/fixtures/zodTestMatrix.ts b/test/helpers/src/fixtures/zodTestMatrix.ts index fc4ee63db..89bbbb501 100644 --- a/test/helpers/src/fixtures/zodTestMatrix.ts +++ b/test/helpers/src/fixtures/zodTestMatrix.ts @@ -1,21 +1,12 @@ -import * as z3 from 'zod/v3'; import * as z4 from 'zod/v4'; -// Shared Zod namespace type that exposes the common surface area used in tests. -export type ZNamespace = typeof z3 & typeof z4; +// Zod namespace type for tests +export type ZNamespace = typeof z4; export const zodTestMatrix = [ - { - zodVersionLabel: 'Zod v3', - z: z3 as ZNamespace, - isV3: true as const, - isV4: false as const - }, { zodVersionLabel: 'Zod v4', - z: z4 as ZNamespace, - isV3: false as const, - isV4: true as const + z: z4 as ZNamespace } ] as const; From 3ae419c1232e54548113fb8345be61a84e5f9eac Mon Sep 17 00:00:00 2001 From: Matt Carey Date: Tue, 3 Feb 2026 13:13:32 +0000 Subject: [PATCH 2/3] feat!: remove Zod schemas from public API and add typed request handlers - Create mapping.ts as single source of truth for method-to-schema mappings - Remove resultSchema parameter from Protocol.request() - now derived from method name - Add HandlerResultTypeMap for typed setRequestHandler return values - Update Client/Server setRequestHandler to use method-specific result types - Remove schema arguments from all call sites (examples, tests) - Fix test method names to use real MCP methods (ping, tools/call, etc.) BREAKING CHANGE: Protocol.request() no longer accepts a resultSchema parameter. The result type is now automatically inferred from the method name. --- examples/client/src/elicitationUrlExample.ts | 4 +- .../client/src/multipleClientsParallel.ts | 2 +- .../client/src/parallelToolCallsClient.ts | 8 +- examples/client/src/simpleOAuthClient.ts | 4 +- examples/client/src/simpleStreamableHttp.ts | 30 +- examples/client/src/ssePollingClient.ts | 3 +- .../streamableHttpWithSseFallbackClient.ts | 4 +- examples/server/src/simpleStreamableHttp.ts | 3 +- packages/client/src/client/client.ts | 41 ++- packages/core/src/index.ts | 1 + packages/core/src/shared/protocol.ts | 107 ++++--- packages/core/src/types/mapping.ts | 155 +++++++++ packages/core/src/types/types.ts | 16 +- packages/core/test/shared/protocol.test.ts | 298 +++++++++--------- packages/server/src/server/server.ts | 25 +- test/integration/test/server/mcp.test.ts | 30 +- 16 files changed, 436 insertions(+), 295 deletions(-) create mode 100644 packages/core/src/types/mapping.ts diff --git a/examples/client/src/elicitationUrlExample.ts b/examples/client/src/elicitationUrlExample.ts index 2f90b41d3..72005fcc3 100644 --- a/examples/client/src/elicitationUrlExample.ts +++ b/examples/client/src/elicitationUrlExample.ts @@ -656,7 +656,7 @@ async function listTools(): Promise { method: 'tools/list', params: {} }; - const toolsResult = await client.request(toolsRequest, ListToolsResultSchema); + const toolsResult = await client.request(toolsRequest); console.log('Available tools:'); if (toolsResult.tools.length === 0) { @@ -687,7 +687,7 @@ async function callTool(name: string, args: Record): Promise { method: 'tools/list', params: {} }; - const toolsResult = await client.request(toolsRequest, ListToolsResultSchema); + const toolsResult = await client.request(toolsRequest); console.log('Available tools:'); if (toolsResult.tools.length === 0) { @@ -159,7 +159,7 @@ async function startParallelNotificationTools(client: Client): Promise { console.log(`Starting tool call for ${caller}...`); return client - .request(request, CallToolResultSchema) + .request(request as CallToolRequest) .then(result => ({ caller, result })) .catch(error => { console.error(`Error in tool call for ${caller}:`, error); diff --git a/examples/client/src/simpleOAuthClient.ts b/examples/client/src/simpleOAuthClient.ts index 23ea05a99..80f709e73 100644 --- a/examples/client/src/simpleOAuthClient.ts +++ b/examples/client/src/simpleOAuthClient.ts @@ -252,7 +252,7 @@ class InteractiveOAuthClient { params: {} }; - const result = await this.client.request(request, ListToolsResultSchema); + const result = await this.client.request(request); if (result.tools && result.tools.length > 0) { console.log('\nšŸ“‹ Available tools:'); @@ -310,7 +310,7 @@ class InteractiveOAuthClient { } }; - const result = await this.client.request(request, CallToolResultSchema); + const result = await this.client.request(request); console.log(`\nšŸ”§ Tool '${toolName}' result:`); if (result.content) { diff --git a/examples/client/src/simpleStreamableHttp.ts b/examples/client/src/simpleStreamableHttp.ts index d3144c164..2b0743450 100644 --- a/examples/client/src/simpleStreamableHttp.ts +++ b/examples/client/src/simpleStreamableHttp.ts @@ -2,6 +2,7 @@ import { createInterface } from 'node:readline'; import type { CallToolRequest, + ElicitResult, GetPromptRequest, ListPromptsRequest, ListResourcesRequest, @@ -268,7 +269,7 @@ async function connect(url?: string): Promise { }; // Set up elicitation request handler with proper validation - client.setRequestHandler('elicitation/create', async request => { + client.setRequestHandler('elicitation/create', async (request): Promise => { if (request.params.mode !== 'form') { throw new McpError(ErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); } @@ -293,7 +294,7 @@ async function connect(url?: string): Promise { attempts++; console.log(`\nPlease provide the following information (attempt ${attempts}/${maxAttempts}):`); - const content: Record = {}; + const content: Record = {}; let inputCancelled = false; // Collect input for each field @@ -357,7 +358,7 @@ async function connect(url?: string): Promise { // Parse and validate the input try { if (answer === '' && field.default !== undefined) { - content[fieldName] = field.default; + content[fieldName] = field.default as string | number | boolean | string[]; } else if (answer === '' && !isRequired) { // Skip optional empty fields continue; @@ -365,7 +366,7 @@ async function connect(url?: string): Promise { throw new Error(`${fieldName} is required`); } else { // Parse the value based on type - let parsedValue: unknown; + let parsedValue: string | number | boolean | string[]; switch (field.type) { case 'boolean': { @@ -375,7 +376,7 @@ async function connect(url?: string): Promise { } case 'number': { parsedValue = Number.parseFloat(answer); - if (Number.isNaN(parsedValue as number)) { + if (Number.isNaN(parsedValue)) { throw new TypeError(`${fieldName} must be a valid number`); } @@ -383,7 +384,7 @@ async function connect(url?: string): Promise { } case 'integer': { parsedValue = Number.parseInt(answer, 10); - if (Number.isNaN(parsedValue as number)) { + if (Number.isNaN(parsedValue)) { throw new TypeError(`${fieldName} must be a valid integer`); } @@ -511,8 +512,7 @@ async function connect(url?: string): Promise { { method: 'resources/list', params: {} - }, - ListResourcesResultSchema + } ); console.log('Available resources count:', resourcesResult.resources.length); } catch { @@ -598,7 +598,7 @@ async function listTools(): Promise { method: 'tools/list', params: {} }; - const toolsResult = await client.request(toolsRequest, ListToolsResultSchema); + const toolsResult = await client.request(toolsRequest); console.log('Available tools:'); if (toolsResult.tools.length === 0) { @@ -629,7 +629,7 @@ async function callTool(name: string, args: Record): Promise { method: 'prompts/list', params: {} }; - const promptsResult = await client.request(promptsRequest, ListPromptsResultSchema); + const promptsResult = await client.request(promptsRequest); console.log('Available prompts:'); if (promptsResult.prompts.length === 0) { console.log(' No prompts available'); @@ -785,7 +785,7 @@ async function getPrompt(name: string, args: Record): Promise { method: 'resources/list', params: {} }; - const resourcesResult = await client.request(resourcesRequest, ListResourcesResultSchema); + const resourcesResult = await client.request(resourcesRequest); console.log('Available resources:'); if (resourcesResult.resources.length === 0) { @@ -834,7 +834,7 @@ async function readResource(uri: string): Promise { }; console.log(`Reading resource: ${uri}`); - const result = await client.request(request, ReadResourceResultSchema); + const result = await client.request(request); console.log('Resource contents:'); for (const content of result.contents) { diff --git a/examples/client/src/ssePollingClient.ts b/examples/client/src/ssePollingClient.ts index 34ff768c4..28f49d0b5 100644 --- a/examples/client/src/ssePollingClient.ts +++ b/examples/client/src/ssePollingClient.ts @@ -78,10 +78,9 @@ async function main(): Promise { arguments: {} } }, - CallToolResultSchema, { // Track resumption tokens for debugging - onresumptiontoken: token => { + onresumptiontoken: (token: string) => { lastEventId = token; console.log(`[Event ID] ${token}`); } diff --git a/examples/client/src/streamableHttpWithSseFallbackClient.ts b/examples/client/src/streamableHttpWithSseFallbackClient.ts index 73a626764..a975126a4 100644 --- a/examples/client/src/streamableHttpWithSseFallbackClient.ts +++ b/examples/client/src/streamableHttpWithSseFallbackClient.ts @@ -135,7 +135,7 @@ async function listTools(client: Client): Promise { method: 'tools/list', params: {} }; - const toolsResult = await client.request(toolsRequest, ListToolsResultSchema); + const toolsResult = await client.request(toolsRequest); console.log('Available tools:'); if (toolsResult.tools.length === 0) { @@ -168,7 +168,7 @@ async function startNotificationTool(client: Client): Promise { }; console.log('Calling notification tool...'); - const result = await client.request(request, CallToolResultSchema); + const result = await client.request(request); console.log('Tool result:'); for (const item of result.content) { diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index e016c217f..adc9c25bb 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -247,8 +247,7 @@ const getServer = () => { message, requestedSchema } - }, - ElicitResultSchema + } ); if (result.action === 'accept') { diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 9c2c87290..00d6523af 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -27,6 +27,7 @@ import type { RequestMethod, RequestOptions, RequestTypeMap, + HandlerResultTypeMap, Result, ServerCapabilities, SubscribeRequest, @@ -333,13 +334,13 @@ export class Client< handler: ( request: RequestTypeMap[M], extra: RequestHandlerExtra - ) => ClientResult | ResultT | Promise + ) => HandlerResultTypeMap[M] | Promise ): void { if (method === 'elicitation/create') { const wrappedHandler = async ( request: RequestTypeMap[M], extra: RequestHandlerExtra - ): Promise => { + ): Promise => { const validatedRequest = safeParse(ElicitRequestSchema, request); if (!validatedRequest.success) { // Type guard: if success is false, error is guaranteed to exist @@ -372,7 +373,7 @@ export class Client< : String(taskValidationResult.error); throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); } - return taskValidationResult.data; + return taskValidationResult.data as HandlerResultTypeMap[M]; } // For non-task requests, validate against ElicitResultSchema @@ -401,7 +402,7 @@ export class Client< } } - return validatedResult; + return validatedResult as HandlerResultTypeMap[M]; }; // Install the wrapped handler @@ -412,7 +413,7 @@ export class Client< const wrappedHandler = async ( request: RequestTypeMap[M], extra: RequestHandlerExtra - ): Promise => { + ): Promise => { const validatedRequest = safeParse(CreateMessageRequestSchema, request); if (!validatedRequest.success) { const errorMessage = @@ -434,7 +435,7 @@ export class Client< : String(taskValidationResult.error); throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); } - return taskValidationResult.data; + return taskValidationResult.data as HandlerResultTypeMap[M]; } // For non-task requests, validate against appropriate schema based on tools presence @@ -447,7 +448,7 @@ export class Client< throw new McpError(ErrorCode.InvalidParams, `Invalid sampling result: ${errorMessage}`); } - return validationResult.data; + return validationResult.data as HandlerResultTypeMap[M]; }; // Install the wrapped handler @@ -481,7 +482,6 @@ export class Client< clientInfo: this._clientInfo } }, - InitializeResultSchema, options ); @@ -686,19 +686,19 @@ export class Client< } async ping(options?: RequestOptions) { - return this.request({ method: 'ping' }, EmptyResultSchema, options); + return this.request({ method: 'ping' }, options); } async complete(params: CompleteRequest['params'], options?: RequestOptions) { - return this.request({ method: 'completion/complete', params }, CompleteResultSchema, options); + return this.request({ method: 'completion/complete', params }, options); } async setLoggingLevel(level: LoggingLevel, options?: RequestOptions) { - return this.request({ method: 'logging/setLevel', params: { level } }, EmptyResultSchema, options); + return this.request({ method: 'logging/setLevel', params: { level } }, options); } async getPrompt(params: GetPromptRequest['params'], options?: RequestOptions) { - return this.request({ method: 'prompts/get', params }, GetPromptResultSchema, options); + return this.request({ method: 'prompts/get', params }, options); } async listPrompts(params?: ListPromptsRequest['params'], options?: RequestOptions) { @@ -707,7 +707,7 @@ export class Client< console.debug('Client.listPrompts() called but server does not advertise prompts capability - returning empty list'); return { prompts: [] }; } - return this.request({ method: 'prompts/list', params }, ListPromptsResultSchema, options); + return this.request({ method: 'prompts/list', params }, options); } async listResources(params?: ListResourcesRequest['params'], options?: RequestOptions) { @@ -716,7 +716,7 @@ export class Client< console.debug('Client.listResources() called but server does not advertise resources capability - returning empty list'); return { resources: [] }; } - return this.request({ method: 'resources/list', params }, ListResourcesResultSchema, options); + return this.request({ method: 'resources/list', params }, options); } async listResourceTemplates(params?: ListResourceTemplatesRequest['params'], options?: RequestOptions) { @@ -727,19 +727,19 @@ export class Client< ); return { resourceTemplates: [] }; } - return this.request({ method: 'resources/templates/list', params }, ListResourceTemplatesResultSchema, options); + return this.request({ method: 'resources/templates/list', params }, options); } async readResource(params: ReadResourceRequest['params'], options?: RequestOptions) { - return this.request({ method: 'resources/read', params }, ReadResourceResultSchema, options); + return this.request({ method: 'resources/read', params }, options); } async subscribeResource(params: SubscribeRequest['params'], options?: RequestOptions) { - return this.request({ method: 'resources/subscribe', params }, EmptyResultSchema, options); + return this.request({ method: 'resources/subscribe', params }, options); } async unsubscribeResource(params: UnsubscribeRequest['params'], options?: RequestOptions) { - return this.request({ method: 'resources/unsubscribe', params }, EmptyResultSchema, options); + return this.request({ method: 'resources/unsubscribe', params }, options); } /** @@ -749,7 +749,6 @@ export class Client< */ async callTool( params: CallToolRequest['params'], - resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, options?: RequestOptions ) { // Guard: required-task tools need experimental API @@ -760,7 +759,7 @@ export class Client< ); } - const result = await this.request({ method: 'tools/call', params }, resultSchema, options); + const result = await this.request({ method: 'tools/call', params }, options); // Check if the tool has an outputSchema const validator = this.getToolOutputValidator(params.name); @@ -856,7 +855,7 @@ export class Client< console.debug('Client.listTools() called but server does not advertise tools capability - returning empty list'); return { tools: [] }; } - const result = await this.request({ method: 'tools/list', params }, ListToolsResultSchema, options); + const result = await this.request({ method: 'tools/list', params }, options); // Cache the tools and their output schemas for future validation this.cacheToolMetadata(result.tools); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index b7980fadb..01aa29805 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -8,6 +8,7 @@ export * from './shared/stdio.js'; export * from './shared/toolNameValidation.js'; export * from './shared/transport.js'; export * from './shared/uriTemplate.js'; +export * from './types/mapping.js'; export * from './types/types.js'; export * from './util/inMemory.js'; export * from './util/zodCompat.js'; diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 6909f17ff..cfe6eefcf 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -3,8 +3,11 @@ import { isTerminal } from '../experimental/tasks/interfaces.js'; import type { AuthInfo, CancelledNotification, + CancelTaskResult, ClientCapabilities, + CreateTaskResult, GetTaskPayloadRequest, + GetTaskPayloadResult, GetTaskRequest, GetTaskResult, JSONRPCErrorResponse, @@ -12,6 +15,7 @@ import type { JSONRPCRequest, JSONRPCResponse, JSONRPCResultResponse, + ListTasksResult, MessageExtraInfo, Notification, NotificationMethod, @@ -23,8 +27,6 @@ import type { RequestId, RequestInfo, RequestMeta, - RequestMethod, - RequestTypeMap, Result, ServerCapabilities, Task, @@ -36,7 +38,6 @@ import { CreateTaskResultSchema, ErrorCode, getNotificationSchema, - getRequestSchema, GetTaskResultSchema, isJSONRPCErrorResponse, isJSONRPCNotification, @@ -48,6 +49,8 @@ import { RELATED_TASK_META_KEY, TaskStatusNotificationSchema } from '../types/types.js'; +import type { HandlerResultTypeMap, RequestMethod, RequestTypeMap, ResultTypeMap } from '../types/mapping.js'; +import { getRequestSchema, getResultSchema } from '../types/mapping.js'; import type { AnySchema, SchemaOutput } from '../util/zodCompat.js'; import { safeParse } from '../util/zodCompat.js'; import { parseWithCompat } from '../util/zodJsonSchemaCompat.js'; @@ -287,7 +290,7 @@ export type RequestHandlerExtra(request: SendRequestT, resultSchema: U, options?: TaskRequestOptions) => Promise>; + sendRequest: (request: SendRequestT & { method: M }, options?: TaskRequestOptions) => Promise; /** * Closes the SSE stream for this request, triggering client reconnection. @@ -325,7 +328,7 @@ export abstract class Protocol) => Promise + (request: JSONRPCRequest, extra: RequestHandlerExtra) => Promise > = new Map(); private _requestHandlerAbortControllers: Map = new Map(); private _notificationHandlers: Map Promise> = new Map(); @@ -378,7 +381,7 @@ export abstract class Protocol ({}) as SendResultT + _request => ({}) ); // Install task handlers if TaskStore is provided @@ -395,11 +398,11 @@ export abstract class Protocol { - const handleTaskResult = async (): Promise => { + const handleTaskResult = async (): Promise => { const taskId = request.params.taskId; // Deliver queued messages @@ -476,7 +479,7 @@ export abstract class Protocol { + sendRequest: async (r, options?) => { // Include related-task metadata if this request is part of a task const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; if (relatedTaskId && !requestOptions.relatedTask) { @@ -741,7 +744,7 @@ export abstract class Protocol( - request: SendRequestT, - resultSchema: T, + protected async *requestStream( + request: SendRequestT & { method: M }, options?: RequestOptions - ): AsyncGenerator & Result>, void, void> { + ): AsyncGenerator, void, void> { const { task } = options ?? {}; // For non-task requests, just yield the result if (!task) { try { - const result = await this.request(request, resultSchema, options); - yield { type: 'result', result: result as SchemaOutput & Result }; + const result = await this.request(request, options); + yield { type: 'result', result: result as ResultTypeMap[M] & Result }; } catch (error) { yield { type: 'error', @@ -1000,10 +1002,11 @@ export abstract class Protocol & Result }; + const result = await this.getTaskResult({ taskId }, options); + yield { type: 'result', result: result as ResultTypeMap[M] & Result }; break; } @@ -1052,14 +1055,14 @@ export abstract class Protocol & Result }; + if (currentTask.status === 'input_required') { + const result = await this.getTaskResult({ taskId }, options); + yield { type: 'result', result: result as ResultTypeMap[M] & Result }; return; } // Wait before polling again - const pollInterval = task.pollInterval ?? this._options?.defaultTaskPollInterval ?? 1000; + const pollInterval = currentTask.pollInterval ?? this._options?.defaultTaskPollInterval ?? 1000; await new Promise(resolve => setTimeout(resolve, pollInterval)); // Check if cancelled @@ -1078,11 +1081,24 @@ export abstract class Protocol(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { + request( + request: SendRequestT & { method: M }, + options: RequestOptions & { task: TaskCreationParams } + ): Promise; + request( + request: SendRequestT & { method: M }, + options?: RequestOptions + ): Promise; + request( + request: SendRequestT & { method: M }, + options?: RequestOptions + ): Promise { + // Use CreateTaskResultSchema when task creation is requested, otherwise use the method's result schema + const resultSchema = options?.task ? CreateTaskResultSchema : getResultSchema(request.method); const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; // Send the request - return new Promise>((resolve, reject) => { + return new Promise((resolve, reject) => { const earlyReject = (error: unknown) => { reject(error); }; @@ -1181,7 +1197,7 @@ export abstract class Protocol); + resolve(parseResult.data as ResultTypeMap[M] | CreateTaskResult); } else { // Type guard: if success is false, error is guaranteed to exist reject(parseResult.error); @@ -1241,9 +1257,9 @@ export abstract class Protocol { + protected async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { // @ts-expect-error SendRequestT cannot directly contain GetTaskRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); + return this.request({ method: 'tasks/get', params }, options); } /** @@ -1251,13 +1267,12 @@ export abstract class Protocol( + protected async getTaskResult( params: GetTaskPayloadRequest['params'], - resultSchema: T, options?: RequestOptions - ): Promise> { + ): Promise { // @ts-expect-error SendRequestT cannot directly contain GetTaskPayloadRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/result', params }, resultSchema, options); + return this.request({ method: 'tasks/result', params }, options); } /** @@ -1265,9 +1280,9 @@ export abstract class Protocol> { + protected async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise { // @ts-expect-error SendRequestT cannot directly contain ListTasksRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); + return this.request({ method: 'tasks/list', params }, options); } /** @@ -1275,9 +1290,9 @@ export abstract class Protocol> { + protected async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise { // @ts-expect-error SendRequestT cannot directly contain CancelTaskRequest, but we ensure all type instantiations contain it anyways - return this.request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); + return this.request({ method: 'tasks/cancel', params }, options); } /** @@ -1403,7 +1418,7 @@ export abstract class Protocol - ) => SendResultT | Promise + ) => HandlerResultTypeMap[M] | Promise ): void { this.assertRequestHandlerCapability(method); const schema = getRequestSchema(method); diff --git a/packages/core/src/types/mapping.ts b/packages/core/src/types/mapping.ts new file mode 100644 index 000000000..f8a29abef --- /dev/null +++ b/packages/core/src/types/mapping.ts @@ -0,0 +1,155 @@ +/** + * Single source of truth for request method → schema mappings. + * + * Each entry maps a method name to [RequestSchema, ResultSchema]. + * All type maps and runtime lookups are derived from these definitions. + */ + +import type { Infer } from 'zod/v4'; +import type { AnySchema } from '../util/zodCompat.js'; +import type { CreateTaskResult } from './types.js'; +import { + // Shared + PingRequestSchema, + EmptyResultSchema, + + // Client → Server + InitializeRequestSchema, + InitializeResultSchema, + CompleteRequestSchema, + CompleteResultSchema, + SetLevelRequestSchema, + GetPromptRequestSchema, + GetPromptResultSchema, + ListPromptsRequestSchema, + ListPromptsResultSchema, + ListResourcesRequestSchema, + ListResourcesResultSchema, + ListResourceTemplatesRequestSchema, + ListResourceTemplatesResultSchema, + ReadResourceRequestSchema, + ReadResourceResultSchema, + SubscribeRequestSchema, + UnsubscribeRequestSchema, + CallToolRequestSchema, + CallToolResultSchema, + ListToolsRequestSchema, + ListToolsResultSchema, + + // Server → Client + CreateMessageRequestSchema, + CreateMessageResultSchema, + ElicitRequestSchema, + ElicitResultSchema, + ListRootsRequestSchema, + ListRootsResultSchema, + + // Tasks (bidirectional) + GetTaskRequestSchema, + GetTaskResultSchema, + GetTaskPayloadRequestSchema, + GetTaskPayloadResultSchema, + ListTasksRequestSchema, + ListTasksResultSchema, + CancelTaskRequestSchema, + CancelTaskResultSchema, +} from './types.js'; + +type MethodSchema = readonly [request: AnySchema, result: AnySchema]; + +/** + * Client → Server request methods. + * These are requests that clients send and servers handle. + */ +export const clientMethodSchemas = { + 'initialize': [InitializeRequestSchema, InitializeResultSchema], + 'completion/complete': [CompleteRequestSchema, CompleteResultSchema], + 'logging/setLevel': [SetLevelRequestSchema, EmptyResultSchema], + 'prompts/get': [GetPromptRequestSchema, GetPromptResultSchema], + 'prompts/list': [ListPromptsRequestSchema, ListPromptsResultSchema], + 'resources/list': [ListResourcesRequestSchema, ListResourcesResultSchema], + 'resources/templates/list': [ListResourceTemplatesRequestSchema, ListResourceTemplatesResultSchema], + 'resources/read': [ReadResourceRequestSchema, ReadResourceResultSchema], + 'resources/subscribe': [SubscribeRequestSchema, EmptyResultSchema], + 'resources/unsubscribe': [UnsubscribeRequestSchema, EmptyResultSchema], + 'tools/call': [CallToolRequestSchema, CallToolResultSchema], + 'tools/list': [ListToolsRequestSchema, ListToolsResultSchema], +} as const satisfies Record; + +/** + * Server → Client request methods. + * These are requests that servers send and clients handle. + */ +export const serverMethodSchemas = { + 'sampling/createMessage': [CreateMessageRequestSchema, CreateMessageResultSchema], + 'elicitation/create': [ElicitRequestSchema, ElicitResultSchema], + 'roots/list': [ListRootsRequestSchema, ListRootsResultSchema], +} as const satisfies Record; + +/** + * Bidirectional request methods. + * These can be sent by either client or server. + */ +export const sharedMethodSchemas = { + 'ping': [PingRequestSchema, EmptyResultSchema], + 'tasks/get': [GetTaskRequestSchema, GetTaskResultSchema], + 'tasks/result': [GetTaskPayloadRequestSchema, GetTaskPayloadResultSchema], + 'tasks/list': [ListTasksRequestSchema, ListTasksResultSchema], + 'tasks/cancel': [CancelTaskRequestSchema, CancelTaskResultSchema], +} as const satisfies Record; + +/** + * Combined mapping of all request methods to their schemas. + */ +export const methodSchemas = { + ...clientMethodSchemas, + ...serverMethodSchemas, + ...sharedMethodSchemas, +} as const; + +// ============================================================================= +// Type Definitions (derived from the mappings) +// ============================================================================= + +/** All valid request method names */ +export type RequestMethod = keyof typeof methodSchemas; + +/** Methods that clients can send (client → server + shared) */ +export type ClientRequestMethod = keyof typeof clientMethodSchemas | keyof typeof sharedMethodSchemas; + +/** Methods that servers can send (server → client + shared) */ +export type ServerRequestMethod = keyof typeof serverMethodSchemas | keyof typeof sharedMethodSchemas; + +/** Maps method name → request type */ +export type RequestTypeMap = { + [M in RequestMethod]: Infer<(typeof methodSchemas)[M][0]> +}; + +/** Maps method name → result type */ +export type ResultTypeMap = { + [M in RequestMethod]: Infer<(typeof methodSchemas)[M][1]> +}; + +/** Methods that support task creation (can return CreateTaskResult) */ +export type TaskAugmentedMethod = 'tools/call' | 'sampling/createMessage' | 'elicitation/create'; + +/** Maps method name → handler result type (includes CreateTaskResult for task-augmented methods) */ +export type HandlerResultTypeMap = { + [M in RequestMethod]: M extends TaskAugmentedMethod + ? ResultTypeMap[M] | CreateTaskResult + : ResultTypeMap[M] +}; + +// ============================================================================= +// Runtime Lookups (direct property access - very fast) +// ============================================================================= + +/** Get the request schema for a method */ +export function getRequestSchema(method: M) { + return methodSchemas[method][0]; +} + +/** Get the result schema for a method */ +export function getResultSchema(method: M) { + return methodSchemas[method][1]; +} diff --git a/packages/core/src/types/types.ts b/packages/core/src/types/types.ts index 9d659693f..867976f50 100644 --- a/packages/core/src/types/types.ts +++ b/packages/core/src/types/types.ts @@ -2610,16 +2610,13 @@ export type ServerResult = Infer; type MethodToTypeMap = { [T in U as T extends { method: infer M extends string } ? M : never]: T; }; -export type RequestMethod = ClientRequest['method'] | ServerRequest['method']; export type NotificationMethod = ClientNotification['method'] | ServerNotification['method']; -export type RequestTypeMap = MethodToTypeMap; export type NotificationTypeMap = MethodToTypeMap; -/* Runtime schema lookup */ -type RequestSchemaType = (typeof ClientRequestSchema.options)[number] | (typeof ServerRequestSchema.options)[number]; +/* Runtime notification schema lookup */ type NotificationSchemaType = (typeof ClientNotificationSchema.options)[number] | (typeof ServerNotificationSchema.options)[number]; -function buildSchemaMap(schemas: readonly T[]): Record { +function buildNotificationSchemaMap(schemas: readonly T[]): Record { const map: Record = {}; for (const schema of schemas) { const method = schema.shape.method.value; @@ -2628,18 +2625,11 @@ function buildSchemaMap(sche return map; } -const requestSchemas = buildSchemaMap([...ClientRequestSchema.options, ...ServerRequestSchema.options] as const) as Record< - RequestMethod, - RequestSchemaType ->; -const notificationSchemas = buildSchemaMap([...ClientNotificationSchema.options, ...ServerNotificationSchema.options] as const) as Record< +const notificationSchemas = buildNotificationSchemaMap([...ClientNotificationSchema.options, ...ServerNotificationSchema.options] as const) as Record< NotificationMethod, NotificationSchemaType >; -export function getRequestSchema(method: M) { - return requestSchemas[method]; -} export function getNotificationSchema(method: M) { return notificationSchemas[method]; } diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index b5f6f40cb..b5b147e53 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -42,7 +42,7 @@ interface TestProtocol { // Protected task methods (exposed for testing) listTasks: (params?: { cursor?: string }) => Promise<{ tasks: Task[]; nextCursor?: string }>; cancelTask: (params: { taskId: string }) => Promise; - requestStream: (request: Request, schema: ZodType, options?: unknown) => AsyncGenerator>; + requestStream: (request: Request, options?: unknown) => AsyncGenerator>; } // Mock Transport class @@ -166,12 +166,8 @@ describe('protocol tests', () => { test('should throw a timeout error if the request exceeds the timeout', async () => { await protocol.connect(transport); - const request = { method: 'example', params: {} }; try { - const mockSchema: ZodType<{ result: string }> = z.object({ - result: z.string() - }); - await protocol.request(request, mockSchema, { + await protocol.request({ method: 'ping', params: {} }, { timeout: 0 }); } catch (error) { @@ -210,7 +206,7 @@ describe('protocol tests', () => { test('should preserve existing _meta when adding progressToken', async () => { await protocol.connect(transport); const request = { - method: 'example', + method: 'ping', params: { data: 'test', _meta: { @@ -226,7 +222,7 @@ describe('protocol tests', () => { // Start request but don't await - we're testing the sent message void protocol - .request(request, mockSchema, { + .request(request as any, { onprogress: onProgressMock }) .catch(() => { @@ -235,7 +231,7 @@ describe('protocol tests', () => { expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ - method: 'example', + method: 'ping', params: { data: 'test', _meta: { @@ -254,7 +250,7 @@ describe('protocol tests', () => { test('should create _meta with progressToken when no _meta exists', async () => { await protocol.connect(transport); const request = { - method: 'example', + method: 'ping', params: { data: 'test' } @@ -266,7 +262,7 @@ describe('protocol tests', () => { // Start request but don't await - we're testing the sent message void protocol - .request(request, mockSchema, { + .request(request as any, { onprogress: onProgressMock }) .catch(() => { @@ -275,7 +271,7 @@ describe('protocol tests', () => { expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ - method: 'example', + method: 'ping', params: { data: 'test', _meta: { @@ -292,7 +288,7 @@ describe('protocol tests', () => { test('should not modify _meta when onprogress is not provided', async () => { await protocol.connect(transport); const request = { - method: 'example', + method: 'ping', params: { data: 'test', _meta: { @@ -305,13 +301,13 @@ describe('protocol tests', () => { }); // Start request but don't await - we're testing the sent message - void protocol.request(request, mockSchema).catch(() => { + void protocol.request(request as any).catch(() => { // May not complete, ignore error }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ - method: 'example', + method: 'ping', params: { data: 'test', _meta: { @@ -328,7 +324,7 @@ describe('protocol tests', () => { test('should handle params being undefined with onprogress', async () => { await protocol.connect(transport); const request = { - method: 'example' + method: 'ping' }; const mockSchema: ZodType<{ result: string }> = z.object({ result: z.string() @@ -337,7 +333,7 @@ describe('protocol tests', () => { // Start request but don't await - we're testing the sent message void protocol - .request(request, mockSchema, { + .request(request as any, { onprogress: onProgressMock }) .catch(() => { @@ -346,7 +342,7 @@ describe('protocol tests', () => { expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ - method: 'example', + method: 'ping', params: { _meta: { progressToken: expect.any(Number) @@ -370,12 +366,12 @@ describe('protocol tests', () => { test('should not reset timeout when resetTimeoutOnProgress is false', async () => { await protocol.connect(transport); - const request = { method: 'example', params: {} }; + const request = { method: 'ping', params: {} }; const mockSchema: ZodType<{ result: string }> = z.object({ result: z.string() }); const onProgressMock = vi.fn(); - const requestPromise = protocol.request(request, mockSchema, { + const requestPromise = protocol.request(request as any, { timeout: 1000, resetTimeoutOnProgress: false, onprogress: onProgressMock @@ -408,12 +404,12 @@ describe('protocol tests', () => { test('should reset timeout when progress notification is received', async () => { await protocol.connect(transport); - const request = { method: 'example', params: {} }; + const request = { method: 'ping', params: {} }; const mockSchema: ZodType<{ result: string }> = z.object({ result: z.string() }); const onProgressMock = vi.fn(); - const requestPromise = protocol.request(request, mockSchema, { + const requestPromise = protocol.request(request as any, { timeout: 1000, resetTimeoutOnProgress: true, onprogress: onProgressMock @@ -440,21 +436,21 @@ describe('protocol tests', () => { transport.onmessage({ jsonrpc: '2.0', id: 0, - result: { result: 'success' } + result: {} }); } await Promise.resolve(); - await expect(requestPromise).resolves.toEqual({ result: 'success' }); + await expect(requestPromise).resolves.toEqual({}); }); test('should respect maxTotalTimeout', async () => { await protocol.connect(transport); - const request = { method: 'example', params: {} }; + const request = { method: 'ping', params: {} }; const mockSchema: ZodType<{ result: string }> = z.object({ result: z.string() }); const onProgressMock = vi.fn(); - const requestPromise = protocol.request(request, mockSchema, { + const requestPromise = protocol.request(request as any, { timeout: 1000, maxTotalTimeout: 150, resetTimeoutOnProgress: true, @@ -497,11 +493,11 @@ describe('protocol tests', () => { test('should timeout if no progress received within timeout period', async () => { await protocol.connect(transport); - const request = { method: 'example', params: {} }; + const request = { method: 'ping', params: {} }; const mockSchema: ZodType<{ result: string }> = z.object({ result: z.string() }); - const requestPromise = protocol.request(request, mockSchema, { + const requestPromise = protocol.request(request as any, { timeout: 100, resetTimeoutOnProgress: true }); @@ -511,12 +507,12 @@ describe('protocol tests', () => { test('should handle multiple progress notifications correctly', async () => { await protocol.connect(transport); - const request = { method: 'example', params: {} }; + const request = { method: 'ping', params: {} }; const mockSchema: ZodType<{ result: string }> = z.object({ result: z.string() }); const onProgressMock = vi.fn(); - const requestPromise = protocol.request(request, mockSchema, { + const requestPromise = protocol.request(request as any, { timeout: 1000, resetTimeoutOnProgress: true, onprogress: onProgressMock @@ -546,22 +542,22 @@ describe('protocol tests', () => { transport.onmessage({ jsonrpc: '2.0', id: 0, - result: { result: 'success' } + result: {} }); } await Promise.resolve(); - await expect(requestPromise).resolves.toEqual({ result: 'success' }); + await expect(requestPromise).resolves.toEqual({}); }); test('should handle progress notifications with message field', async () => { await protocol.connect(transport); - const request = { method: 'example', params: {} }; + const request = { method: 'ping', params: {} }; const mockSchema: ZodType<{ result: string }> = z.object({ result: z.string() }); const onProgressMock = vi.fn(); - const requestPromise = protocol.request(request, mockSchema, { + const requestPromise = protocol.request(request as any, { timeout: 1000, onprogress: onProgressMock }); @@ -614,11 +610,11 @@ describe('protocol tests', () => { transport.onmessage({ jsonrpc: '2.0', id: 0, - result: { result: 'success' } + result: {} }); } await Promise.resolve(); - await expect(requestPromise).resolves.toEqual({ result: 'success' }); + await expect(requestPromise).resolves.toEqual({}); }); }); @@ -1032,7 +1028,7 @@ describe('Task-based execution', () => { }); void protocol - .request(request, resultSchema, { + .request(request as any, { task: { ttl: 30000, pollInterval: 1000 @@ -1075,7 +1071,7 @@ describe('Task-based execution', () => { }); void protocol - .request(request, resultSchema, { + .request(request as any, { task: { ttl: 60000 } @@ -1112,7 +1108,7 @@ describe('Task-based execution', () => { content: z.array(z.object({ type: z.literal('text'), text: z.string() })) }); - const resultPromise = protocol.request(request, resultSchema, { + const resultPromise = protocol.request(request as any, { task: { ttl: 30000 } @@ -1127,16 +1123,9 @@ describe('Task-based execution', () => { it('should inject relatedTask metadata into _meta field', async () => { await protocol.connect(transport); - const request = { - method: 'notifications/message', - params: { data: 'test' } - }; - - const resultSchema = z.object({}); - // Start the request (don't await completion, just let it send) void protocol - .request(request, resultSchema, { + .request({ method: 'tools/call', params: { name: 'test', arguments: {} } }, { relatedTask: { taskId: 'parent-task-123' } @@ -1202,7 +1191,7 @@ describe('Task-based execution', () => { // Start the request (don't await completion, just let it send) void protocol - .request(request, resultSchema, { + .request(request as any, { task: { ttl: 60000, pollInterval: 1000 @@ -2403,7 +2392,7 @@ describe('Progress notification support for tasks', () => { // Start a task-augmented request with progress callback void protocol - .request(request, resultSchema, { + .request(request as any, { task: { ttl: 60000 }, onprogress: progressCallback }) @@ -2432,7 +2421,8 @@ describe('Progress notification support for tasks', () => { taskId, status: 'working', ttl: 60000, - createdAt: new Date().toISOString() + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString() } } }); @@ -2513,7 +2503,7 @@ describe('Progress notification support for tasks', () => { // Start a task-augmented request with progress callback void protocol - .request(request, resultSchema, { + .request(request as any, { task: { ttl: 60000 }, onprogress: progressCallback }) @@ -2629,7 +2619,7 @@ describe('Progress notification support for tasks', () => { }) }); - void protocol.request(request, resultSchema, { + void protocol.request(request as any, { task: { ttl: 60000 }, onprogress: progressCallback }); @@ -2649,7 +2639,8 @@ describe('Progress notification support for tasks', () => { taskId, status: 'working', ttl: 60000, - createdAt: new Date().toISOString() + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString() } } }); @@ -2727,7 +2718,7 @@ describe('Progress notification support for tasks', () => { }) }); - void protocol.request(request, resultSchema, { + void protocol.request(request as any, { task: { ttl: 60000 }, onprogress: progressCallback }); @@ -2747,7 +2738,8 @@ describe('Progress notification support for tasks', () => { taskId, status: 'working', ttl: 60000, - createdAt: new Date().toISOString() + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString() } } }); @@ -2822,7 +2814,7 @@ describe('Progress notification support for tasks', () => { }) }); - void protocol.request(request, resultSchema, { + void protocol.request(request as any, { task: { ttl: 60000 }, onprogress: progressCallback }); @@ -2842,7 +2834,8 @@ describe('Progress notification support for tasks', () => { taskId, status: 'working', ttl: 60000, - createdAt: new Date().toISOString() + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString() } } }); @@ -2893,7 +2886,7 @@ describe('Progress notification support for tasks', () => { const onProgressMock = vi.fn(); - void protocol.request(request, resultSchema, { + void protocol.request(request as any, { task: { ttl: 60000 }, @@ -2918,7 +2911,7 @@ describe('Progress notification support for tasks', () => { const onProgressMock = vi.fn(); - void protocol.request(request, resultSchema, { + void protocol.request(request as any, { task: { ttl: 30000 }, @@ -2968,7 +2961,7 @@ describe('Progress notification support for tasks', () => { const onProgressMock = vi.fn(); - void protocol.request(request, resultSchema, { + void protocol.request(request as any, { task: { ttl: 30000 }, @@ -2988,7 +2981,8 @@ describe('Progress notification support for tasks', () => { taskId: 'task-123', status: 'working', ttl: 30000, - createdAt: new Date().toISOString() + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString() } } }); @@ -3268,8 +3262,7 @@ describe('Message interception for task-related requests', () => { { method: 'ping', params: {} - }, - z.object({}), + } as any, { relatedTask: { taskId: task.taskId } } @@ -3317,8 +3310,7 @@ describe('Message interception for task-related requests', () => { { method: 'ping', params: {} - }, - z.object({}) + } as any ); // Verify queue exists (but we don't track size in the new API) @@ -3360,7 +3352,6 @@ describe('Message interception for task-related requests', () => { method: 'ping', params: {} }, - z.object({}), { relatedTask: { taskId: task.taskId } } @@ -3413,7 +3404,6 @@ describe('Message interception for task-related requests', () => { method: 'ping', params: {} }, - z.object({ message: z.string() }), { relatedTask: { taskId: task.taskId } } @@ -3429,7 +3419,7 @@ describe('Message interception for task-related requests', () => { message: { jsonrpc: '2.0', id: requestId, - result: { message: 'pong' } + result: {} }, timestamp: Date.now() }); @@ -3453,7 +3443,7 @@ describe('Message interception for task-related requests', () => { // Verify the response was routed correctly const result = await requestPromise; - expect(result).toEqual({ message: 'pong' }); + expect(result).toEqual({}); }); it('should log error when resolver is missing for side-channeled request', async () => { @@ -3483,7 +3473,6 @@ describe('Message interception for task-related requests', () => { method: 'ping', params: {} }, - z.object({ message: z.string() }), { relatedTask: { taskId: task.taskId } } @@ -3503,7 +3492,7 @@ describe('Message interception for task-related requests', () => { message: { jsonrpc: '2.0', id: requestId, - result: { message: 'pong' } + result: {} }, timestamp: Date.now() }); @@ -3557,7 +3546,6 @@ describe('Message interception for task-related requests', () => { method: 'ping', params: {} }, - z.object({}), { relatedTask: { taskId: task.taskId } } @@ -3575,7 +3563,6 @@ describe('Message interception for task-related requests', () => { method: 'ping', params: {} }, - z.object({}), { relatedTask: { taskId: task.taskId } } @@ -3634,15 +3621,12 @@ describe('Message Interception', () => { it('should queue requests with relatedTask metadata', async () => { await protocol.connect(transport); - const mockSchema = z.object({ result: z.string() }); - // Send a request with relatedTask metadata const requestPromise = protocol.request( { - method: 'test/request', - params: { data: 'test' } + method: 'ping', + params: {} }, - mockSchema, { relatedTask: { taskId: 'task-456' @@ -3656,7 +3640,7 @@ describe('Message Interception', () => { const queuedMessage = await queue!.dequeue('task-456'); assertQueuedRequest(queuedMessage); - expect(queuedMessage.message.method).toBe('test/request'); + expect(queuedMessage.message.method).toBe('ping'); // Verify resolver is stored in _requestResolvers map (not in the message) const requestId = queuedMessage.message.id as RequestId; @@ -3667,7 +3651,7 @@ describe('Message Interception', () => { transport.onmessage?.({ jsonrpc: '2.0', id: requestId, - result: { result: 'success' } + result: {} }); await requestPromise; }); @@ -3836,16 +3820,14 @@ describe('Message Interception', () => { it('should not queue requests without relatedTask metadata', async () => { await protocol.connect(transport); - const mockSchema = z.object({ result: z.string() }); const sendSpy = vi.spyOn(transport, 'send'); // Send a request without relatedTask metadata const requestPromise = protocol.request( { - method: 'test/request', - params: { data: 'test' } - }, - mockSchema + method: 'ping', + params: {} + } ); // Access the private _taskMessageQueue to verify no messages were queued @@ -3859,7 +3841,7 @@ describe('Message Interception', () => { transport.onmessage?.({ jsonrpc: '2.0', id: requestId, - result: { result: 'success' } + result: {} }); await requestPromise; }); @@ -3898,15 +3880,13 @@ describe('Message Interception', () => { await protocol.connect(transport); const taskId = 'extracted-task-999'; - const mockSchema = z.object({ result: z.string() }); // Send a request with relatedTask metadata const requestPromise = protocol.request( { - method: 'test/request', - params: { data: 'test' } + method: 'ping', + params: {} }, - mockSchema, { relatedTask: { taskId: taskId @@ -3921,11 +3901,11 @@ describe('Message Interception', () => { // Clean up the pending request const queuedMessage = await queue!.dequeue(taskId); assertQueuedRequest(queuedMessage); - expect(queuedMessage.message.method).toBe('test/request'); + expect(queuedMessage.message.method).toBe('ping'); transport.onmessage?.({ jsonrpc: '2.0', id: queuedMessage.message.id, - result: { result: 'success' } + result: {} }); await requestPromise; }); @@ -4043,14 +4023,12 @@ describe('Message Interception', () => { await protocol.connect(transport); const relatedTask = { taskId: 'task-meta-456' }; - const mockSchema = z.object({ result: z.string() }); const requestPromise = protocol.request( { - method: 'test/request', - params: { data: 'test' } + method: 'ping', + params: {} }, - mockSchema, { relatedTask } ); @@ -4067,7 +4045,7 @@ describe('Message Interception', () => { transport.onmessage?.({ jsonrpc: '2.0', id: (queuedMessage!.message as JSONRPCRequest).id, - result: { result: 'success' } + result: {} }); await requestPromise; }); @@ -4235,7 +4213,7 @@ describe('Queue lifecycle management', () => { // Queue a request (catch rejection to avoid unhandled promise rejection) const requestPromise = protocol - .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { + .request({ method: 'ping', params: {} }, { relatedTask: { taskId } }) .catch(err => err); @@ -4308,7 +4286,7 @@ describe('Queue lifecycle management', () => { // Queue a request (catch the rejection to avoid unhandled promise rejection) const requestPromise = protocol - .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { + .request({ method: 'ping', params: {} }, { relatedTask: { taskId } }) .catch(err => err); @@ -4341,19 +4319,19 @@ describe('Queue lifecycle management', () => { // Queue multiple requests (catch rejections to avoid unhandled promise rejections) const request1Promise = protocol - .request({ method: 'test/request1', params: { data: 'test1' } }, z.object({ result: z.string() }), { + .request({ method: 'ping', params: {} }, { relatedTask: { taskId } }) .catch(err => err); const request2Promise = protocol - .request({ method: 'test/request2', params: { data: 'test2' } }, z.object({ result: z.string() }), { + .request({ method: 'ping', params: {} }, { relatedTask: { taskId } }) .catch(err => err); const request3Promise = protocol - .request({ method: 'test/request3', params: { data: 'test3' } }, z.object({ result: z.string() }), { + .request({ method: 'ping', params: {} }, { relatedTask: { taskId } }) .catch(err => err); @@ -4391,7 +4369,7 @@ describe('Queue lifecycle management', () => { // Queue a request (catch rejection to avoid unhandled promise rejection) const requestPromise = protocol - .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { + .request({ method: 'ping', params: {} }, { relatedTask: { taskId } }) .catch(err => err); @@ -4441,8 +4419,7 @@ describe('requestStream() method', () => { const streamPromise = (async () => { const messages = []; const stream = (protocol as unknown as TestProtocol).requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema + { method: 'tools/call', params: { name: 'test', arguments: {} } } ); for await (const message of stream) { messages.push(message); @@ -4484,8 +4461,7 @@ describe('requestStream() method', () => { const streamPromise = (async () => { const messages = []; const stream = (protocol as unknown as TestProtocol).requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema + { method: 'tools/call', params: { name: 'test', arguments: {} } } ); for await (const message of stream) { messages.push(message); @@ -4535,7 +4511,6 @@ describe('requestStream() method', () => { const messages = []; const stream = (protocol as unknown as TestProtocol).requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema, { signal: abortController.signal } @@ -4566,8 +4541,7 @@ describe('requestStream() method', () => { const messagesPromise = toArrayAsync( (protocol as unknown as TestProtocol).requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema + { method: 'tools/call', params: { name: 'test', arguments: {} } } ) ); @@ -4609,7 +4583,6 @@ describe('requestStream() method', () => { const messagesPromise = toArrayAsync( (protocol as unknown as TestProtocol).requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema, { timeout: 100 } @@ -4651,7 +4624,6 @@ describe('requestStream() method', () => { const messages = await toArrayAsync( (protocol as unknown as TestProtocol).requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema, { signal: abortController.signal } @@ -4679,8 +4651,7 @@ describe('requestStream() method', () => { const messagesPromise = toArrayAsync( (protocol as unknown as TestProtocol).requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema + { method: 'tools/call', params: { name: 'test', arguments: {} } } ) ); @@ -4719,57 +4690,72 @@ describe('requestStream() method', () => { test('should yield error as terminal message for task failure', async () => { const transport = new MockTransport(); - const mockTaskStore = createMockTaskStore(); const protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ defaultTaskPollInterval: 10 }); // Short poll interval for test await protocol.connect(transport); - const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocol).requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema - ) - ); - - // Simulate task creation response - await new Promise(resolve => setTimeout(resolve, 10)); const taskId = 'test-task-123'; - transport.onmessage?.({ - jsonrpc: '2.0', - id: 0, - result: { - _meta: { - task: { - taskId, - status: 'working', - createdAt: new Date().toISOString(), - pollInterval: 100 - } - } + const now = new Date().toISOString(); + let getTaskCount = 0; + + // Intercept outgoing requests and respond appropriately + const originalSend = transport.send.bind(transport); + transport.send = vi.fn(async (message: JSONRPCMessage) => { + await originalSend(message); + const msg = message as { id: number; method: string }; + + // Respond to the initial tools/call request with CreateTaskResult + if (msg.method === 'tools/call') { + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: msg.id, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: now, + lastUpdatedAt: now + } + } + }); + }, 5); } - }); - - // Wait for task creation to be processed - await new Promise(resolve => setTimeout(resolve, 20)); - // Update task to failed status - const failedTask = { - taskId, - status: 'failed' as const, - createdAt: new Date().toISOString(), - pollInterval: 100, - ttl: null, - statusMessage: 'Task failed' - }; - mockTaskStore.getTask.mockResolvedValue(failedTask); + // Respond to tasks/get polling requests + // GetTaskResult is TaskSchema merged with ResultSchema - fields are at top level + if (msg.method === 'tasks/get') { + getTaskCount++; + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: msg.id, + result: { + taskId, + // Return failed status on first poll + status: 'failed', + ttl: 60000, + createdAt: now, + lastUpdatedAt: now, + statusMessage: 'Task failed' + } + }); + }, 5); + } + }); - // Collect messages - const messages = await messagesPromise; + const messages = await toArrayAsync( + (protocol as unknown as TestProtocol).requestStream( + { method: 'tools/call', params: { name: 'test', arguments: {} } }, + { task: { ttl: 60000 } } + ) + ); // Verify error is terminal and last message expect(messages.length).toBeGreaterThan(0); @@ -4794,8 +4780,7 @@ describe('requestStream() method', () => { const messages = await toArrayAsync( (protocol as unknown as TestProtocol).requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema + { method: 'tools/call', params: { name: 'test', arguments: {} } } ) ); @@ -4819,8 +4804,7 @@ describe('requestStream() method', () => { const messagesPromise = toArrayAsync( (protocol as unknown as TestProtocol).requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema + { method: 'tools/call', params: { name: 'test', arguments: {} } } ) ); diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index ca43272cf..fdf70ede8 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -24,6 +24,7 @@ import type { RequestMethod, RequestOptions, RequestTypeMap, + HandlerResultTypeMap, ResourceUpdatedNotification, Result, ServerCapabilities, @@ -220,13 +221,13 @@ export class Server< handler: ( request: RequestTypeMap[M], extra: RequestHandlerExtra - ) => ServerResult | ResultT | Promise + ) => HandlerResultTypeMap[M] | Promise ): void { if (method === 'tools/call') { const wrappedHandler = async ( request: RequestTypeMap[M], extra: RequestHandlerExtra - ): Promise => { + ): Promise => { const validatedRequest = safeParse(CallToolRequestSchema, request); if (!validatedRequest.success) { const errorMessage = @@ -248,7 +249,7 @@ export class Server< : String(taskValidationResult.error); throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); } - return taskValidationResult.data; + return taskValidationResult.data as HandlerResultTypeMap[M]; } // For non-task requests, validate against CallToolResultSchema @@ -259,7 +260,7 @@ export class Server< throw new McpError(ErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); } - return validationResult.data; + return validationResult.data as HandlerResultTypeMap[M]; }; // Install the wrapped handler @@ -464,7 +465,7 @@ export class Server< } async ping() { - return this.request({ method: 'ping' }, EmptyResultSchema); + return this.request({ method: 'ping' }); } /** @@ -533,11 +534,9 @@ export class Server< } } - // Use different schemas based on whether tools are provided - if (params.tools) { - return this.request({ method: 'sampling/createMessage', params }, CreateMessageResultWithToolsSchema, options); - } - return this.request({ method: 'sampling/createMessage', params }, CreateMessageResultSchema, options); + // The result type varies based on whether tools are provided, but the mapping + // uses a single schema. The overloads handle type narrowing. + return this.request({ method: 'sampling/createMessage', params }, options) as Promise; } /** @@ -557,7 +556,7 @@ export class Server< } const urlParams = params as ElicitRequestURLParams; - return this.request({ method: 'elicitation/create', params: urlParams }, ElicitResultSchema, options); + return this.request({ method: 'elicitation/create', params: urlParams }, options); } case 'form': { if (!this._clientCapabilities?.elicitation?.form) { @@ -567,7 +566,7 @@ export class Server< const formParams: ElicitRequestFormParams = params.mode === 'form' ? (params as ElicitRequestFormParams) : { ...(params as ElicitRequestFormParams), mode: 'form' }; - const result = await this.request({ method: 'elicitation/create', params: formParams }, ElicitResultSchema, options); + const result = await this.request({ method: 'elicitation/create', params: formParams }, options); if (result.action === 'accept' && result.content && formParams.requestedSchema) { try { @@ -621,7 +620,7 @@ export class Server< } async listRoots(params?: ListRootsRequest['params'], options?: RequestOptions) { - return this.request({ method: 'roots/list', params }, ListRootsResultSchema, options); + return this.request({ method: 'roots/list', params }, options); } /** diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index eae10f087..0f70be719 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -808,7 +808,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + const result = await client.request({ method: 'tools/list' }); expect(result.tools).toHaveLength(1); expect(result.tools[0]!.name).toBe('test'); @@ -855,7 +855,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + const result = await client.request({ method: 'tools/list' }); expect(result.tools).toHaveLength(1); expect(result.tools[0]!.name).toBe('test'); @@ -903,7 +903,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + const result = await client.request({ method: 'tools/list' }); expect(result.tools).toHaveLength(1); expect(result.tools[0]!.name).toBe('test'); @@ -1708,7 +1708,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + const result = await client.request({ method: 'tools/list' }); expect(result.tools).toHaveLength(1); expect(result.tools[0]!.name).toBe('test-with-meta'); @@ -1744,7 +1744,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); - const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + const result = await client.request({ method: 'tools/list' }); expect(result.tools).toHaveLength(1); expect(result.tools[0]!.name).toBe('test-without-meta'); @@ -1809,7 +1809,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + const result = await client.request({ method: 'tools/list' }); expect(result.tools).toHaveLength(1); expect(result.tools[0]!.name).toBe('task-tool'); @@ -1878,7 +1878,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + const result = await client.request({ method: 'tools/list' }); expect(result.tools).toHaveLength(1); expect(result.tools[0]!.name).toBe('optional-task-tool'); @@ -2202,7 +2202,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); // Verify both resources are registered - let result = await client.request({ method: 'resources/list' }, ListResourcesResultSchema); + let result = await client.request({ method: 'resources/list' }); expect(result.resources).toHaveLength(2); @@ -2218,7 +2218,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(notifications).toMatchObject([{ method: 'notifications/resources/list_changed' }]); // Verify the resource was removed - result = await client.request({ method: 'resources/list' }, ListResourcesResultSchema); + result = await client.request({ method: 'resources/list' }); expect(result.resources).toHaveLength(1); expect(result.resources[0]!.uri).toBe('test://resource2'); @@ -2261,7 +2261,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); // Verify template is registered - const result = await client.request({ method: 'resources/templates/list' }, ListResourceTemplatesResultSchema); + const result = await client.request({ method: 'resources/templates/list' }); expect(result.resourceTemplates).toHaveLength(1); expect(notifications).toHaveLength(0); @@ -2276,7 +2276,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(notifications).toMatchObject([{ method: 'notifications/resources/list_changed' }]); // Verify the template was removed - const result2 = await client.request({ method: 'resources/templates/list' }, ListResourceTemplatesResultSchema); + const result2 = await client.request({ method: 'resources/templates/list' }); expect(result2.resourceTemplates).toHaveLength(0); }); @@ -3221,7 +3221,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); // Verify both prompts are registered - let result = await client.request({ method: 'prompts/list' }, ListPromptsResultSchema); + let result = await client.request({ method: 'prompts/list' }); expect(result.prompts).toHaveLength(2); expect(result.prompts.map(p => p.name).toSorted()).toEqual(['prompt1', 'prompt2']); @@ -3238,7 +3238,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { expect(notifications).toMatchObject([{ method: 'notifications/prompts/list_changed' }]); // Verify the prompt was removed - result = await client.request({ method: 'prompts/list' }, ListPromptsResultSchema); + result = await client.request({ method: 'prompts/list' }); expect(result.prompts).toHaveLength(1); expect(result.prompts[0]!.name).toBe('prompt2'); @@ -4076,7 +4076,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + const result = await client.request({ method: 'tools/list' }); expect(result.tools).toHaveLength(4); @@ -5366,7 +5366,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + const result = await client.request({ method: 'tools/list' }); expect(result.tools).toHaveLength(4); From cc7a95ad535ad370fdd4cc92f3aac0bbc78bc597 Mon Sep 17 00:00:00 2001 From: Matt Carey Date: Tue, 3 Feb 2026 13:21:28 +0000 Subject: [PATCH 3/3] fix: use HandlerResultTypeMap in setRequestHandler and remove schema args - Update Server/Client setRequestHandler to use HandlerResultTypeMap[M] - Remove schema arguments from all this.request() calls - Clean up unused schema imports --- packages/client/src/client/client.ts | 11 ----------- packages/server/src/server/server.ts | 9 +-------- 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 00d6523af..106e3f767 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -4,7 +4,6 @@ import type { ClientNotification, ClientRequest, ClientResult, - CompatibilityCallToolResultSchema, CompleteRequest, GetPromptRequest, Implementation, @@ -39,28 +38,18 @@ import { AjvJsonSchemaValidator, assertClientRequestTaskCapability, assertToolsCallTaskCapability, - CallToolResultSchema, - CompleteResultSchema, CreateMessageRequestSchema, CreateMessageResultSchema, CreateMessageResultWithToolsSchema, CreateTaskResultSchema, ElicitRequestSchema, ElicitResultSchema, - EmptyResultSchema, ErrorCode, - GetPromptResultSchema, - InitializeResultSchema, LATEST_PROTOCOL_VERSION, ListChangedOptionsBaseSchema, - ListPromptsResultSchema, - ListResourcesResultSchema, - ListResourceTemplatesResultSchema, - ListToolsResultSchema, McpError, mergeCapabilities, Protocol, - ReadResourceResultSchema, safeParse, SUPPORTED_PROTOCOL_VERSIONS } from '@modelcontextprotocol/core'; diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index fdf70ede8..16552448c 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -40,14 +40,9 @@ import { assertToolsCallTaskCapability, CallToolRequestSchema, CallToolResultSchema, - CreateMessageResultSchema, - CreateMessageResultWithToolsSchema, CreateTaskResultSchema, - ElicitResultSchema, - EmptyResultSchema, ErrorCode, LATEST_PROTOCOL_VERSION, - ListRootsResultSchema, LoggingLevelSchema, McpError, mergeCapabilities, @@ -534,9 +529,7 @@ export class Server< } } - // The result type varies based on whether tools are provided, but the mapping - // uses a single schema. The overloads handle type narrowing. - return this.request({ method: 'sampling/createMessage', params }, options) as Promise; + return this.request({ method: 'sampling/createMessage', params }, options); } /**