From c76a53ba0cfbbce2c3774900c825cb2baf31cf38 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Tue, 3 Feb 2026 20:19:23 +0000 Subject: [PATCH] refactor: extract task orchestration from Protocol into TaskManager Extract all task-related logic from Protocol into a dedicated TaskManager class with consolidated lifecycle API and NullTaskManager (Null Object pattern). Protocol delegates to TaskManager via 4 lifecycle methods: - processInboundRequest: context extraction + function wrapping - processOutboundRequest: metadata augmentation + queue routing - processInboundResponse: side-channel handling + progress preservation - processOutboundNotification: task-related notification routing NullTaskManager provides no-op defaults, eliminating all conditional checks in Protocol. --- .changeset/extract-task-manager.md | 14 + examples/server/src/simpleStreamableHttp.ts | 3 +- .../client/src/experimental/tasks/client.ts | 47 +- .../core/src/experimental/tasks/interfaces.ts | 3 +- packages/core/src/index.ts | 1 + packages/core/src/shared/protocol.ts | 927 ++---------------- packages/core/src/shared/taskManager.ts | 855 ++++++++++++++++ packages/core/test/shared/protocol.test.ts | 381 +++---- .../shared/protocolTransportHandling.test.ts | 2 +- .../server/src/experimental/tasks/server.ts | 42 +- test/integration/test/client/client.test.ts | 26 +- test/integration/test/helpers/mcp.ts | 6 +- test/integration/test/server.test.ts | 20 +- test/integration/test/server/mcp.test.ts | 16 +- test/integration/test/taskLifecycle.test.ts | 6 +- 15 files changed, 1216 insertions(+), 1133 deletions(-) create mode 100644 .changeset/extract-task-manager.md create mode 100644 packages/core/src/shared/taskManager.ts diff --git a/.changeset/extract-task-manager.md b/.changeset/extract-task-manager.md new file mode 100644 index 000000000..c70a7a6e1 --- /dev/null +++ b/.changeset/extract-task-manager.md @@ -0,0 +1,14 @@ +--- +"@modelcontextprotocol/core": minor +"@modelcontextprotocol/client": minor +"@modelcontextprotocol/server": minor +--- + +refactor: extract task orchestration from Protocol into TaskManager + +**Breaking changes:** +- `extra.taskId` → `extra.task?.taskId` +- `extra.taskStore` → `extra.task?.taskStore` +- `extra.taskRequestedTtl` → `extra.task?.requestedTtl` +- `ProtocolOptions` no longer accepts `taskStore`/`taskMessageQueue` — pass via `TaskManagerOptions` in `ClientOptions`/`ServerOptions` +- Abstract methods `assertTaskCapability`/`assertTaskHandlerCapability` removed from Protocol diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index dc86b17bd..5568668d0 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -47,8 +47,7 @@ const getServer = () => { }, { capabilities: { logging: {}, tasks: { requests: { tools: { call: {} } } } }, - taskStore, // Enable task support - taskMessageQueue: new InMemoryTaskMessageQueue() + tasks: { taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } } ); diff --git a/packages/client/src/experimental/tasks/client.ts b/packages/client/src/experimental/tasks/client.ts index 8fb16b2c2..7f5a4cc7b 100644 --- a/packages/client/src/experimental/tasks/client.ts +++ b/packages/client/src/experimental/tasks/client.ts @@ -25,11 +25,6 @@ import type { Client } from '../../client/client.js'; * @internal */ interface ClientInternal { - requestStream( - request: Request, - resultSchema: T, - options?: RequestOptions - ): AsyncGenerator>, void, void>; isToolTask(toolName: string): boolean; getToolOutputValidator(toolName: string): ((data: unknown) => { valid: boolean; errorMessage?: string }) | undefined; } @@ -90,7 +85,6 @@ export class ExperimentalClientTasks { ): AsyncGenerator>, void, void> { // Access Client's internal methods const clientInternal = this._client as unknown as ClientInternal; - // Add task creation parameters if server supports it and not explicitly provided const optionsWithTask = { ...options, @@ -99,7 +93,7 @@ export class ExperimentalClientTasks { task: options?.task ?? (clientInternal.isToolTask(params.name) ? {} : undefined) }; - const stream = clientInternal.requestStream({ method: 'tools/call', params }, CallToolResultSchema, optionsWithTask); + const stream = this._client.tasks.requestStream({ method: 'tools/call', params }, CallToolResultSchema, optionsWithTask); // Get the validator for this tool (if it has an output schema) const validator = clientInternal.getToolOutputValidator(params.name); @@ -170,9 +164,7 @@ export class ExperimentalClientTasks { * @experimental */ async getTask(taskId: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - type ClientWithGetTask = { getTask(params: { taskId: string }, options?: RequestOptions): Promise }; - return (this._client as unknown as ClientWithGetTask).getTask({ taskId }, options); + return this._client.tasks.getTask({ taskId }, options); } /** @@ -186,16 +178,7 @@ export class ExperimentalClientTasks { * @experimental */ async getTaskResult(taskId: string, resultSchema?: T, options?: RequestOptions): Promise> { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - getTaskResult: ( - params: { taskId: string }, - resultSchema?: U, - options?: RequestOptions - ) => Promise>; - } - ).getTaskResult({ taskId }, resultSchema, options); + return this._client.tasks.getTaskResult({ taskId }, resultSchema!, options); } /** @@ -208,12 +191,7 @@ export class ExperimentalClientTasks { * @experimental */ async listTasks(cursor?: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - listTasks: (params?: { cursor?: string }, options?: RequestOptions) => Promise; - } - ).listTasks(cursor ? { cursor } : undefined, options); + return this._client.tasks.listTasks(cursor ? { cursor } : undefined, options); } /** @@ -225,12 +203,7 @@ export class ExperimentalClientTasks { * @experimental */ async cancelTask(taskId: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - cancelTask: (params: { taskId: string }, options?: RequestOptions) => Promise; - } - ).cancelTask({ taskId }, options); + return this._client.tasks.cancelTask({ taskId }, options); } /** @@ -252,14 +225,6 @@ export class ExperimentalClientTasks { resultSchema: T, options?: RequestOptions ): AsyncGenerator>, void, void> { - // Delegate to the client's underlying Protocol method - type ClientWithRequestStream = { - requestStream( - request: Request, - resultSchema: U, - options?: RequestOptions - ): AsyncGenerator>, void, void>; - }; - return (this._client as unknown as ClientWithRequestStream).requestStream(request, resultSchema, options); + return this._client.tasks.requestStream(request, resultSchema, options); } } diff --git a/packages/core/src/experimental/tasks/interfaces.ts b/packages/core/src/experimental/tasks/interfaces.ts index 57a24639a..913030179 100644 --- a/packages/core/src/experimental/tasks/interfaces.ts +++ b/packages/core/src/experimental/tasks/interfaces.ts @@ -3,7 +3,8 @@ * WARNING: These APIs are experimental and may change without notice. */ -import type { RequestTaskStore, ServerContext } from '../../shared/protocol.js'; +import type { ServerContext } from '../../shared/protocol.js'; +import type { RequestTaskStore } from '../../shared/taskManager.js'; import type { JSONRPCErrorResponse, JSONRPCNotification, diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 63bd0034c..3aef00b2e 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -6,6 +6,7 @@ export * from './shared/metadataUtils.js'; export * from './shared/protocol.js'; export * from './shared/responseMessage.js'; export * from './shared/stdio.js'; +export * from './shared/taskManager.js'; export * from './shared/toolNameValidation.js'; export * from './shared/transport.js'; export * from './shared/uriTemplate.js'; diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index dfa98a171..695a5ae9e 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -1,6 +1,4 @@ import { SdkError, SdkErrorCode } from '../errors/sdkErrors.js'; -import type { CreateTaskOptions, QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; -import { isTerminal } from '../experimental/tasks/interfaces.js'; import type { AuthInfo, CancelledNotification, @@ -11,9 +9,6 @@ import type { ElicitRequestFormParams, ElicitRequestURLParams, ElicitResult, - GetTaskPayloadRequest, - GetTaskRequest, - GetTaskResult, JSONRPCErrorResponse, JSONRPCNotification, JSONRPCRequest, @@ -35,31 +30,23 @@ import type { RequestTypeMap, Result, ServerCapabilities, - Task, - TaskCreationParams, - TaskStatusNotification + TaskCreationParams } from '../types/types.js'; import { - CancelTaskResultSchema, - CreateTaskResultSchema, getNotificationSchema, getRequestSchema, - GetTaskResultSchema, isJSONRPCErrorResponse, isJSONRPCNotification, isJSONRPCRequest, isJSONRPCResultResponse, - isTaskAugmentedRequestParams, - ListTasksResultSchema, ProtocolError, ProtocolErrorCode, - RELATED_TASK_META_KEY, - SUPPORTED_PROTOCOL_VERSIONS, - TaskStatusNotificationSchema + SUPPORTED_PROTOCOL_VERSIONS } from '../types/types.js'; -import type { AnyObjectSchema, AnySchema, SchemaOutput } from '../util/schema.js'; +import type { AnySchema, SchemaOutput } from '../util/schema.js'; import { parseSchema } from '../util/schema.js'; -import type { ResponseMessage } from './responseMessage.js'; +import type { TaskContext, TaskManagerHost, TaskManagerOptions, TaskRequestOptions } from './taskManager.js'; +import { NullTaskManager, TaskManager } from './taskManager.js'; import type { Transport, TransportSendOptions } from './transport.js'; /** @@ -95,28 +82,10 @@ export type ProtocolOptions = { */ debouncedNotificationMethods?: string[]; /** - * Optional task storage implementation. If provided, enables task-related request handlers + * Task manager options. If a taskStore is provided, enables task-related request handlers * and provides task storage capabilities to request handlers. */ - taskStore?: TaskStore; - /** - * Optional task message queue implementation for managing server-initiated messages - * that will be delivered through the tasks/result response stream. - */ - taskMessageQueue?: TaskMessageQueue; - /** - * Default polling interval (in milliseconds) for task status checks when no pollInterval - * is provided by the server. Defaults to 5000ms if not specified. - */ - defaultTaskPollInterval?: number; - /** - * Maximum number of messages that can be queued per task for side-channel delivery. - * If undefined, the queue size is unbounded. - * When the limit is exceeded, the TaskMessageQueue implementation's enqueue() method - * will throw an error. It's the implementation's responsibility to handle overflow - * appropriately (e.g., by failing the task, dropping messages, etc.). - */ - maxTaskQueueSize?: number; + tasks?: TaskManagerOptions; }; /** @@ -187,78 +156,6 @@ export type NotificationOptions = { relatedTask?: RelatedTaskMetadata; }; -/** - * Options that can be given per request. - */ -// relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. -export type TaskRequestOptions = Omit; - -/** - * Request-scoped TaskStore interface. - */ -export interface RequestTaskStore { - /** - * Creates a new task with the given creation parameters. - * The implementation generates a unique taskId and createdAt timestamp. - * - * @param taskParams - The task creation parameters from the request - * @returns The created task object - */ - createTask(taskParams: CreateTaskOptions): Promise; - - /** - * Gets the current status of a task. - * - * @param taskId - The task identifier - * @returns The task object - * @throws If the task does not exist - */ - getTask(taskId: string): Promise; - - /** - * Stores the result of a task and sets its final status. - * - * @param taskId - The task identifier - * @param status - The final status: 'completed' for success, 'failed' for errors - * @param result - The result to store - */ - storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise; - - /** - * Retrieves the stored result of a task. - * - * @param taskId - The task identifier - * @returns The stored result - */ - getTaskResult(taskId: string): Promise; - - /** - * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). - * - * @param taskId - The task identifier - * @param status - The new status - * @param statusMessage - Optional diagnostic message for failed tasks or other status information - */ - updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise; - - /** - * Lists tasks, optionally starting from a pagination cursor. - * - * @param cursor - Optional cursor for pagination - * @returns An object containing the tasks array and an optional nextCursor - */ - listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; -} - -/** - * Task context provided to request handlers when task storage is configured. - */ -export type TaskContext = { - id?: string; - store: RequestTaskStore; - requestedTtl?: number | null; -}; - /** * Base context provided to all request handlers. */ @@ -400,13 +297,7 @@ export abstract class Protocol { private _timeoutInfo: Map = new Map(); private _pendingDebouncedNotifications = new Set(); - // Maps task IDs to progress tokens to keep handlers alive after CreateTaskResult - private _taskProgressTokens: Map = new Map(); - - private _taskStore?: TaskStore; - private _taskMessageQueue?: TaskMessageQueue; - - private _requestResolvers: Map void> = new Map(); + private _taskManager: TaskManager; protected _supportedProtocolVersions: string[]; @@ -437,6 +328,24 @@ export abstract class Protocol { constructor(private _options?: ProtocolOptions) { this._supportedProtocolVersions = _options?.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; + // Create TaskManager from protocol options + this._taskManager = _options?.tasks?.taskStore ? new TaskManager(_options.tasks) : new NullTaskManager(); + + // Bind TaskManager to this Protocol as its host + const host: TaskManagerHost = { + request: (request, resultSchema, options) => this.request(request, resultSchema, options), + notification: (notification, options) => this.notification(notification, options), + reportError: error => this._onerror(error), + removeProgressHandler: token => this._progressHandlers.delete(token), + registerHandler: (method, handler) => { + this._requestHandlers.set(method, (request, ctx) => handler(request, ctx)); + }, + sendOnResponseStream: async (message, relatedRequestId) => { + await this._transport?.send(message, { relatedRequestId }); + } + }; + this._taskManager.bind(host); + this.setNotificationHandler('notifications/cancelled', notification => { this._oncancel(notification); }); @@ -450,175 +359,16 @@ export abstract class Protocol { // Automatic pong by default. _request => ({}) as Result ); + } - // Install task handlers if TaskStore is provided - this._taskStore = _options?.taskStore; - this._taskMessageQueue = _options?.taskMessageQueue; - if (this._taskStore) { - this.setRequestHandler('tasks/get', async (request, ctx) => { - const task = await this._taskStore!.getTask(request.params.taskId, ctx.sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); - } - - // Per spec: tasks/get responses SHALL NOT include related-task metadata - // as the taskId parameter is the source of truth - return { - ...task - } as Result; - }); - - this.setRequestHandler('tasks/result', async (request, ctx) => { - const handleTaskResult = async (): Promise => { - const taskId = request.params.taskId; - - // Deliver queued messages - if (this._taskMessageQueue) { - let queuedMessage: QueuedMessage | undefined; - while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, ctx.sessionId))) { - // Handle response and error messages by routing them to the appropriate resolver - if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { - const message = queuedMessage.message; - const requestId = message.id; - - // Lookup resolver in _requestResolvers map - const resolver = this._requestResolvers.get(requestId as RequestId); - - if (resolver) { - // Remove resolver from map after invocation - this._requestResolvers.delete(requestId as RequestId); - - // Invoke resolver with response or error - if (queuedMessage.type === 'response') { - resolver(message as JSONRPCResultResponse); - } else { - // Convert JSONRPCError to ProtocolError - const errorMessage = message as JSONRPCErrorResponse; - const error = new ProtocolError( - errorMessage.error.code, - errorMessage.error.message, - errorMessage.error.data - ); - resolver(error); - } - } else { - // Handle missing resolver gracefully with error logging - const messageType = queuedMessage.type === 'response' ? 'Response' : 'Error'; - this._onerror(new Error(`${messageType} handler missing for request ${requestId}`)); - } - - // Continue to next message - continue; - } - - // Send the message on the response stream by passing the relatedRequestId - // This tells the transport to write the message to the tasks/result response stream - await this._transport?.send(queuedMessage.message, { relatedRequestId: ctx.mcpReq.id }); - } - } - - // Now check task status - const task = await this._taskStore!.getTask(taskId, ctx.sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found: ${taskId}`); - } - - // Block if task is not terminal (we've already delivered all queued messages above) - if (!isTerminal(task.status)) { - // Wait for status change or new messages - await this._waitForTaskUpdate(taskId, ctx.mcpReq.signal); - - // After waking up, recursively call to deliver any new messages or result - return await handleTaskResult(); - } - - // If task is terminal, return the result - if (isTerminal(task.status)) { - const result = await this._taskStore!.getTaskResult(taskId, ctx.sessionId); - - this._clearTaskQueue(taskId); - - return { - ...result, - _meta: { - ...result._meta, - [RELATED_TASK_META_KEY]: { - taskId: taskId - } - } - } as Result; - } - - return await handleTaskResult(); - }; - - return await handleTaskResult(); - }); - - this.setRequestHandler('tasks/list', async (request, ctx) => { - try { - const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, ctx.sessionId); - return { - tasks, - nextCursor, - _meta: {} - } as Result; - } catch (error) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` - ); - } - }); - - this.setRequestHandler('tasks/cancel', async (request, ctx) => { - try { - // Get the current task to check if it's in a terminal state, in case the implementation is not atomic - const task = await this._taskStore!.getTask(request.params.taskId, ctx.sessionId); - - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); - } - - // Reject cancellation of terminal tasks - if (isTerminal(task.status)) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); - } - - await this._taskStore!.updateTaskStatus( - request.params.taskId, - 'cancelled', - 'Client cancelled task execution.', - ctx.sessionId - ); - - this._clearTaskQueue(request.params.taskId); - - const cancelledTask = await this._taskStore!.getTask(request.params.taskId, ctx.sessionId); - if (!cancelledTask) { - // Task was deleted during cancellation (e.g., cleanup happened) - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Task not found after cancellation: ${request.params.taskId}` - ); - } - - return { - _meta: {}, - ...cancelledTask - } as Result; - } catch (error) { - // Re-throw ProtocolError as-is - if (error instanceof ProtocolError) { - throw error; - } - throw new ProtocolError( - ProtocolErrorCode.InvalidRequest, - `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` - ); - } - }); - } + /** + * Access the TaskManager for task orchestration. + * Always available; returns a NullTaskManager when no task store is configured. + * The NullTaskManager supports outbound methods (requestStream, getTask, etc.) + * but does not handle inbound task requests. + */ + get tasks(): TaskManager { + return this._taskManager; } /** @@ -722,7 +472,7 @@ export abstract class Protocol { const responseHandlers = this._responseHandlers; this._responseHandlers = new Map(); this._progressHandlers.clear(); - this._taskProgressTokens.clear(); + this._taskManager.onClose(); this._pendingDebouncedNotifications.clear(); const error = new SdkError(SdkErrorCode.ConnectionClosed, 'Connection closed'); @@ -759,8 +509,12 @@ export abstract class Protocol { // Capture the current transport at request time to ensure responses go to the correct client const capturedTransport = this._transport; - // Extract taskId from request metadata if present (needed early for method not found case) - const relatedTaskId = request.params?._meta?.[RELATED_TASK_META_KEY]?.taskId; + // Delegate task context extraction to TaskManager + const taskResult = this._taskManager.processInboundRequest(request, { + sessionId: capturedTransport?.sessionId, + sendNotification: (notification, options?) => this.notification(notification, { ...options, relatedRequestId: request.id }), + sendRequest: (r, resultSchema, options?) => this.request(r, resultSchema, { ...options, relatedRequestId: request.id }) + }); if (handler === undefined) { const errorResponse: JSONRPCErrorResponse = { @@ -773,34 +527,22 @@ export abstract class Protocol { }; // Queue or send the error response based on whether this is a task-related request - if (relatedTaskId && this._taskMessageQueue) { - this._enqueueTaskMessage( - relatedTaskId, - { - type: 'error', - message: errorResponse, - timestamp: Date.now() - }, - capturedTransport?.sessionId - ).catch(error => this._onerror(new Error(`Failed to enqueue error response: ${error}`))); - } else { - capturedTransport - ?.send(errorResponse) - .catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); - } + taskResult + .routeResponse(errorResponse) + .then(routed => { + if (!routed) { + capturedTransport + ?.send(errorResponse) + .catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); + } + }) + .catch(error => this._onerror(new Error(`Failed to enqueue error response: ${error}`))); return; } const abortController = new AbortController(); this._requestHandlerAbortControllers.set(request.id, abortController); - const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined; - const taskStore = this._taskStore ? this.requestTaskStore(request, capturedTransport?.sessionId) : undefined; - - const task: TaskContext | undefined = taskStore - ? { id: relatedTaskId, store: taskStore, requestedTtl: taskCreationParams?.ttl } - : undefined; - const baseCtx: BaseContext = { sessionId: capturedTransport?.sessionId, mcpReq: { @@ -808,27 +550,11 @@ export abstract class Protocol { method: request.method, _meta: request.params?._meta, signal: abortController.signal, - send: async (r, resultSchema, options?) => { - const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; - if (relatedTaskId && !requestOptions.relatedTask) { - requestOptions.relatedTask = { taskId: relatedTaskId }; - } - const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; - if (effectiveTaskId && taskStore) { - await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); - } - return await this.request(r, resultSchema, requestOptions); - }, - notify: async notification => { - const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; - if (relatedTaskId) { - notificationOptions.relatedTask = { taskId: relatedTaskId }; - } - await this.notification(notification, notificationOptions); - } + send: taskResult.sendRequest, + notify: taskResult.sendNotification }, http: extra?.authInfo ? { authInfo: extra.authInfo } : undefined, - task + task: taskResult.taskContext }; const ctx = this.buildContext(baseCtx, extra); @@ -836,8 +562,7 @@ export abstract class Protocol { Promise.resolve() .then(() => { // If this request asked for task creation, check capability first - if (taskCreationParams) { - // Check if the request method supports task creation + if (taskResult.hasTaskCreationParams) { this.assertTaskHandlerCapability(request.method); } }) @@ -856,17 +581,10 @@ export abstract class Protocol { }; // Queue or send the response based on whether this is a task-related request - await (relatedTaskId && this._taskMessageQueue - ? this._enqueueTaskMessage( - relatedTaskId, - { - type: 'response', - message: response, - timestamp: Date.now() - }, - capturedTransport?.sessionId - ) - : capturedTransport?.send(response)); + const routed = await taskResult.routeResponse(response); + if (!routed) { + await capturedTransport?.send(response); + } }, async error => { if (abortController.signal.aborted) { @@ -885,17 +603,10 @@ export abstract class Protocol { }; // Queue or send the error response based on whether this is a task-related request - await (relatedTaskId && this._taskMessageQueue - ? this._enqueueTaskMessage( - relatedTaskId, - { - type: 'error', - message: errorResponse, - timestamp: Date.now() - }, - capturedTransport?.sessionId - ) - : capturedTransport?.send(errorResponse)); + const routed = await taskResult.routeResponse(errorResponse); + if (!routed) { + await capturedTransport?.send(errorResponse); + } } ) .catch(error => this._onerror(new Error(`Failed to send response: ${error}`))) @@ -936,16 +647,9 @@ export abstract class Protocol { private _onresponse(response: JSONRPCResponse | JSONRPCErrorResponse): void { const messageId = Number(response.id); - // Check if this is a response to a queued request - const resolver = this._requestResolvers.get(messageId); - if (resolver) { - this._requestResolvers.delete(messageId); - if (isJSONRPCResultResponse(response)) { - resolver(response); - } else { - const error = new ProtocolError(response.error.code, response.error.message, response.error.data); - resolver(error); - } + // Delegate to TaskManager for task-related response handling + const { consumed, preserveProgress } = this._taskManager.processInboundResponse(response, messageId); + if (consumed) { return; } @@ -959,19 +663,7 @@ export abstract class Protocol { this._cleanupTimeout(messageId); // Keep progress handler alive for CreateTaskResult responses - let isTaskResponse = false; - if (isJSONRPCResultResponse(response) && response.result && typeof response.result === 'object') { - const result = response.result as Record; - if (result.task && typeof result.task === 'object') { - const task = result.task as Record; - if (typeof task.taskId === 'string') { - isTaskResponse = true; - this._taskProgressTokens.set(task.taskId, messageId); - } - } - } - - if (!isTaskResponse) { + if (!preserveProgress) { this._progressHandlers.delete(messageId); } @@ -1016,149 +708,26 @@ export abstract class Protocol { protected abstract assertRequestHandlerCapability(method: string): void; /** - * A method to check if task creation is supported for the given request method. + * A method to check if task creation is supported for a given request method. * * This should be implemented by subclasses. */ protected abstract assertTaskCapability(method: string): void; /** - * A method to check if task handler is supported by the local side, for the given method to be handled. + * A method to check if task handling is supported for a given request method. * * This should be implemented by subclasses. */ protected abstract assertTaskHandlerCapability(method: string): void; - /** - * Sends a request and returns an AsyncGenerator that yields response messages. - * The generator is guaranteed to end with either a 'result' or 'error' message. - * - * @example - * ```typescript - * const stream = protocol.requestStream(request, resultSchema, options); - * for await (const message of stream) { - * switch (message.type) { - * case 'taskCreated': - * console.log('Task created:', message.task.taskId); - * break; - * case 'taskStatus': - * console.log('Task status:', message.task.status); - * break; - * case 'result': - * console.log('Final result:', message.result); - * break; - * case 'error': - * console.error('Error:', message.error); - * break; - * } - * } - * ``` - * - * @experimental Use `client.experimental.tasks.requestStream()` to access this method. - */ - protected async *requestStream( - request: Request, - resultSchema: T, - options?: RequestOptions - ): AsyncGenerator>, void, void> { - const { task } = options ?? {}; - - // For non-task requests, just yield the result - if (!task) { - try { - const result = await this.request(request, resultSchema, options); - yield { type: 'result', result }; - } catch (error) { - yield { - type: 'error', - error: error instanceof Error ? error : new Error(String(error)) - }; - } - return; - } - - // For task-augmented requests, we need to poll for status - // First, make the request to create the task - let taskId: string | undefined; - try { - // Send the request and get the CreateTaskResult - const createResult = await this.request(request, CreateTaskResultSchema, options); - - // Extract taskId from the result - if (createResult.task) { - taskId = createResult.task.taskId; - yield { type: 'taskCreated', task: createResult.task }; - } else { - throw new ProtocolError(ProtocolErrorCode.InternalError, 'Task creation did not return a task'); - } - - // Poll for task completion - while (true) { - // Get current task status - const task = await this.getTask({ taskId }, options); - yield { type: 'taskStatus', task }; - - // Check if task is terminal - if (isTerminal(task.status)) { - switch (task.status) { - case 'completed': { - // Get the final result - const result = await this.getTaskResult({ taskId }, resultSchema, options); - yield { type: 'result', result }; - - break; - } - case 'failed': { - yield { - type: 'error', - error: new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} failed`) - }; - - break; - } - case 'cancelled': { - yield { - type: 'error', - error: new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} was cancelled`) - }; - - break; - } - // No default - } - return; - } - - // When input_required, call tasks/result to deliver queued messages - // (elicitation, sampling) via SSE and block until terminal - if (task.status === 'input_required') { - const result = await this.getTaskResult({ taskId }, resultSchema, options); - yield { type: 'result', result }; - return; - } - - // Wait before polling again - const pollInterval = task.pollInterval ?? this._options?.defaultTaskPollInterval ?? 1000; - await new Promise(resolve => setTimeout(resolve, pollInterval)); - - // Check if cancelled - options?.signal?.throwIfAborted(); - } - } catch (error) { - yield { - type: 'error', - error: error instanceof Error ? error : new Error(String(error)) - }; - } - } - /** * Sends a request and waits for a response. * * Do not use this method to emit notifications! Use notification() instead. */ request(request: Request, resultSchema: T, options?: RequestOptions): Promise> { - const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; + const { relatedRequestId, resumptionToken, onresumptiontoken, task } = options ?? {}; // Send the request return new Promise>((resolve, reject) => { @@ -1205,25 +774,6 @@ export abstract class Protocol { }; } - // Augment with task creation parameters if provided - if (task) { - jsonrpcRequest.params = { - ...jsonrpcRequest.params, - task: task - }; - } - - // Augment with related task metadata if relatedTask is provided - if (relatedTask) { - jsonrpcRequest.params = { - ...jsonrpcRequest.params, - _meta: { - ...jsonrpcRequest.params?._meta, - [RELATED_TASK_META_KEY]: relatedTask - } - }; - } - const cancel = (reason: unknown) => { this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); @@ -1278,33 +828,22 @@ export abstract class Protocol { this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); - // Queue request if related to a task - const relatedTaskId = relatedTask?.taskId; - if (relatedTaskId) { - // Store the response resolver for this request so responses can be routed back - const responseResolver = (response: JSONRPCResultResponse | Error) => { - const handler = this._responseHandlers.get(messageId); - if (handler) { - handler(response); - } else { - // Log error when resolver is missing, but don't fail - this._onerror(new Error(`Response handler missing for side-channeled request ${messageId}`)); - } - }; - this._requestResolvers.set(messageId, responseResolver); + // Delegate task augmentation and routing to TaskManager + const responseHandler = (response: JSONRPCResultResponse | Error) => { + const handler = this._responseHandlers.get(messageId); + if (handler) { + handler(response); + } else { + this._onerror(new Error(`Response handler missing for side-channeled request ${messageId}`)); + } + }; - this._enqueueTaskMessage(relatedTaskId, { - type: 'request', - message: jsonrpcRequest, - timestamp: Date.now() - }).catch(error => { - this._cleanupTimeout(messageId); - reject(error); - }); + const { queued } = this._taskManager.processOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, error => { + this._cleanupTimeout(messageId); + reject(error); + }); - // Don't send through transport - queued messages are delivered via tasks/result only - // This prevents duplicate delivery for bidirectional transports - } else { + if (!queued) { // No related task - send through transport normally this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { this._cleanupTimeout(messageId); @@ -1314,46 +853,6 @@ export abstract class Protocol { }); } - /** - * Gets the current status of a task. - * - * @experimental Use `client.experimental.tasks.getTask()` to access this method. - */ - protected async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { - return this.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); - } - - /** - * Retrieves the result of a completed task. - * - * @experimental Use `client.experimental.tasks.getTaskResult()` to access this method. - */ - protected async getTaskResult( - params: GetTaskPayloadRequest['params'], - resultSchema: T, - options?: RequestOptions - ): Promise> { - return this.request({ method: 'tasks/result', params }, resultSchema, options); - } - - /** - * Lists tasks, optionally starting from a pagination cursor. - * - * @experimental Use `client.experimental.tasks.listTasks()` to access this method. - */ - protected async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise> { - return this.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); - } - - /** - * Cancels a specific task. - * - * @experimental Use `client.experimental.tasks.cancelTask()` to access this method. - */ - protected async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { - return this.request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); - } - /** * Emits a notification, which is a one-way message that does not expect a response. */ @@ -1364,30 +863,10 @@ export abstract class Protocol { this.assertNotificationCapability(notification.method as NotificationMethod); - // Queue notification if related to a task - const relatedTaskId = options?.relatedTask?.taskId; - if (relatedTaskId) { - // Build the JSONRPC notification with metadata - const jsonrpcNotification: JSONRPCNotification = { - ...notification, - jsonrpc: '2.0', - params: { - ...notification.params, - _meta: { - ...notification.params?._meta, - [RELATED_TASK_META_KEY]: options.relatedTask - } - } - }; - - await this._enqueueTaskMessage(relatedTaskId, { - type: 'notification', - message: jsonrpcNotification, - timestamp: Date.now() - }); - + // Delegate task-related notification routing and JSONRPC building to TaskManager + const { queued, jsonrpcNotification } = await this._taskManager.processOutboundNotification(notification, options); + if (queued) { // Don't send through transport - queued messages are delivered via tasks/result only - // This prevents duplicate delivery for bidirectional transports return; } @@ -1417,54 +896,16 @@ export abstract class Protocol { return; } - let jsonrpcNotification: JSONRPCNotification = { - ...notification, - jsonrpc: '2.0' - }; - - // Augment with related task metadata if relatedTask is provided - if (options?.relatedTask) { - jsonrpcNotification = { - ...jsonrpcNotification, - params: { - ...jsonrpcNotification.params, - _meta: { - ...jsonrpcNotification.params?._meta, - [RELATED_TASK_META_KEY]: options.relatedTask - } - } - }; - } - // Send the notification, but don't await it here to avoid blocking. // Handle potential errors with a .catch(). - this._transport?.send(jsonrpcNotification, options).catch(error => this._onerror(error)); + this._transport?.send(jsonrpcNotification!, options).catch(error => this._onerror(error)); }); // Return immediately. return; } - let jsonrpcNotification: JSONRPCNotification = { - ...notification, - jsonrpc: '2.0' - }; - - // Augment with related task metadata if relatedTask is provided - if (options?.relatedTask) { - jsonrpcNotification = { - ...jsonrpcNotification, - params: { - ...jsonrpcNotification.params, - _meta: { - ...jsonrpcNotification.params?._meta, - [RELATED_TASK_META_KEY]: options.relatedTask - } - } - }; - } - - await this._transport.send(jsonrpcNotification, options); + await this._transport.send(jsonrpcNotification!, options); } /** @@ -1524,194 +965,6 @@ export abstract class Protocol { removeNotificationHandler(method: NotificationMethod): void { this._notificationHandlers.delete(method); } - - /** - * Cleans up the progress handler associated with a task. - * This should be called when a task reaches a terminal status. - */ - private _cleanupTaskProgressHandler(taskId: string): void { - const progressToken = this._taskProgressTokens.get(taskId); - if (progressToken !== undefined) { - this._progressHandlers.delete(progressToken); - this._taskProgressTokens.delete(taskId); - } - } - - /** - * Enqueues a task-related message for side-channel delivery via tasks/result. - * @param taskId The task ID to associate the message with - * @param message The message to enqueue - * @param sessionId Optional session ID for binding the operation to a specific session - * @throws Error if taskStore is not configured or if enqueue fails (e.g., queue overflow) - * - * Note: If enqueue fails, it's the TaskMessageQueue implementation's responsibility to handle - * the error appropriately (e.g., by failing the task, logging, etc.). The Protocol layer - * simply propagates the error. - */ - private async _enqueueTaskMessage(taskId: string, message: QueuedMessage, sessionId?: string): Promise { - // Task message queues are only used when taskStore is configured - if (!this._taskStore || !this._taskMessageQueue) { - throw new Error('Cannot enqueue task message: taskStore and taskMessageQueue are not configured'); - } - - const maxQueueSize = this._options?.maxTaskQueueSize; - await this._taskMessageQueue.enqueue(taskId, message, sessionId, maxQueueSize); - } - - /** - * Clears the message queue for a task and rejects any pending request resolvers. - * @param taskId The task ID whose queue should be cleared - * @param sessionId Optional session ID for binding the operation to a specific session - */ - private async _clearTaskQueue(taskId: string, sessionId?: string): Promise { - if (this._taskMessageQueue) { - // Reject any pending request resolvers - const messages = await this._taskMessageQueue.dequeueAll(taskId, sessionId); - for (const message of messages) { - if (message.type === 'request' && isJSONRPCRequest(message.message)) { - // Extract request ID from the message - const requestId = message.message.id as RequestId; - const resolver = this._requestResolvers.get(requestId); - if (resolver) { - resolver(new ProtocolError(ProtocolErrorCode.InternalError, 'Task cancelled or completed')); - this._requestResolvers.delete(requestId); - } else { - // Log error when resolver is missing during cleanup for better observability - this._onerror(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); - } - } - } - } - } - - /** - * Waits for a task update (new messages or status change) with abort signal support. - * Uses polling to check for updates at the task's configured poll interval. - * @param taskId The task ID to wait for - * @param signal Abort signal to cancel the wait - * @returns Promise that resolves when an update occurs or rejects if aborted - */ - private async _waitForTaskUpdate(taskId: string, signal: AbortSignal): Promise { - // Get the task's poll interval, falling back to default - let interval = this._options?.defaultTaskPollInterval ?? 1000; - try { - const task = await this._taskStore?.getTask(taskId); - if (task?.pollInterval) { - interval = task.pollInterval; - } - } catch { - // Use default interval if task lookup fails - } - - return new Promise((resolve, reject) => { - if (signal.aborted) { - reject(new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Request cancelled')); - return; - } - - // Wait for the poll interval, then resolve so caller can check for updates - const timeoutId = setTimeout(resolve, interval); - - // Clean up timeout and reject if aborted - signal.addEventListener( - 'abort', - () => { - clearTimeout(timeoutId); - reject(new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Request cancelled')); - }, - { once: true } - ); - }); - } - - private requestTaskStore(request?: JSONRPCRequest, sessionId?: string): RequestTaskStore { - const taskStore = this._taskStore; - if (!taskStore) { - throw new Error('No task store configured'); - } - - return { - createTask: async taskParams => { - if (!request) { - throw new Error('No request provided'); - } - - return await taskStore.createTask( - taskParams, - request.id, - { - method: request.method, - params: request.params - }, - sessionId - ); - }, - getTask: async taskId => { - const task = await taskStore.getTask(taskId, sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); - } - - return task; - }, - storeTaskResult: async (taskId, status, result) => { - await taskStore.storeTaskResult(taskId, status, result, sessionId); - - // Get updated task state and send notification - const task = await taskStore.getTask(taskId, sessionId); - if (task) { - const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ - method: 'notifications/tasks/status', - params: task - }); - await this.notification(notification as Notification); - - if (isTerminal(task.status)) { - this._cleanupTaskProgressHandler(taskId); - // Don't clear queue here - it will be cleared after delivery via tasks/result - } - } - }, - getTaskResult: taskId => { - return taskStore.getTaskResult(taskId, sessionId); - }, - updateTaskStatus: async (taskId, status, statusMessage) => { - // Check if task exists - const task = await taskStore.getTask(taskId, sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task "${taskId}" not found - it may have been cleaned up`); - } - - // Don't allow transitions from terminal states - if (isTerminal(task.status)) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Cannot update task "${taskId}" from terminal status "${task.status}" to "${status}". Terminal states (completed, failed, cancelled) cannot transition to other states.` - ); - } - - await taskStore.updateTaskStatus(taskId, status, statusMessage, sessionId); - - // Get updated task state and send notification - const updatedTask = await taskStore.getTask(taskId, sessionId); - if (updatedTask) { - const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ - method: 'notifications/tasks/status', - params: updatedTask - }); - await this.notification(notification as Notification); - - if (isTerminal(updatedTask.status)) { - this._cleanupTaskProgressHandler(taskId); - // Don't clear queue here - it will be cleared after delivery via tasks/result - } - } - }, - listTasks: cursor => { - return taskStore.listTasks(cursor, sessionId); - } - }; - } } function isPlainObject(value: unknown): value is Record { diff --git a/packages/core/src/shared/taskManager.ts b/packages/core/src/shared/taskManager.ts new file mode 100644 index 000000000..1eac79df5 --- /dev/null +++ b/packages/core/src/shared/taskManager.ts @@ -0,0 +1,855 @@ +import type { CreateTaskOptions, QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; +import { isTerminal } from '../experimental/tasks/interfaces.js'; +import type { + GetTaskPayloadRequest, + GetTaskRequest, + GetTaskResult, + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + Notification, + Request, + RequestId, + Result, + Task, + TaskCreationParams, + TaskStatusNotification +} from '../types/types.js'; +import { + CancelTaskResultSchema, + CreateTaskResultSchema, + GetTaskResultSchema, + isJSONRPCErrorResponse, + isJSONRPCRequest, + isJSONRPCResultResponse, + isTaskAugmentedRequestParams, + ListTasksResultSchema, + ProtocolError, + ProtocolErrorCode, + RELATED_TASK_META_KEY, + TaskStatusNotificationSchema +} from '../types/types.js'; +import type { AnyObjectSchema, AnySchema, SchemaOutput } from '../util/schema.js'; +import type { BaseContext, NotificationOptions, RequestOptions } from './protocol.js'; +import type { ResponseMessage } from './responseMessage.js'; + +/** + * Options that can be given per request. + */ +// relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. +export type TaskRequestOptions = Omit; + +/** + * Request-scoped TaskStore interface. + */ +export interface RequestTaskStore { + /** + * Creates a new task with the given creation parameters. + * The implementation generates a unique taskId and createdAt timestamp. + * + * @param taskParams - The task creation parameters from the request + * @returns The created task object + */ + createTask(taskParams: CreateTaskOptions): Promise; + + /** + * Gets the current status of a task. + * + * @param taskId - The task identifier + * @returns The task object + * @throws If the task does not exist + */ + getTask(taskId: string): Promise; + + /** + * Stores the result of a task and sets its final status. + * + * @param taskId - The task identifier + * @param status - The final status: 'completed' for success, 'failed' for errors + * @param result - The result to store + */ + storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise; + + /** + * Retrieves the stored result of a task. + * + * @param taskId - The task identifier + * @returns The stored result + */ + getTaskResult(taskId: string): Promise; + + /** + * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). + * + * @param taskId - The task identifier + * @param status - The new status + * @param statusMessage - Optional diagnostic message for failed tasks or other status information + */ + updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise; + + /** + * Lists tasks, optionally starting from a pagination cursor. + * + * @param cursor - Optional cursor for pagination + * @returns An object containing the tasks array and an optional nextCursor + */ + listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; +} + +/** + * Task context provided to request handlers when task storage is configured. + */ +export type TaskContext = { + id?: string; + store: RequestTaskStore; + requestedTtl?: number | null; +}; + +export type TaskManagerOptions = { + /** + * Task storage implementation. Required for handling incoming task requests (server-side). + * Not required for sending task requests (client-side outbound API). + */ + taskStore?: TaskStore; + /** + * Optional task message queue implementation for managing server-initiated messages + * that will be delivered through the tasks/result response stream. + */ + taskMessageQueue?: TaskMessageQueue; + /** + * Default polling interval (in milliseconds) for task status checks when no pollInterval + * is provided by the server. Defaults to 1000ms if not specified. + */ + defaultTaskPollInterval?: number; + /** + * Maximum number of messages that can be queued per task for side-channel delivery. + * If undefined, the queue size is unbounded. + */ + maxTaskQueueSize?: number; +}; + +export interface InboundRequestContext { + sessionId?: string; + sendNotification: (notification: Notification, options?: NotificationOptions) => Promise; + sendRequest: (request: Request, resultSchema: U, options?: RequestOptions) => Promise>; +} + +export interface InboundRequestResult { + taskContext?: TaskContext; + sendNotification: (notification: Notification) => Promise; + sendRequest: (request: Request, resultSchema: U, options?: TaskRequestOptions) => Promise>; + routeResponse: (message: JSONRPCResponse | JSONRPCErrorResponse) => Promise; + hasTaskCreationParams: boolean; +} + +/** @internal */ +export interface TaskManagerHost { + request(request: Request, resultSchema: T, options?: RequestOptions): Promise>; + notification(notification: Notification, options?: NotificationOptions): Promise; + reportError(error: Error): void; + removeProgressHandler(token: number): void; + registerHandler(method: string, handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise): void; + sendOnResponseStream(message: JSONRPCNotification | JSONRPCRequest, relatedRequestId: RequestId): Promise; +} + +/** + * Manages task orchestration: state, message queuing, and polling. + */ +export class TaskManager { + private _taskStore?: TaskStore; + private _taskMessageQueue?: TaskMessageQueue; + private _taskProgressTokens: Map = new Map(); + private _requestResolvers: Map void> = new Map(); + private _options: TaskManagerOptions; + private _host?: TaskManagerHost; + + constructor(options: TaskManagerOptions) { + this._options = options; + this._taskStore = options.taskStore; + this._taskMessageQueue = options.taskMessageQueue; + } + + bind(host: TaskManagerHost): void { + this._host = host; + + if (this._taskStore) { + host.registerHandler('tasks/get', async (request, ctx) => { + const params = request.params as { taskId: string }; + const task = await this.handleGetTask(params.taskId, ctx.sessionId); + // Per spec: tasks/get responses SHALL NOT include related-task metadata + // as the taskId parameter is the source of truth + return { + ...task + } as Result; + }); + + host.registerHandler('tasks/result', async (request, ctx) => { + const params = request.params as { taskId: string }; + return await this.handleGetTaskPayload(params.taskId, ctx.sessionId, ctx.mcpReq.signal, async message => { + // Send the message on the response stream by passing the relatedRequestId + // This tells the transport to write the message to the tasks/result response stream + await host.sendOnResponseStream(message, ctx.mcpReq.id); + }); + }); + + host.registerHandler('tasks/list', async (request, ctx) => { + const params = request.params as { cursor?: string } | undefined; + return (await this.handleListTasks(params?.cursor, ctx.sessionId)) as Result; + }); + + host.registerHandler('tasks/cancel', async (request, ctx) => { + const params = request.params as { taskId: string }; + return await this.handleCancelTask(params.taskId, ctx.sessionId); + }); + } + } + + private get _requireHost(): TaskManagerHost { + if (!this._host) { + throw new ProtocolError(ProtocolErrorCode.InternalError, 'TaskManager is not bound to a Protocol host — call bind() first'); + } + return this._host; + } + + get taskStore(): TaskStore | undefined { + return this._taskStore; + } + + private get _requireTaskStore(): TaskStore { + if (!this._taskStore) { + throw new ProtocolError(ProtocolErrorCode.InternalError, 'TaskStore is not configured'); + } + return this._taskStore; + } + + get taskMessageQueue(): TaskMessageQueue | undefined { + return this._taskMessageQueue; + } + + // -- Public API (client-facing) -- + async *requestStream( + request: Request, + resultSchema: T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + const host = this._requireHost; + const { task } = options ?? {}; + + if (!task) { + try { + const result = await host.request(request, resultSchema, options); + yield { type: 'result', result }; + } catch (error) { + yield { + type: 'error', + error: error instanceof Error ? error : new Error(String(error)) + }; + } + return; + } + + let taskId: string | undefined; + try { + const createResult = await host.request(request, CreateTaskResultSchema, options); + + if (createResult.task) { + taskId = createResult.task.taskId; + yield { type: 'taskCreated', task: createResult.task }; + } else { + throw new ProtocolError(ProtocolErrorCode.InternalError, 'Task creation did not return a task'); + } + + while (true) { + const task = await this.getTask({ taskId }, options); + yield { type: 'taskStatus', task }; + + if (isTerminal(task.status)) { + switch (task.status) { + case 'completed': { + const result = await this.getTaskResult({ taskId }, resultSchema, options); + yield { type: 'result', result }; + break; + } + case 'failed': { + yield { type: 'error', error: new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} failed`) }; + break; + } + case 'cancelled': { + yield { + type: 'error', + error: new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} was cancelled`) + }; + break; + } + } + return; + } + + if (task.status === 'input_required') { + const result = await this.getTaskResult({ taskId }, resultSchema, options); + yield { type: 'result', result }; + return; + } + + const pollInterval = task.pollInterval ?? this._options.defaultTaskPollInterval ?? 1000; + await new Promise(resolve => setTimeout(resolve, pollInterval)); + options?.signal?.throwIfAborted(); + } + } catch (error) { + yield { + type: 'error', + error: error instanceof Error ? error : new Error(String(error)) + }; + } + } + + async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { + return this._requireHost.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); + } + + async getTaskResult( + params: GetTaskPayloadRequest['params'], + resultSchema: T, + options?: RequestOptions + ): Promise> { + return this._requireHost.request({ method: 'tasks/result', params }, resultSchema, options); + } + + async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise> { + return this._requireHost.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); + } + + async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { + return this._requireHost.request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); + } + + // -- Handler bodies (delegated from Protocol's registered handlers) -- + + private async handleGetTask(taskId: string, sessionId?: string): Promise { + const task = await this._requireTaskStore.getTask(taskId, sessionId); + if (!task) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + } + return task; + } + + private async handleGetTaskPayload( + taskId: string, + sessionId: string | undefined, + signal: AbortSignal, + sendOnResponseStream: (message: JSONRPCNotification | JSONRPCRequest) => Promise + ): Promise { + const handleTaskResult = async (): Promise => { + if (this._taskMessageQueue) { + let queuedMessage: QueuedMessage | undefined; + while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, sessionId))) { + if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { + const message = queuedMessage.message; + const requestId = message.id; + const resolver = this._requestResolvers.get(requestId as RequestId); + + if (resolver) { + this._requestResolvers.delete(requestId as RequestId); + if (queuedMessage.type === 'response') { + resolver(message as JSONRPCResultResponse); + } else { + const errorMessage = message as JSONRPCErrorResponse; + resolver(new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data)); + } + } else { + const messageType = queuedMessage.type === 'response' ? 'Response' : 'Error'; + this._host?.reportError(new Error(`${messageType} handler missing for request ${requestId}`)); + } + continue; + } + + await sendOnResponseStream(queuedMessage.message as JSONRPCNotification | JSONRPCRequest); + } + } + + const task = await this._requireTaskStore.getTask(taskId, sessionId); + if (!task) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found: ${taskId}`); + } + + if (!isTerminal(task.status)) { + await this._waitForTaskUpdate(taskId, signal); + return await handleTaskResult(); + } + + const result = await this._requireTaskStore.getTaskResult(taskId, sessionId); + await this._clearTaskQueue(taskId); + + return { + ...result, + _meta: { + ...result._meta, + [RELATED_TASK_META_KEY]: { taskId } + } + }; + }; + + return await handleTaskResult(); + } + + private async handleListTasks( + cursor: string | undefined, + sessionId?: string + ): Promise<{ tasks: Task[]; nextCursor?: string; _meta: Record }> { + try { + const { tasks, nextCursor } = await this._requireTaskStore.listTasks(cursor, sessionId); + return { tasks, nextCursor, _meta: {} }; + } catch (error) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + + private async handleCancelTask(taskId: string, sessionId?: string): Promise { + try { + const task = await this._requireTaskStore.getTask(taskId, sessionId); + if (!task) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found: ${taskId}`); + } + + if (isTerminal(task.status)) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); + } + + await this._requireTaskStore.updateTaskStatus(taskId, 'cancelled', 'Client cancelled task execution.', sessionId); + await this._clearTaskQueue(taskId); + + const cancelledTask = await this._requireTaskStore.getTask(taskId, sessionId); + if (!cancelledTask) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found after cancellation: ${taskId}`); + } + + return { _meta: {}, ...cancelledTask }; + } catch (error) { + if (error instanceof ProtocolError) throw error; + throw new ProtocolError( + ProtocolErrorCode.InvalidRequest, + `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + + // -- Internal delegation methods -- + + private prepareOutboundRequest( + jsonrpcRequest: JSONRPCRequest, + options: RequestOptions | undefined, + messageId: number, + responseHandler: (response: JSONRPCResultResponse | Error) => void, + onError: (error: unknown) => void + ): boolean { + const { task, relatedTask } = options ?? {}; + + if (task) { + jsonrpcRequest.params = { + ...jsonrpcRequest.params, + task: task + }; + } + + if (relatedTask) { + jsonrpcRequest.params = { + ...jsonrpcRequest.params, + _meta: { + ...jsonrpcRequest.params?._meta, + [RELATED_TASK_META_KEY]: relatedTask + } + }; + } + + const relatedTaskId = relatedTask?.taskId; + if (relatedTaskId) { + this._requestResolvers.set(messageId, responseHandler); + + this._enqueueTaskMessage(relatedTaskId, { + type: 'request', + message: jsonrpcRequest, + timestamp: Date.now() + }).catch(error => { + onError(error); + }); + + return true; + } + + return false; + } + + private extractInboundTaskContext( + request: JSONRPCRequest, + sessionId?: string + ): { + relatedTaskId?: string; + taskCreationParams?: TaskCreationParams; + taskContext?: TaskContext; + } { + const relatedTaskId = (request.params?._meta as Record | undefined)?.[RELATED_TASK_META_KEY]?.taskId; + const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined; + + // Provide task context whenever a task store is configured, + // not just for task-related requests — tools need ctx.task.store + let taskContext: TaskContext | undefined; + if (this._taskStore) { + const store = this.createRequestTaskStore(request, sessionId); + taskContext = { + id: relatedTaskId, + store, + requestedTtl: taskCreationParams?.ttl + }; + } + + if (!relatedTaskId && !taskCreationParams && !taskContext) { + return {}; + } + + return { + relatedTaskId, + taskCreationParams, + taskContext + }; + } + + private wrapSendNotification( + relatedTaskId: string, + originalSendNotification: (notification: Notification, options?: NotificationOptions) => Promise + ): (notification: Notification) => Promise { + return async (notification: Notification) => { + const notificationOptions: NotificationOptions = { relatedTask: { taskId: relatedTaskId } }; + await originalSendNotification(notification, notificationOptions); + }; + } + + private wrapSendRequest( + relatedTaskId: string, + taskStore: RequestTaskStore | undefined, + originalSendRequest: (request: Request, resultSchema: V, options?: RequestOptions) => Promise> + ): (request: Request, resultSchema: V, options?: TaskRequestOptions) => Promise> { + return async (request: Request, resultSchema: V, options?: TaskRequestOptions) => { + const requestOptions: RequestOptions = { ...options }; + if (!requestOptions.relatedTask) { + requestOptions.relatedTask = { taskId: relatedTaskId }; + } + + const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; + if (effectiveTaskId && taskStore) { + await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); + } + + return await originalSendRequest(request, resultSchema, requestOptions); + }; + } + + private handleResponse(response: JSONRPCResponse | JSONRPCErrorResponse): boolean { + const messageId = Number(response.id); + const resolver = this._requestResolvers.get(messageId); + if (resolver) { + this._requestResolvers.delete(messageId); + if (isJSONRPCResultResponse(response)) { + resolver(response); + } else { + resolver(new ProtocolError(response.error.code, response.error.message, response.error.data)); + } + return true; + } + return false; + } + + private shouldPreserveProgressHandler(response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number): boolean { + if (isJSONRPCResultResponse(response) && response.result && typeof response.result === 'object') { + const result = response.result as Record; + if (result.task && typeof result.task === 'object') { + const task = result.task as Record; + if (typeof task.taskId === 'string') { + this._taskProgressTokens.set(task.taskId, messageId); + return true; + } + } + } + return false; + } + + private async routeNotification(notification: Notification, options?: NotificationOptions): Promise { + const relatedTaskId = options?.relatedTask?.taskId; + if (!relatedTaskId) return false; + + const jsonrpcNotification: JSONRPCNotification = { + ...notification, + jsonrpc: '2.0', + params: { + ...notification.params, + _meta: { + ...notification.params?._meta, + [RELATED_TASK_META_KEY]: options!.relatedTask + } + } + }; + + await this._enqueueTaskMessage(relatedTaskId, { + type: 'notification', + message: jsonrpcNotification, + timestamp: Date.now() + }); + + return true; + } + + private async routeResponse( + relatedTaskId: string | undefined, + message: JSONRPCResponse | JSONRPCErrorResponse, + sessionId?: string + ): Promise { + if (!relatedTaskId || !this._taskMessageQueue) return false; + + await (isJSONRPCErrorResponse(message) + ? this._enqueueTaskMessage(relatedTaskId, { type: 'error', message, timestamp: Date.now() }, sessionId) + : this._enqueueTaskMessage( + relatedTaskId, + { type: 'response', message: message as JSONRPCResultResponse, timestamp: Date.now() }, + sessionId + )); + return true; + } + + private createRequestTaskStore(request?: JSONRPCRequest, sessionId?: string): RequestTaskStore { + const taskStore = this._requireTaskStore; + const host = this._host; + + return { + createTask: async taskParams => { + if (!request) throw new Error('No request provided'); + return await taskStore.createTask(taskParams, request.id, { method: request.method, params: request.params }, sessionId); + }, + getTask: async taskId => { + const task = await taskStore.getTask(taskId, sessionId); + if (!task) throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + return task; + }, + storeTaskResult: async (taskId, status, result) => { + await taskStore.storeTaskResult(taskId, status, result, sessionId); + const task = await taskStore.getTask(taskId, sessionId); + if (task) { + const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ + method: 'notifications/tasks/status', + params: task + }); + await host?.notification(notification as Notification); + if (isTerminal(task.status)) { + this._cleanupTaskProgressHandler(taskId); + } + } + }, + getTaskResult: taskId => taskStore.getTaskResult(taskId, sessionId), + updateTaskStatus: async (taskId, status, statusMessage) => { + const task = await taskStore.getTask(taskId, sessionId); + if (!task) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task "${taskId}" not found - it may have been cleaned up`); + } + if (isTerminal(task.status)) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + `Cannot update task "${taskId}" from terminal status "${task.status}" to "${status}". Terminal states (completed, failed, cancelled) cannot transition to other states.` + ); + } + await taskStore.updateTaskStatus(taskId, status, statusMessage, sessionId); + const updatedTask = await taskStore.getTask(taskId, sessionId); + if (updatedTask) { + const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ + method: 'notifications/tasks/status', + params: updatedTask + }); + await host?.notification(notification as Notification); + if (isTerminal(updatedTask.status)) { + this._cleanupTaskProgressHandler(taskId); + } + } + }, + listTasks: cursor => taskStore.listTasks(cursor, sessionId) + }; + } + + // -- Consolidated lifecycle methods (called by Protocol) -- + + processInboundRequest(request: JSONRPCRequest, ctx: InboundRequestContext): InboundRequestResult { + const taskInfo = this.extractInboundTaskContext(request, ctx.sessionId); + const relatedTaskId = taskInfo?.relatedTaskId; + + const sendNotification = relatedTaskId + ? this.wrapSendNotification(relatedTaskId, ctx.sendNotification) + : (notification: Notification) => ctx.sendNotification(notification); + + const sendRequest = taskInfo?.taskContext + ? this.wrapSendRequest(relatedTaskId ?? '', taskInfo.taskContext.store, ctx.sendRequest) + : ctx.sendRequest; + + return { + taskContext: taskInfo?.taskContext, + sendNotification, + sendRequest, + routeResponse: async (message: JSONRPCResponse | JSONRPCErrorResponse) => { + if (relatedTaskId) { + return this.routeResponse(relatedTaskId, message, ctx.sessionId); + } + return false; + }, + hasTaskCreationParams: !!taskInfo?.taskCreationParams + }; + } + + processOutboundRequest( + jsonrpcRequest: JSONRPCRequest, + options: RequestOptions | undefined, + messageId: number, + responseHandler: (response: JSONRPCResultResponse | Error) => void, + onError: (error: unknown) => void + ): { queued: boolean } { + const queued = this.prepareOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, onError); + return { queued }; + } + + processInboundResponse( + response: JSONRPCResponse | JSONRPCErrorResponse, + messageId: number + ): { consumed: boolean; preserveProgress: boolean } { + const consumed = this.handleResponse(response); + if (consumed) { + return { consumed: true, preserveProgress: false }; + } + const preserveProgress = this.shouldPreserveProgressHandler(response, messageId); + return { consumed: false, preserveProgress }; + } + + async processOutboundNotification( + notification: Notification, + options?: NotificationOptions + ): Promise<{ queued: boolean; jsonrpcNotification?: JSONRPCNotification }> { + // Try queuing first + const queued = await this.routeNotification(notification, options); + if (queued) return { queued: true }; + + // Build JSONRPC notification with optional relatedTask metadata + let jsonrpcNotification: JSONRPCNotification = { ...notification, jsonrpc: '2.0' }; + if (options?.relatedTask) { + jsonrpcNotification = { + ...jsonrpcNotification, + params: { + ...jsonrpcNotification.params, + _meta: { + ...jsonrpcNotification.params?._meta, + [RELATED_TASK_META_KEY]: options.relatedTask + } + } + }; + } + return { queued: false, jsonrpcNotification }; + } + + onClose(): void { + this._taskProgressTokens.clear(); + } + + // -- Private helpers -- + + private async _enqueueTaskMessage(taskId: string, message: QueuedMessage, sessionId?: string): Promise { + if (!this._taskStore || !this._taskMessageQueue) { + throw new Error('Cannot enqueue task message: taskStore and taskMessageQueue are not configured'); + } + await this._taskMessageQueue.enqueue(taskId, message, sessionId, this._options.maxTaskQueueSize); + } + + private async _clearTaskQueue(taskId: string, sessionId?: string): Promise { + if (this._taskMessageQueue) { + const messages = await this._taskMessageQueue.dequeueAll(taskId, sessionId); + for (const message of messages) { + if (message.type === 'request' && isJSONRPCRequest(message.message)) { + const requestId = message.message.id as RequestId; + const resolver = this._requestResolvers.get(requestId); + if (resolver) { + resolver(new ProtocolError(ProtocolErrorCode.InternalError, 'Task cancelled or completed')); + this._requestResolvers.delete(requestId); + } else { + this._host?.reportError(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); + } + } + } + } + } + + private async _waitForTaskUpdate(taskId: string, signal: AbortSignal): Promise { + let interval = this._options.defaultTaskPollInterval ?? 1000; + try { + const task = await this._requireTaskStore.getTask(taskId); + if (task?.pollInterval) interval = task.pollInterval; + } catch { + // Use default interval + } + + return new Promise((resolve, reject) => { + if (signal.aborted) { + reject(new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Request cancelled')); + return; + } + const timeoutId = setTimeout(resolve, interval); + signal.addEventListener( + 'abort', + () => { + clearTimeout(timeoutId); + reject(new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Request cancelled')); + }, + { once: true } + ); + }); + } + + private _cleanupTaskProgressHandler(taskId: string): void { + const progressToken = this._taskProgressTokens.get(taskId); + if (progressToken !== undefined) { + this._host?.removeProgressHandler(progressToken); + this._taskProgressTokens.delete(taskId); + } + } +} + +/** + * No-op TaskManager used when task support is not configured. + */ +export class NullTaskManager extends TaskManager { + constructor() { + super({}); + } + + override processInboundRequest(_request: JSONRPCRequest, ctx: InboundRequestContext): InboundRequestResult { + return { + taskContext: undefined, + sendNotification: (notification: Notification) => ctx.sendNotification(notification), + sendRequest: ctx.sendRequest, + routeResponse: async () => false, + hasTaskCreationParams: false + }; + } + + // processOutboundRequest is inherited - it handles task/relatedTask augmentation + // and only queues if relatedTask is set (which won't happen without a task store) + + // processInboundResponse is inherited - it checks _requestResolvers (empty for NullTaskManager) + // and _taskProgressTokens (empty for NullTaskManager) + + override async processOutboundNotification( + notification: Notification, + _options?: NotificationOptions + ): Promise<{ queued: boolean; jsonrpcNotification?: JSONRPCNotification }> { + return { queued: false, jsonrpcNotification: { ...notification, jsonrpc: '2.0' } }; + } + + override onClose(): void { + // No-op + } +} diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index 4d8f1830f..2df944a15 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -35,16 +35,18 @@ import { SdkError, SdkErrorCode } from '../../src/errors/sdkErrors.js'; // Type helper for accessing private/protected Protocol properties in tests interface TestProtocol { - _taskMessageQueue?: TaskMessageQueue; - _requestResolvers: Map void>; _responseHandlers: Map void>; - _taskProgressTokens: Map; - _clearTaskQueue: (taskId: string, sessionId?: string) => Promise; - requestTaskStore: (request: Request, authInfo: unknown) => TaskStore; - // Protected task methods (exposed for testing) - listTasks: (params?: { cursor?: string }) => Promise<{ tasks: Task[]; nextCursor?: string }>; - cancelTask: (params: { taskId: string }) => Promise; - requestStream: (request: Request, schema: ZodType, options?: unknown) => AsyncGenerator>; + // TaskManager is accessible via the `tasks` getter + tasks: { + _taskMessageQueue?: TaskMessageQueue; + _requestResolvers: Map void>; + _taskProgressTokens: Map; + _clearTaskQueue: (taskId: string, sessionId?: string) => Promise; + // createRequestTaskStore is private; use inbound request handler flow instead + listTasks: (params?: { cursor?: string }) => Promise<{ tasks: Task[]; nextCursor?: string }>; + cancelTask: (params: { taskId: string }) => Promise; + requestStream: (request: Request, schema: ZodType, options?: unknown) => AsyncGenerator>; + }; } // Mock Transport class @@ -162,10 +164,10 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); }); @@ -639,10 +641,10 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced_with_params'] }); await protocol.connect(transport); @@ -665,10 +667,10 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced_with_options'] }); await protocol.connect(transport); @@ -689,10 +691,10 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -718,10 +720,10 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -750,10 +752,10 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -780,10 +782,10 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); // Configure for a different method await protocol.connect(transport); @@ -818,10 +820,10 @@ describe('protocol tests', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -1040,11 +1042,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: createMockTaskStore(), taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore: createMockTaskStore(), taskMessageQueue: new InMemoryTaskMessageQueue() } }); }); describe('request with task metadata', () => { @@ -1182,7 +1184,7 @@ describe('Task-based execution', () => { expect(sendSpy).not.toHaveBeenCalled(); // Verify the message was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); }); @@ -1206,7 +1208,7 @@ describe('Task-based execution', () => { expect(sendSpy).not.toHaveBeenCalled(); // Verify the message was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('parent-task-456'); @@ -1253,7 +1255,7 @@ describe('Task-based execution', () => { expect(sendSpy).not.toHaveBeenCalled(); // Verify the message was queued with all metadata combined - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('parent-task'); @@ -1290,11 +1292,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); await protocol.connect(transport); @@ -1362,11 +1364,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); await protocol.connect(transport); @@ -1426,11 +1428,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); await protocol.connect(transport); @@ -1475,11 +1477,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); await protocol.connect(transport); @@ -1511,11 +1513,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); await protocol.connect(transport); @@ -1544,7 +1546,7 @@ describe('Task-based execution', () => { it('should call listTasks method from client side', async () => { await protocol.connect(transport); - const listTasksPromise = (protocol as unknown as TestProtocol).listTasks(); + const listTasksPromise = (protocol as unknown as TestProtocol).tasks!.listTasks(); // Simulate server response setTimeout(() => { @@ -1584,7 +1586,7 @@ describe('Task-based execution', () => { it('should call listTasks with cursor from client side', async () => { await protocol.connect(transport); - const listTasksPromise = (protocol as unknown as TestProtocol).listTasks({ cursor: 'task-10' }); + const listTasksPromise = (protocol as unknown as TestProtocol).tasks!.listTasks({ cursor: 'task-10' }); // Simulate server response setTimeout(() => { @@ -1648,11 +1650,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1693,11 +1695,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1744,11 +1746,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1779,7 +1781,7 @@ describe('Task-based execution', () => { it('should call cancelTask method from client side', async () => { await protocol.connect(transport); - const deleteTaskPromise = (protocol as unknown as TestProtocol).cancelTask({ taskId: 'task-to-delete' }); + const deleteTaskPromise = (protocol as unknown as TestProtocol).tasks!.cancelTask({ taskId: 'task-to-delete' }); // Simulate server response - per MCP spec, CancelTaskResult is Result & Task setTimeout(() => { @@ -1829,11 +1831,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); const serverTransport = new MockTransport(); await serverProtocol.connect(serverTransport); @@ -1882,11 +1884,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1934,11 +1936,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1974,11 +1976,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -2022,11 +2024,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -2068,11 +2070,11 @@ describe('Task-based execution', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -2124,7 +2126,7 @@ describe('Task-based execution', () => { // Verify the notification was QUEUED (not sent via transport) // Messages with relatedTask metadata should be queued for delivery via tasks/result // to prevent duplicate delivery for bidirectional transports - const queue = (serverProtocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (serverProtocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('parent-task-123'); @@ -2154,11 +2156,11 @@ describe('Request Cancellation vs Task Cancellation', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + })({ tasks: { taskStore } }); }); describe('notifications/cancelled behavior', () => { @@ -2442,10 +2444,10 @@ describe('Progress notification support for tasks', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); }); @@ -2456,11 +2458,11 @@ describe('Progress notification support for tasks', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + })({ tasks: { taskStore } }); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2552,11 +2554,11 @@ describe('Progress notification support for tasks', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + })({ tasks: { taskStore } }); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2648,15 +2650,22 @@ describe('Progress notification support for tasks', () => { expect(progressCallback).toHaveBeenCalledTimes(1); // Verify the task-progress association was created - const taskProgressTokens = (protocol as unknown as TestProtocol)._taskProgressTokens as Map; + const taskProgressTokens = (protocol as unknown as TestProtocol).tasks!._taskProgressTokens as Map; expect(taskProgressTokens.has(taskId)).toBe(true); expect(taskProgressTokens.get(taskId)).toBe(progressToken); - // Simulate task completion by calling through the protocol's task store - // This will trigger the cleanup logic - const mockRequest = { jsonrpc: '2.0' as const, id: 999, method: 'test', params: {} }; - const requestTaskStore = (protocol as unknown as TestProtocol).requestTaskStore(mockRequest, undefined); - await requestTaskStore.storeTaskResult(taskId, 'completed', { content: [] }); + // Simulate task completion by triggering an inbound request whose handler + // calls storeTaskResult through the task context (the public RequestTaskStore API). + // This is equivalent to how a real server handler would complete a task. + protocol.setRequestHandler('ping', async (_request, ctx) => { + if (ctx.task?.store) { + await ctx.task.store.storeTaskResult(taskId, 'completed', { content: [] }); + } + return {}; + }); + if (transport.onmessage) { + transport.onmessage({ jsonrpc: '2.0', id: 999, method: 'ping', params: {} }); + } // Wait for all async operations including notification sending to complete await new Promise(resolve => setTimeout(resolve, 50)); @@ -2691,11 +2700,11 @@ describe('Progress notification support for tasks', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + })({ tasks: { taskStore } }); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2792,11 +2801,11 @@ describe('Progress notification support for tasks', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + })({ tasks: { taskStore } }); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2890,11 +2899,11 @@ describe('Progress notification support for tasks', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + })({ tasks: { taskStore } }); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -3166,7 +3175,7 @@ describe('Message interception for task-related notifications', () => { return ctx; } protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } }); await server.connect(transport); @@ -3185,7 +3194,7 @@ describe('Message interception for task-related notifications', () => { ); // Access the private queue to verify the message was queued - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(task.taskId); @@ -3206,7 +3215,7 @@ describe('Message interception for task-related notifications', () => { return ctx; } protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } }); await server.connect(transport); @@ -3236,7 +3245,7 @@ describe('Message interception for task-related notifications', () => { return ctx; } protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); + })({ tasks: { taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 } }); await server.connect(transport); @@ -3286,7 +3295,7 @@ describe('Message interception for task-related notifications', () => { return ctx; } protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } }); await server.connect(transport); @@ -3304,7 +3313,7 @@ describe('Message interception for task-related notifications', () => { ); // Verify the message was queued under the correct task ID - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(taskId); expect(queuedMessage).toBeDefined(); @@ -3322,7 +3331,7 @@ describe('Message interception for task-related notifications', () => { return ctx; } protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } }); await server.connect(transport); @@ -3343,7 +3352,7 @@ describe('Message interception for task-related notifications', () => { } // Verify messages are in FIFO order - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); for (let i = 0; i < 5; i++) { @@ -3367,7 +3376,7 @@ describe('Message interception for task-related requests', () => { return ctx; } protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } }); await server.connect(transport); @@ -3387,7 +3396,7 @@ describe('Message interception for task-related requests', () => { ); // Access the private queue to verify the message was queued - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(task.taskId); @@ -3397,7 +3406,7 @@ describe('Message interception for task-related requests', () => { // Verify resolver is stored in _requestResolvers map (not in the message) const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; - const resolvers = (server as unknown as TestProtocol)._requestResolvers; + const resolvers = (server as unknown as TestProtocol).tasks!._requestResolvers; expect(resolvers.has(requestId)).toBe(true); // Clean up - send a response to prevent hanging promise @@ -3422,7 +3431,7 @@ describe('Message interception for task-related requests', () => { return ctx; } protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } }); await server.connect(transport); @@ -3436,7 +3445,7 @@ describe('Message interception for task-related requests', () => { ); // Verify queue exists (but we don't track size in the new API) - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Clean up - send a response @@ -3464,7 +3473,7 @@ describe('Message interception for task-related requests', () => { return ctx; } protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } }); await server.connect(transport); @@ -3484,11 +3493,11 @@ describe('Message interception for task-related requests', () => { ); // Verify the resolver was stored - const resolvers = (server as unknown as TestProtocol)._requestResolvers; + const resolvers = (server as unknown as TestProtocol).tasks!._requestResolvers; expect(resolvers.size).toBe(1); // Get the request ID from the queue - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocol).tasks!._taskMessageQueue; const queuedMessage = await queue!.dequeue(task.taskId); const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; @@ -3520,7 +3529,7 @@ describe('Message interception for task-related requests', () => { return ctx; } protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: queue }); + })({ tasks: { taskStore, taskMessageQueue: queue } }); await server.connect(transport); @@ -3588,7 +3597,7 @@ describe('Message interception for task-related requests', () => { return ctx; } protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } }); const errors: Error[] = []; server.onerror = (error: Error) => { @@ -3613,12 +3622,12 @@ describe('Message interception for task-related requests', () => { ); // Get the request ID from the queue - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocol).tasks!._taskMessageQueue; const queuedMessage = await queue!.dequeue(task.taskId); const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; // Manually delete the resolver to simulate missing resolver - (server as unknown as TestProtocol)._requestResolvers.delete(requestId); + (server as unknown as TestProtocol).tasks!._requestResolvers.delete(requestId); // Enqueue a response message - this should trigger the error logging when processed await queue!.enqueue(task.taskId, { @@ -3667,7 +3676,7 @@ describe('Message interception for task-related requests', () => { return ctx; } protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); + })({ tasks: { taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 } }); await server.connect(transport); @@ -3727,11 +3736,11 @@ describe('Message Interception', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } }); }); describe('messages with relatedTask metadata are queued', () => { @@ -3752,7 +3761,7 @@ describe('Message Interception', () => { ); // Access the private _taskMessageQueue to verify the message was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('task-123'); @@ -3780,7 +3789,7 @@ describe('Message Interception', () => { ); // Access the private _taskMessageQueue to verify the message was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('task-456'); @@ -3789,7 +3798,7 @@ describe('Message Interception', () => { // Verify resolver is stored in _requestResolvers map (not in the message) const requestId = queuedMessage.message.id as RequestId; - const resolvers = (protocol as unknown as TestProtocol)._requestResolvers; + const resolvers = (protocol as unknown as TestProtocol).tasks!._requestResolvers; expect(resolvers.has(requestId)).toBe(true); // Clean up the pending request @@ -3829,7 +3838,7 @@ describe('Message Interception', () => { await new Promise(resolve => setTimeout(resolve, 50)); // Verify the response was queued instead of sent directly - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(taskId); @@ -3867,7 +3876,7 @@ describe('Message Interception', () => { await new Promise(resolve => setTimeout(resolve, 50)); // Verify the error was queued instead of sent directly - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(taskId); @@ -3901,7 +3910,7 @@ describe('Message Interception', () => { await new Promise(resolve => setTimeout(resolve, 50)); // Verify the error was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(taskId); @@ -3958,7 +3967,7 @@ describe('Message Interception', () => { // Access the private _taskMessageQueue to verify no messages were queued // Since we can't check if queues exist without messages, we verify that // attempting to dequeue returns undefined (no messages queued) - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); }); @@ -3980,7 +3989,7 @@ describe('Message Interception', () => { // Access the private _taskMessageQueue to verify no messages were queued // Since we can't check if queues exist without messages, we verify that // attempting to dequeue returns undefined (no messages queued) - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Clean up the pending request @@ -4014,7 +4023,7 @@ describe('Message Interception', () => { ); // Verify the message was queued under the correct task ID - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Verify a message was queued for this task @@ -4044,7 +4053,7 @@ describe('Message Interception', () => { ); // Verify the message was queued under the correct task ID - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Clean up the pending request @@ -4068,7 +4077,7 @@ describe('Message Interception', () => { await protocol.notification({ method: 'test3', params: {} }, { relatedTask: { taskId: 'task-A' } }); // Verify messages are queued under correct task IDs - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Verify two messages for task-A @@ -4091,7 +4100,7 @@ describe('Message Interception', () => { it('should queue messages for a task', async () => { await protocol.connect(transport); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Send first message for a task @@ -4106,7 +4115,7 @@ describe('Message Interception', () => { it('should queue multiple messages for the same task', async () => { await protocol.connect(transport); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Send first message @@ -4127,7 +4136,7 @@ describe('Message Interception', () => { it('should queue messages for different tasks separately', async () => { await protocol.connect(transport); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Send messages for different tasks @@ -4158,7 +4167,7 @@ describe('Message Interception', () => { { relatedTask } ); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; const queuedMessage = await queue!.dequeue('task-meta-123'); // Verify the metadata is preserved in the queued message @@ -4183,7 +4192,7 @@ describe('Message Interception', () => { { relatedTask } ); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; const queuedMessage = await queue!.dequeue('task-meta-456'); // Verify the metadata is preserved in the queued message @@ -4220,7 +4229,7 @@ describe('Message Interception', () => { } ); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; const queuedMessage = await queue!.dequeue('task-preserve-meta'); // Verify both existing and new metadata are preserved @@ -4248,11 +4257,11 @@ describe('Queue lifecycle management', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + })({ tasks: { taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() } }); }); describe('queue cleanup on task completion', () => { @@ -4268,7 +4277,7 @@ describe('Queue lifecycle management', () => { await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); // Verify messages are queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Verify messages can be dequeued @@ -4278,7 +4287,7 @@ describe('Queue lifecycle management', () => { expect(msg2).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + (protocol as unknown as TestProtocol).tasks!._clearTaskQueue(taskId); // After cleanup, no more messages should be available const msg3 = await queue!.dequeue(taskId); @@ -4314,7 +4323,7 @@ describe('Queue lifecycle management', () => { await resultPromise; // Verify queue is cleared after delivery (no messages available) - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; const msg = await queue!.dequeue(taskId); expect(msg).toBeUndefined(); }); @@ -4332,7 +4341,7 @@ describe('Queue lifecycle management', () => { await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); // Verify message is queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; const msg1 = await queue!.dequeue(taskId); expect(msg1).toBeDefined(); @@ -4373,7 +4382,7 @@ describe('Queue lifecycle management', () => { .catch(err => err); // Verify request is queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Mock task as non-terminal @@ -4414,7 +4423,7 @@ describe('Queue lifecycle management', () => { await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); // Verify messages are queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Verify messages can be dequeued @@ -4424,7 +4433,7 @@ describe('Queue lifecycle management', () => { expect(msg2).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + (protocol as unknown as TestProtocol).tasks!._clearTaskQueue(taskId); // After cleanup, no more messages should be available const msg3 = await queue!.dequeue(taskId); @@ -4446,11 +4455,11 @@ describe('Queue lifecycle management', () => { .catch(err => err); // Verify request is queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + (protocol as unknown as TestProtocol).tasks!._clearTaskQueue(taskId); // Verify the request promise is rejected const result = await requestPromise; @@ -4491,11 +4500,11 @@ describe('Queue lifecycle management', () => { .catch(err => err); // Verify requests are queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocol).tasks!._taskMessageQueue; expect(queue).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + (protocol as unknown as TestProtocol).tasks!._clearTaskQueue(taskId); // Verify all request promises are rejected const result1 = await request1Promise; @@ -4529,7 +4538,7 @@ describe('Queue lifecycle management', () => { .catch(err => err); // Get the request ID that was sent - const requestResolvers = (protocol as unknown as TestProtocol)._requestResolvers; + const requestResolvers = (protocol as unknown as TestProtocol).tasks!._requestResolvers; const initialResolverCount = requestResolvers.size; expect(initialResolverCount).toBeGreaterThan(0); @@ -4538,7 +4547,7 @@ describe('Queue lifecycle management', () => { mockTaskStore.getTask.mockResolvedValue(completedTask); // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + (protocol as unknown as TestProtocol).tasks!._clearTaskQueue(taskId); // Verify request promise is rejected const result = await requestPromise; @@ -4565,17 +4574,17 @@ describe('requestStream() method', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); await protocol.connect(transport); // Start the request stream const streamPromise = (async () => { const messages = []; - const stream = (protocol as unknown as TestProtocol).requestStream( + const stream = (protocol as unknown as TestProtocol).tasks!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ); @@ -4611,17 +4620,17 @@ describe('requestStream() method', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); await protocol.connect(transport); // Start the request stream const streamPromise = (async () => { const messages = []; - const stream = (protocol as unknown as TestProtocol).requestStream( + const stream = (protocol as unknown as TestProtocol).tasks!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ); @@ -4660,10 +4669,10 @@ describe('requestStream() method', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); await protocol.connect(transport); @@ -4674,7 +4683,7 @@ describe('requestStream() method', () => { // Start the request stream with already-aborted signal const messages = []; - const stream = (protocol as unknown as TestProtocol).requestStream( + const stream = (protocol as unknown as TestProtocol).tasks!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { @@ -4701,15 +4710,15 @@ describe('requestStream() method', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); await protocol.connect(transport); const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocol).requestStream( + (protocol as unknown as TestProtocol).tasks!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ) @@ -4746,15 +4755,15 @@ describe('requestStream() method', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); await protocol.connect(transport); const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocol).requestStream( + (protocol as unknown as TestProtocol).tasks!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { @@ -4788,10 +4797,10 @@ describe('requestStream() method', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); await protocol.connect(transport); @@ -4800,7 +4809,7 @@ describe('requestStream() method', () => { // Collect messages const messages = await toArrayAsync( - (protocol as unknown as TestProtocol).requestStream( + (protocol as unknown as TestProtocol).tasks!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { @@ -4824,15 +4833,15 @@ describe('requestStream() method', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); await protocol.connect(transport); const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocol).requestStream( + (protocol as unknown as TestProtocol).tasks!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ) @@ -4879,15 +4888,15 @@ describe('requestStream() method', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + })({ tasks: { taskStore: mockTaskStore } }); await protocol.connect(transport); const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocol).requestStream( + (protocol as unknown as TestProtocol).tasks!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ) @@ -4942,10 +4951,10 @@ describe('requestStream() method', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); await protocol.connect(transport); @@ -4953,7 +4962,7 @@ describe('requestStream() method', () => { transport.send = vi.fn().mockRejectedValue(new Error('Network error')); const messages = await toArrayAsync( - (protocol as unknown as TestProtocol).requestStream( + (protocol as unknown as TestProtocol).tasks!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ) @@ -4973,15 +4982,15 @@ describe('requestStream() method', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); await protocol.connect(transport); const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocol).requestStream( + (protocol as unknown as TestProtocol).tasks!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ) @@ -5037,9 +5046,11 @@ describe('Error handling for missing resolvers', () => { } protected assertTaskHandlerCapability(_method: string): void {} })({ - taskStore, - taskMessageQueue, - defaultTaskPollInterval: 100 + tasks: { + taskStore, + taskMessageQueue, + defaultTaskPollInterval: 100 + } }); // @ts-expect-error deliberately overriding error handler with mock @@ -5077,7 +5088,7 @@ describe('Error handling for missing resolvers', () => { if (queuedMessage && queuedMessage.type === 'response') { const responseMessage = queuedMessage.message as JSONRPCResultResponse; const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol.tasks!._requestResolvers.get(requestId); if (!resolver) { // This simulates what happens in the actual handler @@ -5154,7 +5165,7 @@ describe('Error handling for missing resolvers', () => { // Clear the task queue (simulating cancellation) const testProtocol = protocol as unknown as TestProtocol; - await testProtocol._clearTaskQueue(task.taskId); + await testProtocol.tasks!._clearTaskQueue(task.taskId); // Verify error was logged for missing resolver expect(errorHandler).toHaveBeenCalledWith( @@ -5175,7 +5186,7 @@ describe('Error handling for missing resolvers', () => { // Store a resolver const testProtocol = protocol as unknown as TestProtocol; - testProtocol._requestResolvers.set(requestId, resolverMock); + testProtocol.tasks!._requestResolvers.set(requestId, resolverMock); // Enqueue a request await taskMessageQueue.enqueue(task.taskId, { @@ -5190,7 +5201,7 @@ describe('Error handling for missing resolvers', () => { }); // Clear the task queue - await testProtocol._clearTaskQueue(task.taskId); + await testProtocol.tasks!._clearTaskQueue(task.taskId); // Verify resolver was called with cancellation error expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); @@ -5201,7 +5212,7 @@ describe('Error handling for missing resolvers', () => { expect(calledError.message).toContain('Task cancelled or completed'); // Verify resolver was removed - expect(testProtocol._requestResolvers.has(requestId)).toBe(false); + expect(testProtocol.tasks!._requestResolvers.has(requestId)).toBe(false); }); it('should handle mixed messages during cleanup', async () => { @@ -5215,7 +5226,7 @@ describe('Error handling for missing resolvers', () => { // Enqueue multiple messages: request with resolver, request without, notification const requestId1 = 42; const resolverMock = vi.fn(); - testProtocol._requestResolvers.set(requestId1, resolverMock); + testProtocol.tasks!._requestResolvers.set(requestId1, resolverMock); await taskMessageQueue.enqueue(task.taskId, { type: 'request', @@ -5250,7 +5261,7 @@ describe('Error handling for missing resolvers', () => { }); // Clear the task queue - await testProtocol._clearTaskQueue(task.taskId); + await testProtocol.tasks!._clearTaskQueue(task.taskId); // Verify resolver was called for first request expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); @@ -5331,7 +5342,7 @@ describe('Error handling for missing resolvers', () => { const testProtocol = protocol as unknown as TestProtocol; const responseMessage = msg.message as JSONRPCResultResponse; const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol.tasks!._requestResolvers.get(requestId); if (!resolver) { protocol.onerror?.(new Error(`Response handler missing for request ${requestId}`)); } @@ -5360,7 +5371,7 @@ describe('Error handling for missing resolvers', () => { const testProtocol = protocol as unknown as TestProtocol; // This should not throw - await expect(testProtocol._clearTaskQueue(task.taskId)).resolves.not.toThrow(); + await expect(testProtocol.tasks!._clearTaskQueue(task.taskId)).resolves.not.toThrow(); }); }); @@ -5374,7 +5385,7 @@ describe('Error handling for missing resolvers', () => { // Store a resolver const testProtocol = protocol as unknown as TestProtocol; - testProtocol._requestResolvers.set(requestId, resolverMock); + testProtocol.tasks!._requestResolvers.set(requestId, resolverMock); // Enqueue an error message await taskMessageQueue.enqueue(task.taskId, { @@ -5399,10 +5410,10 @@ describe('Error handling for missing resolvers', () => { if (queuedMessage && queuedMessage.type === 'error') { const errorMessage = queuedMessage.message as JSONRPCErrorResponse; const reqId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(reqId); + const resolver = testProtocol.tasks!._requestResolvers.get(reqId); if (resolver) { - testProtocol._requestResolvers.delete(reqId); + testProtocol.tasks!._requestResolvers.delete(reqId); const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } @@ -5415,7 +5426,7 @@ describe('Error handling for missing resolvers', () => { expect(calledError.message).toContain('Invalid request parameters'); // Verify resolver was removed from map - expect(testProtocol._requestResolvers.has(requestId)).toBe(false); + expect(testProtocol.tasks!._requestResolvers.has(requestId)).toBe(false); }); it('should log error for unknown request ID in error messages', async () => { @@ -5447,7 +5458,7 @@ describe('Error handling for missing resolvers', () => { const testProtocol = protocol as unknown as TestProtocol; const errorMessage = queuedMessage.message as JSONRPCErrorResponse; const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol.tasks!._requestResolvers.get(requestId); if (!resolver) { protocol.onerror?.(new Error(`Error handler missing for request ${requestId}`)); @@ -5471,7 +5482,7 @@ describe('Error handling for missing resolvers', () => { // Store a resolver const testProtocol = protocol as unknown as TestProtocol; - testProtocol._requestResolvers.set(requestId, resolverMock); + testProtocol.tasks!._requestResolvers.set(requestId, resolverMock); // Enqueue an error message with data field await taskMessageQueue.enqueue(task.taskId, { @@ -5494,10 +5505,10 @@ describe('Error handling for missing resolvers', () => { if (queuedMessage && queuedMessage.type === 'error') { const errorMessage = queuedMessage.message as JSONRPCErrorResponse; const reqId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(reqId); + const resolver = testProtocol.tasks!._requestResolvers.get(reqId); if (resolver) { - testProtocol._requestResolvers.delete(reqId); + testProtocol.tasks!._requestResolvers.delete(reqId); const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } @@ -5536,7 +5547,7 @@ describe('Error handling for missing resolvers', () => { const testProtocol = protocol as unknown as TestProtocol; const errorMessage = msg.message as JSONRPCErrorResponse; const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol.tasks!._requestResolvers.get(requestId); if (!resolver) { protocol.onerror?.(new Error(`Error handler missing for request ${requestId}`)); } @@ -5559,9 +5570,9 @@ describe('Error handling for missing resolvers', () => { const resolver2 = vi.fn(); const resolver3 = vi.fn(); - testProtocol._requestResolvers.set(1, resolver1); - testProtocol._requestResolvers.set(2, resolver2); - testProtocol._requestResolvers.set(3, resolver3); + testProtocol.tasks!._requestResolvers.set(1, resolver1); + testProtocol.tasks!._requestResolvers.set(2, resolver2); + testProtocol.tasks!._requestResolvers.set(3, resolver3); // Enqueue mixed messages: response, error, response await taskMessageQueue.enqueue(task.taskId, { @@ -5603,17 +5614,17 @@ describe('Error handling for missing resolvers', () => { if (msg.type === 'response') { const responseMessage = msg.message as JSONRPCResultResponse; const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol.tasks!._requestResolvers.get(requestId); if (resolver) { - testProtocol._requestResolvers.delete(requestId); + testProtocol.tasks!._requestResolvers.delete(requestId); resolver(responseMessage); } } else if (msg.type === 'error') { const errorMessage = msg.message as JSONRPCErrorResponse; const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol.tasks!._requestResolvers.get(requestId); if (resolver) { - testProtocol._requestResolvers.delete(requestId); + testProtocol.tasks!._requestResolvers.delete(requestId); const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } @@ -5631,7 +5642,7 @@ describe('Error handling for missing resolvers', () => { expect(error.message).toContain('Request failed'); // Verify all resolvers were removed - expect(testProtocol._requestResolvers.size).toBe(0); + expect(testProtocol.tasks!._requestResolvers.size).toBe(0); }); it('should maintain FIFO order when processing responses and errors', async () => { @@ -5645,9 +5656,9 @@ describe('Error handling for missing resolvers', () => { const resolver2 = vi.fn(() => callOrder.push(2)); const resolver3 = vi.fn(() => callOrder.push(3)); - testProtocol._requestResolvers.set(1, resolver1); - testProtocol._requestResolvers.set(2, resolver2); - testProtocol._requestResolvers.set(3, resolver3); + testProtocol.tasks!._requestResolvers.set(1, resolver1); + testProtocol.tasks!._requestResolvers.set(2, resolver2); + testProtocol.tasks!._requestResolvers.set(3, resolver3); // Enqueue in specific order await taskMessageQueue.enqueue(task.taskId, { @@ -5678,17 +5689,17 @@ describe('Error handling for missing resolvers', () => { if (msg.type === 'response') { const responseMessage = msg.message as JSONRPCResultResponse; const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol.tasks!._requestResolvers.get(requestId); if (resolver) { - testProtocol._requestResolvers.delete(requestId); + testProtocol.tasks!._requestResolvers.delete(requestId); resolver(responseMessage); } } else if (msg.type === 'error') { const errorMessage = msg.message as JSONRPCErrorResponse; const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol.tasks!._requestResolvers.get(requestId); if (resolver) { - testProtocol._requestResolvers.delete(requestId); + testProtocol.tasks!._requestResolvers.delete(requestId); const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } diff --git a/packages/core/test/shared/protocolTransportHandling.test.ts b/packages/core/test/shared/protocolTransportHandling.test.ts index adc7e2234..d62f70a63 100644 --- a/packages/core/test/shared/protocolTransportHandling.test.ts +++ b/packages/core/test/shared/protocolTransportHandling.test.ts @@ -39,10 +39,10 @@ describe('Protocol transport handling bug', () => { protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); transportA = new MockTransport('A'); diff --git a/packages/server/src/experimental/tasks/server.ts b/packages/server/src/experimental/tasks/server.ts index 813f4cf28..23dfa7452 100644 --- a/packages/server/src/experimental/tasks/server.ts +++ b/packages/server/src/experimental/tasks/server.ts @@ -6,6 +6,7 @@ */ import type { + AnyObjectSchema, AnySchema, CancelTaskResult, GetTaskResult, @@ -48,20 +49,16 @@ export class ExperimentalServerTasks { * * @experimental */ - requestStream( + requestStream( request: Request, resultSchema: T, options?: RequestOptions ): AsyncGenerator & Result>, void, void> { - // Delegate to the server's underlying Protocol method - type ServerWithRequestStream = { - requestStream( - request: Request, - resultSchema: U, - options?: RequestOptions - ): AsyncGenerator & Result>, void, void>; - }; - return (this._server as unknown as ServerWithRequestStream).requestStream(request, resultSchema, options); + return this._server.tasks.requestStream(request, resultSchema, options) as AsyncGenerator< + ResponseMessage & Result>, + void, + void + >; } /** @@ -74,8 +71,7 @@ export class ExperimentalServerTasks { * @experimental */ async getTask(taskId: string, options?: RequestOptions): Promise { - type ServerWithGetTask = { getTask(params: { taskId: string }, options?: RequestOptions): Promise }; - return (this._server as unknown as ServerWithGetTask).getTask({ taskId }, options); + return this._server.tasks.getTask({ taskId }, options); } /** @@ -89,15 +85,7 @@ export class ExperimentalServerTasks { * @experimental */ async getTaskResult(taskId: string, resultSchema?: T, options?: RequestOptions): Promise> { - return ( - this._server as unknown as { - getTaskResult: ( - params: { taskId: string }, - resultSchema?: U, - options?: RequestOptions - ) => Promise>; - } - ).getTaskResult({ taskId }, resultSchema, options); + return this._server.tasks.getTaskResult({ taskId }, resultSchema!, options); } /** @@ -110,11 +98,7 @@ export class ExperimentalServerTasks { * @experimental */ async listTasks(cursor?: string, options?: RequestOptions): Promise { - return ( - this._server as unknown as { - listTasks: (params?: { cursor?: string }, options?: RequestOptions) => Promise; - } - ).listTasks(cursor ? { cursor } : undefined, options); + return this._server.tasks.listTasks(cursor ? { cursor } : undefined, options); } /** @@ -126,10 +110,6 @@ export class ExperimentalServerTasks { * @experimental */ async cancelTask(taskId: string, options?: RequestOptions): Promise { - return ( - this._server as unknown as { - cancelTask: (params: { taskId: string }, options?: RequestOptions) => Promise; - } - ).cancelTask({ taskId }, options); + return this._server.tasks.cancelTask({ taskId }, options); } } diff --git a/test/integration/test/client/client.test.ts b/test/integration/test/client/client.test.ts index 3c57fc542..723610bb2 100644 --- a/test/integration/test/client/client.test.ts +++ b/test/integration/test/client/client.test.ts @@ -2258,7 +2258,7 @@ describe('Task-based execution', () => { } } }, - taskStore: serverTaskStore + tasks: { taskStore: serverTaskStore } } ); @@ -2334,7 +2334,7 @@ describe('Task-based execution', () => { } } }, - taskStore: serverTaskStore + tasks: { taskStore: serverTaskStore } } ); @@ -2411,7 +2411,7 @@ describe('Task-based execution', () => { } } }, - taskStore: serverTaskStore + tasks: { taskStore: serverTaskStore } } ); @@ -2495,7 +2495,7 @@ describe('Task-based execution', () => { } } }, - taskStore: serverTaskStore + tasks: { taskStore: serverTaskStore } } ); @@ -2599,7 +2599,7 @@ describe('Task-based execution', () => { } } }, - taskStore: clientTaskStore + tasks: { taskStore: clientTaskStore } } ); @@ -2692,7 +2692,7 @@ describe('Task-based execution', () => { } } }, - taskStore: clientTaskStore + tasks: { taskStore: clientTaskStore } } ); @@ -2784,7 +2784,7 @@ describe('Task-based execution', () => { } } }, - taskStore: clientTaskStore + tasks: { taskStore: clientTaskStore } } ); @@ -2875,7 +2875,7 @@ describe('Task-based execution', () => { } } }, - taskStore: clientTaskStore + tasks: { taskStore: clientTaskStore } } ); @@ -2978,7 +2978,7 @@ describe('Task-based execution', () => { } } }, - taskStore: serverTaskStore + tasks: { taskStore: serverTaskStore } } ); @@ -3100,7 +3100,7 @@ describe('Task-based execution', () => { } } }, - taskStore: serverTaskStore + tasks: { taskStore: serverTaskStore } } ); @@ -3147,7 +3147,7 @@ describe('Task-based execution', () => { } } }, - taskStore: serverTaskStore + tasks: { taskStore: serverTaskStore } } ); @@ -3194,7 +3194,7 @@ describe('Task-based execution', () => { } } }, - taskStore: clientTaskStore + tasks: { taskStore: clientTaskStore } } ); @@ -3248,7 +3248,7 @@ test('should respect server task capabilities', async () => { } } }, - taskStore: serverTaskStore + tasks: { taskStore: serverTaskStore } } ); diff --git a/test/integration/test/helpers/mcp.ts b/test/integration/test/helpers/mcp.ts index 5c53c7a92..5d7898372 100644 --- a/test/integration/test/helpers/mcp.ts +++ b/test/integration/test/helpers/mcp.ts @@ -53,8 +53,10 @@ export async function createInMemoryTaskEnvironment(options?: { } } }, - taskStore, - taskMessageQueue: new InMemoryTaskMessageQueue() + tasks: { + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() + } } ); diff --git a/test/integration/test/server.test.ts b/test/integration/test/server.test.ts index 436b4427a..1e81692c8 100644 --- a/test/integration/test/server.test.ts +++ b/test/integration/test/server.test.ts @@ -2134,7 +2134,7 @@ describe('Task-based execution', () => { } } }, - taskStore + tasks: { taskStore } } ); @@ -2314,7 +2314,7 @@ describe('Task-based execution', () => { } } }, - taskStore + tasks: { taskStore } } ); @@ -2494,7 +2494,7 @@ describe('Task-based execution', () => { } } }, - taskStore: clientTaskStore + tasks: { taskStore: clientTaskStore } } ); @@ -2575,7 +2575,7 @@ describe('Task-based execution', () => { } } }, - taskStore: clientTaskStore + tasks: { taskStore: clientTaskStore } } ); @@ -2654,7 +2654,7 @@ describe('Task-based execution', () => { } } }, - taskStore: clientTaskStore + tasks: { taskStore: clientTaskStore } } ); @@ -2735,7 +2735,7 @@ describe('Task-based execution', () => { } } }, - taskStore: clientTaskStore + tasks: { taskStore: clientTaskStore } } ); @@ -2838,7 +2838,7 @@ describe('Task-based execution', () => { } } }, - taskStore + tasks: { taskStore } } ); @@ -2974,7 +2974,7 @@ describe('Task-based execution', () => { } } }, - taskStore + tasks: { taskStore } } ); @@ -3021,7 +3021,7 @@ describe('Task-based execution', () => { } } }, - taskStore: clientTaskStore + tasks: { taskStore: clientTaskStore } } ); @@ -3078,7 +3078,7 @@ test('should respect client task capabilities', async () => { } } }, - taskStore: clientTaskStore + tasks: { taskStore: clientTaskStore } } ); diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index 58e677922..8be6c8dcf 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1920,7 +1920,7 @@ describe('Zod v4', () => { } } }, - taskStore + tasks: { taskStore } } ); @@ -1989,7 +1989,7 @@ describe('Zod v4', () => { } } }, - taskStore + tasks: { taskStore } } ); @@ -6337,7 +6337,7 @@ describe('Zod v4', () => { } } }, - taskStore + tasks: { taskStore } } ); @@ -6442,7 +6442,7 @@ describe('Zod v4', () => { } } }, - taskStore + tasks: { taskStore } } ); @@ -6550,7 +6550,7 @@ describe('Zod v4', () => { } } }, - taskStore + tasks: { taskStore } } ); @@ -6670,7 +6670,7 @@ describe('Zod v4', () => { } } }, - taskStore + tasks: { taskStore } } ); @@ -6776,7 +6776,7 @@ describe('Zod v4', () => { } } }, - taskStore + tasks: { taskStore } } ); @@ -6877,7 +6877,7 @@ describe('Zod v4', () => { } } }, - taskStore + tasks: { taskStore } } ); diff --git a/test/integration/test/taskLifecycle.test.ts b/test/integration/test/taskLifecycle.test.ts index 622cf748c..8081384a9 100644 --- a/test/integration/test/taskLifecycle.test.ts +++ b/test/integration/test/taskLifecycle.test.ts @@ -46,8 +46,10 @@ describe('Task Lifecycle Integration Tests', () => { cancel: {} } }, - taskStore, - taskMessageQueue: new InMemoryTaskMessageQueue() + tasks: { + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() + } } );