diff --git a/backend/src/studio-mcp/__tests__/studio-mcp.service.spec.ts b/backend/src/studio-mcp/__tests__/studio-mcp.service.spec.ts index ad084462..60057234 100644 --- a/backend/src/studio-mcp/__tests__/studio-mcp.service.spec.ts +++ b/backend/src/studio-mcp/__tests__/studio-mcp.service.spec.ts @@ -4,9 +4,9 @@ import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import type { AuthContext } from '../../auth/types'; import type { WorkflowsService } from '../../workflows/workflows.service'; -// Helper to access private _registeredTools on McpServer (plain object at runtime) -type ToolHandler = (...args: unknown[]) => unknown; -type RegisteredToolsMap = Record; +// Helper to access private _registeredTools and experimental tasks on McpServer (plain object at runtime) +type RegisteredToolsMap = Record; + function getRegisteredTools(server: McpServer): RegisteredToolsMap { return (server as unknown as { _registeredTools: RegisteredToolsMap })._registeredTools; } @@ -60,12 +60,11 @@ describe('StudioMcpService Unit Tests', () => { expect(server).toBeInstanceOf(McpServer); }); - it('registers all 9 expected tools', () => { + it('registers all expected tools and tasks', () => { const server = service.createServer(mockAuthContext); const registeredTools = getRegisteredTools(server); expect(registeredTools).toBeDefined(); - expect(Object.keys(registeredTools).length).toBe(9); const toolNames = Object.keys(registeredTools).sort(); expect(toolNames).toEqual([ @@ -110,16 +109,27 @@ describe('StudioMcpService Unit Tests', () => { expect(workflowsService.findById).toHaveBeenCalledWith(workflowId, mockAuthContext); }); - it('run_workflow tool uses auth context passed at creation time', async () => { + it('run_workflow task uses auth context passed at creation time', async () => { const workflowId = '11111111-1111-4111-8111-111111111111'; const inputs = { key: 'value' }; const server = service.createServer(mockAuthContext); const registeredTools = getRegisteredTools(server); - const runWorkflowTool = registeredTools['run_workflow']; + const runWorkflowTask = registeredTools['run_workflow']; + + expect(runWorkflowTask).toBeDefined(); + + // Need to mock the extra params for the experimental tasks + const mockExtra = { + taskStore: { + createTask: jest.fn().mockResolvedValue({ taskId: 'mockTaskId', status: 'working' }), + getTask: jest.fn().mockResolvedValue({ taskId: 'mockTaskId', status: 'working' }), + updateTaskStatus: jest.fn().mockResolvedValue(true), + storeTaskResult: jest.fn().mockResolvedValue(true), + }, + }; - expect(runWorkflowTool).toBeDefined(); - await runWorkflowTool.handler({ workflowId, inputs }); + await runWorkflowTask.handler.createTask({ workflowId, inputs }, mockExtra); expect(workflowsService.run).toHaveBeenCalledWith( workflowId, @@ -129,7 +139,7 @@ describe('StudioMcpService Unit Tests', () => { trigger: { type: 'api', sourceId: mockAuthContext.userId, - label: 'Studio MCP', + label: 'Studio MCP Task', }, }, ); @@ -230,12 +240,21 @@ describe('StudioMcpService Unit Tests', () => { it('denies run_workflow when workflows.run is false', async () => { const server = service.createServer(restrictedAuth); - const tools = getRegisteredTools(server); - const result = (await tools['run_workflow'].handler({ - workflowId: '11111111-1111-4111-8111-111111111111', - })) as { isError?: boolean; content: { text: string }[] }; - expect(result.isError).toBe(true); - expect(result.content[0].text).toContain('workflows.run'); + const tasks = getRegisteredTools(server); + + let errorThrown = false; + try { + await tasks['run_workflow'].handler.createTask( + { + workflowId: '11111111-1111-4111-8111-111111111111', + }, + {} as any, + ); + } catch (_e: any) { + errorThrown = true; + expect(_e.message).toContain('workflows.run'); + } + expect(errorThrown).toBe(true); }); it('denies cancel_run when runs.cancel is false', async () => { @@ -260,15 +279,28 @@ describe('StudioMcpService Unit Tests', () => { it('allows all tools when no apiKeyPermissions (non-API-key auth)', async () => { const server = service.createServer(mockAuthContext); // no apiKeyPermissions const tools = getRegisteredTools(server); + const tasks = getRegisteredTools(server); // All workflow/run tools should work without permission errors const listResult = (await tools['list_workflows'].handler({})) as { isError?: boolean }; expect(listResult.isError).toBeUndefined(); - const runResult = (await tools['run_workflow'].handler({ - workflowId: '11111111-1111-4111-8111-111111111111', - })) as { isError?: boolean }; - expect(runResult.isError).toBeUndefined(); + const mockExtra = { + taskStore: { + createTask: jest.fn().mockResolvedValue({ taskId: 'mock', status: 'working' }), + getTask: jest.fn().mockResolvedValue({ taskId: 'mock', status: 'working' }), + updateTaskStatus: jest.fn().mockResolvedValue(true), + storeTaskResult: jest.fn().mockResolvedValue(true), + }, + }; + + const runResult = await tasks['run_workflow'].handler.createTask( + { + workflowId: '11111111-1111-4111-8111-111111111111', + }, + mockExtra, + ); + expect(runResult.task.taskId).toEqual('mock'); const cancelResult = (await tools['cancel_run'].handler({ runId: 'test-run-id', @@ -308,11 +340,11 @@ describe('StudioMcpService Unit Tests', () => { }; const server = service.createServer(noPermsAuth); const tools = getRegisteredTools(server); + const tasks = getRegisteredTools(server); const gatedTools = [ 'list_workflows', 'get_workflow', - 'run_workflow', 'list_runs', 'get_run_status', 'get_run_result', @@ -326,6 +358,20 @@ describe('StudioMcpService Unit Tests', () => { })) as { isError?: boolean }; expect(result.isError).toBe(true); } + + // Test run_workflow separately since it's a task now + let errorThrown = false; + try { + await tasks['run_workflow'].handler.createTask( + { + workflowId: '11111111-1111-4111-8111-111111111111', + }, + {} as any, + ); + } catch (_e: any) { + errorThrown = true; + } + expect(errorThrown).toBe(true); }); }); @@ -366,4 +412,91 @@ describe('StudioMcpService Unit Tests', () => { expect(workflowsService.list).toHaveBeenNthCalledWith(2, authContext2); }); }); + + describe('monitorWorkflowRun', () => { + it('polls status and saves result on completion', async () => { + const mockTaskStore = { + updateTaskStatus: jest.fn().mockResolvedValue(true), + storeTaskResult: jest.fn().mockResolvedValue(true), + }; + + const mockServer = {} as McpServer; + const taskId = 'test-task-id'; + const runId = 'test-run-id'; + + // Mock getRunStatus to return RUNNING first, then COMPLETED + let callCount = 0; + (workflowsService.getRunStatus as jest.Mock).mockImplementation(() => { + callCount++; + return Promise.resolve({ + status: callCount === 1 ? 'RUNNING' : 'COMPLETED', + }); + }); + + (workflowsService.getRunResult as jest.Mock).mockResolvedValue({ + output: 'test-output', + }); + + // We overwrite the 2000ms timeout temporarily for the test to avoid slow running loop + const originalSetTimeout = global.setTimeout; + (global as any).setTimeout = (fn: any) => originalSetTimeout(fn, 1); + + try { + await (service as any).monitorWorkflowRun( + runId, + undefined, + taskId, + mockTaskStore, + mockServer, + mockAuthContext, + ); + } finally { + global.setTimeout = originalSetTimeout as any; + } + + // updateTaskStatus is only called for non-terminal states (RUNNING → working). + // For COMPLETED, storeTaskResult handles the terminal transition directly. + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledTimes(1); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith(taskId, 'working', 'RUNNING'); + expect(mockTaskStore.updateTaskStatus).not.toHaveBeenCalledWith( + taskId, + 'completed', + 'COMPLETED', + ); + expect(workflowsService.getRunResult).toHaveBeenCalledWith(runId, undefined, mockAuthContext); + expect(mockTaskStore.storeTaskResult).toHaveBeenCalledWith(taskId, 'completed', { + content: [{ type: 'text', text: JSON.stringify({ output: 'test-output' }, null, 2) }], + }); + }); + + it('handles failures by storing the failure reason', async () => { + const mockTaskStore = { + updateTaskStatus: jest.fn().mockResolvedValue(true), + storeTaskResult: jest.fn().mockResolvedValue(true), + }; + + const taskId = 'test-task-id'; + const runId = 'test-run-id'; + + (workflowsService.getRunStatus as jest.Mock).mockResolvedValue({ + status: 'FAILED', + failure: { message: 'boom' }, + }); + + await (service as any).monitorWorkflowRun( + runId, + undefined, + taskId, + mockTaskStore, + {} as McpServer, + mockAuthContext, + ); + + // updateTaskStatus is NOT called for terminal states — storeTaskResult handles it. + expect(mockTaskStore.updateTaskStatus).not.toHaveBeenCalled(); + expect(mockTaskStore.storeTaskResult).toHaveBeenCalledWith(taskId, 'failed', { + content: [{ type: 'text', text: JSON.stringify({ message: 'boom' }, null, 2) }], + }); + }); + }); }); diff --git a/backend/src/studio-mcp/studio-mcp.service.ts b/backend/src/studio-mcp/studio-mcp.service.ts index 8ac4b92b..d58b6c2d 100644 --- a/backend/src/studio-mcp/studio-mcp.service.ts +++ b/backend/src/studio-mcp/studio-mcp.service.ts @@ -16,6 +16,10 @@ import { categorizeComponent } from '../components/utils/categorization'; import { WorkflowsService, type WorkflowRunSummary } from '../workflows/workflows.service'; import type { ServiceWorkflowResponse } from '../workflows/dto/workflow-graph.dto'; import type { AuthContext, ApiKeyPermissions } from '../auth/types'; +import { + InMemoryTaskStore, + InMemoryTaskMessageQueue, +} from '@modelcontextprotocol/sdk/experimental/index.js'; type PermissionPath = | 'workflows.list' @@ -27,6 +31,8 @@ type PermissionPath = @Injectable() export class StudioMcpService { private readonly logger = new Logger(StudioMcpService.name); + private readonly taskStore = new InMemoryTaskStore(); + private readonly taskMessageQueue = new InMemoryTaskMessageQueue(); constructor(private readonly workflowsService: WorkflowsService) {} @@ -67,10 +73,20 @@ export class StudioMcpService { * Uses Streamable HTTP transport only (no legacy SSE). */ createServer(auth: AuthContext): McpServer { - const server = new McpServer({ - name: 'shipsec-studio', - version: '1.0.0', - }); + const server = new McpServer( + { + name: 'shipsec-studio', + version: '1.0.0', + }, + { + capabilities: { + logging: {}, + tasks: { requests: { tools: { call: {} } } }, + }, + taskStore: this.taskStore, + taskMessageQueue: this.taskMessageQueue, + }, + ); this.registerTools(server, auth); @@ -138,57 +154,71 @@ export class StudioMcpService { }, ); - server.registerTool( + const runWorkflowSchema = { + workflowId: z.string().uuid(), + inputs: z.record(z.string(), z.unknown()).optional(), + versionId: z.string().uuid().optional(), + }; + + server.experimental.tasks.registerToolTask( 'run_workflow', { description: - 'Start a workflow execution. Returns the run ID and initial status. Use get_run_status to poll for completion.', - inputSchema: { - workflowId: z.string().uuid(), - inputs: z.record(z.string(), z.unknown()).optional(), - versionId: z.string().uuid().optional(), - }, + 'Start a workflow execution as a background task. The task handle can be monitored for status updates, and finally retrieved for the workflow result. Also supports legacy polling via get_run_status.', + inputSchema: runWorkflowSchema, + execution: { taskSupport: 'optional' }, }, - async (args: { - workflowId: string; - inputs?: Record; - versionId?: string; - }) => { - const gate = this.checkPermission(auth, 'workflows.run'); - if (!gate.allowed) return gate.error; - try { + { + createTask: async (args, extra) => { + const gate = this.checkPermission(auth, 'workflows.run'); + if (!gate.allowed) throw new Error(gate.error.content[0].text); + + const task = await extra.taskStore.createTask({ ttl: 12 * 60 * 60 * 1000 }); + const handle = await this.workflowsService.run( args.workflowId, - { inputs: args.inputs ?? {}, versionId: args.versionId }, + { + inputs: args.inputs ?? {}, + versionId: args.versionId, + }, auth, { trigger: { type: 'api', sourceId: auth.userId ?? 'api-key', - label: 'Studio MCP', + label: 'Studio MCP Task', }, }, ); - return { - content: [ - { - type: 'text' as const, - text: JSON.stringify( - { - runId: handle.runId, - workflowId: handle.workflowId, - status: handle.status, - workflowVersion: handle.workflowVersion, - }, - null, - 2, - ), - }, - ], - }; - } catch (error) { - return this.errorResult(error); - } + + this.monitorWorkflowRun( + handle.runId, + handle.temporalRunId, + task.taskId, + extra.taskStore, + server, + auth, + ).catch((err) => { + this.logger.error(`Error monitoring workflow run task for run ${handle.runId}: ${err}`); + }); + + return { task }; + }, + getTask: async (args, extra) => { + const gate = this.checkPermission(auth, 'runs.read'); + if (!gate.allowed) throw new Error(gate.error.content[0].text); + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + getTaskResult: async (args, extra) => { + const gate = this.checkPermission(auth, 'runs.read'); + if (!gate.allowed) throw new Error(gate.error.content[0].text); + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as any; + }, }, ); } @@ -397,6 +427,88 @@ export class StudioMcpService { ); } + private async monitorWorkflowRun( + runId: string, + temporalRunId: string | undefined, + taskId: string, + taskStore: any, + server: McpServer, + auth: AuthContext, + ): Promise { + const isTerminal = (status: string) => + ['COMPLETED', 'FAILED', 'CANCELLED', 'TERMINATED', 'TIMED_OUT'].includes(status); + + const mapStatus = (status: string): 'working' | 'completed' | 'cancelled' | 'failed' => { + switch (status) { + case 'RUNNING': + case 'QUEUED': + case 'AWAITING_INPUT': + return 'working'; + case 'COMPLETED': + return 'completed'; + case 'CANCELLED': + case 'TERMINATED': + case 'TIMED_OUT': + return 'cancelled'; + case 'FAILED': + return 'failed'; + default: + return 'working'; + } + }; + + while (true) { + try { + const runStatusPayload = await this.workflowsService.getRunStatus( + runId, + temporalRunId, + auth, + ); + const taskState = mapStatus(runStatusPayload.status); + + if (isTerminal(runStatusPayload.status)) { + // For terminal states, storeTaskResult sets the status itself. + // Do NOT call updateTaskStatus first — it would move the task into a terminal + // state and then storeTaskResult would refuse to update it again. + let resultData: any; + if (taskState === 'completed') { + try { + resultData = await this.workflowsService.getRunResult(runId, temporalRunId, auth); + } catch (err) { + resultData = { error: String(err) }; + } + } else { + resultData = runStatusPayload.failure || { reason: runStatusPayload.status }; + } + + const resultPayload = { + content: [{ type: 'text', text: JSON.stringify(resultData, null, 2) }], + }; + + const storeStatus = taskState === 'completed' ? 'completed' : 'failed'; + await taskStore.storeTaskResult(taskId, storeStatus, resultPayload); + break; + } + + // Non-terminal: just update status and keep polling + await taskStore.updateTaskStatus(taskId, taskState, runStatusPayload.status); + await new Promise((res) => setTimeout(res, 2000)); + } catch (err) { + this.logger.error(`Error monitoring task ${taskId} (run: ${runId}): ${err}`); + try { + // storeTaskResult sets the terminal status; don't call updateTaskStatus first + await taskStore.storeTaskResult(taskId, 'failed', { + content: [{ type: 'text', text: `Failed to monitor workflow run: ${String(err)}` }], + isError: true, + }); + } catch (_storeErr) { + // Ignore — task may already be in a terminal state + } + break; + } + } + } + // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- diff --git a/e2e-tests/studio-mcp/studio-mcp-agent.test.ts b/e2e-tests/studio-mcp/studio-mcp-agent.test.ts index 702d409f..cc88b272 100644 --- a/e2e-tests/studio-mcp/studio-mcp-agent.test.ts +++ b/e2e-tests/studio-mcp/studio-mcp-agent.test.ts @@ -37,6 +37,7 @@ e2eDescribe('Studio MCP: AI SDK Integration', () => { permissions: { workflows: { run: true, list: true, read: true }, runs: { read: true, cancel: true }, + audit: { read: true }, }, }), }); @@ -279,7 +280,7 @@ e2eDescribe('Studio MCP: AI SDK Integration', () => { content: [ { type: 'text', - text: `Run the workflow with ID "${workflowId}" using the input message "Hello from AI SDK test". Then check its status.`, + text: `Run the workflow with ID "${workflowId}" using the input message "Hello from AI SDK test". Then check its status and result, and report back the final message.`, }, ], }, diff --git a/e2e-tests/studio-mcp/studio-mcp-tasks.test.ts b/e2e-tests/studio-mcp/studio-mcp-tasks.test.ts new file mode 100644 index 00000000..53a864bb --- /dev/null +++ b/e2e-tests/studio-mcp/studio-mcp-tasks.test.ts @@ -0,0 +1,130 @@ +import { expect, beforeAll } from 'bun:test'; +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { API_BASE, HEADERS, e2eDescribe, e2eTest, createWorkflow } from '../helpers/e2e-harness'; + +interface ApiKeyResponse { + id: string; + plainKey: string; + name: string; +} + +e2eDescribe('Studio MCP: Task API Integration', () => { + let plainKey: string | null = null; + let workflowId: string | null = null; + + beforeAll(async () => { + // Create an API key with workflow permissions + const keyRes = await fetch(`${API_BASE}/api-keys`, { + method: 'POST', + headers: HEADERS, + body: JSON.stringify({ + name: `e2e-mcp-tasks-${Date.now()}`, + permissions: { + workflows: { run: true, list: true, read: true }, + runs: { read: true, cancel: true }, + audit: { read: true }, + }, + }), + }); + + if (!keyRes.ok) { + throw new Error(`Failed to create API key: ${keyRes.status} ${await keyRes.text()}`); + } + + const keyData = (await keyRes.json()) as ApiKeyResponse; + plainKey = keyData.plainKey; + + // Create a minimal workflow (entry point only — runs and completes immediately) + workflowId = await createWorkflow({ + name: `E2E Task API Test ${Date.now()}`, + nodes: [ + { + id: 'start', + type: 'core.workflow.entrypoint', + position: { x: 0, y: 0 }, + data: { + label: 'Start', + config: { + params: { + runtimeInputs: [{ id: 'message', label: 'Message', type: 'text' }], + }, + }, + }, + }, + ], + edges: [], + }); + }); + + e2eTest( + 'run_workflow via Task API streams taskCreated and result messages', + { timeout: 60000 }, + async () => { + expect(plainKey).toBeDefined(); + expect(workflowId).toBeDefined(); + + const transport = new StreamableHTTPClientTransport(new URL(`${API_BASE}/studio-mcp`), { + requestInit: { + headers: { Authorization: `Bearer ${plainKey}` }, + }, + }); + + const client = new Client( + { name: 'e2e-task-client', version: '1.0.0' }, + { + capabilities: { + tasks: { + requests: { + tasks: { get: {}, list: {}, result: {}, cancel: {} }, + }, + }, + }, + }, + ); + + await client.connect(transport); + + try { + // Cache tool metadata (required to detect task-capable tools) + await client.listTools(); + + const messages: any[] = []; + + const stream = client.experimental.tasks.callToolStream({ + name: 'run_workflow', + arguments: { + workflowId, + inputs: { message: 'hello from task api test' }, + }, + }); + + for await (const message of stream) { + messages.push(message); + console.log('[task stream]', JSON.stringify(message)); + if (message.type === 'result' || message.type === 'error') { + break; + } + } + + // Must have gotten at least a taskCreated and a result message + expect(messages.length).toBeGreaterThanOrEqual(2); + + // Verify taskCreated + const taskCreated = messages.find((m) => m.type === 'taskCreated'); + expect(taskCreated).toBeDefined(); + expect(typeof taskCreated.task.taskId).toBe('string'); + expect(taskCreated.task.status).toBe('working'); + + // Verify final result (not an error) + const result = messages.find((m) => m.type === 'result'); + expect(result).toBeDefined(); + expect(result.result.isError).toBeFalsy(); + expect(result.result.content).toBeArray(); + expect(result.result.content.length).toBeGreaterThan(0); + } finally { + await client.close(); + } + }, + ); +}); diff --git a/test-mcp-server.ts b/test-mcp-server.ts deleted file mode 100644 index a280060e..00000000 --- a/test-mcp-server.ts +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env node -/** - * Simple MCP Server for Testing - * Exposes a single "get_weather" tool - */ - -import { Server } from '@modelcontextprotocol/sdk/server/index.js'; -import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; -import { - CallToolRequestSchema, - ListToolsRequestSchema, - Tool, -} from '@modelcontextprotocol/sdk/types.js'; - -// Create server -const server = new Server( - { - name: 'test-weather-server', - version: '1.0.0', - }, - { - capabilities: { - tools: {}, - }, - }, -); - -// Define tools -const tools: Tool[] = [ - { - name: 'get_weather', - description: 'Get the current weather for a location', - inputSchema: { - type: 'object', - properties: { - location: { - type: 'string', - description: 'City name or zip code', - }, - units: { - type: 'string', - enum: ['celsius', 'fahrenheit'], - description: 'Temperature units (default: celsius)', - }, - }, - required: ['location'], - }, - }, -]; - -// Handle tool list requests -server.setRequestHandler(ListToolsRequestSchema, async () => { - console.error('[TEST-MCP] Received list_tools request'); - return { tools }; -}); - -// Handle tool calls -server.setRequestHandler(CallToolRequestSchema, async (request) => { - console.error(`[TEST-MCP] Received tool call: ${request.params.name}`); - - if (request.params.name === 'get_weather') { - const { location, units = 'celsius' } = request.params.arguments as { - location: string; - units?: string; - }; - - const weather = { - location: location, - temperature: 22, - condition: 'Sunny', - humidity: 65, - wind_speed: 10, - units: units, - timestamp: new Date().toISOString(), - mcp_server_response: true, - message: `This is a real response from the MCP server for ${location}!`, - }; - - console.error(`[TEST-MCP] Returning weather for: ${location}`); - return { - content: [ - { - type: 'text', - text: JSON.stringify(weather, null, 2), - }, - ], - }; - } - - return { - content: [ - { - type: 'text', - text: `Unknown tool: ${request.params.name}`, - }, - ], - isError: true, - }; -}); - -// Start server -async function main() { - const transport = new StdioServerTransport(); - await server.connect(transport); - console.error('[TEST-MCP] Server started on stdio'); -} - -main().catch(console.error);