diff --git a/apps/sim/app/api/mcp/copilot/route.ts b/apps/sim/app/api/mcp/copilot/route.ts index 5bc19858e..fdad4af43 100644 --- a/apps/sim/app/api/mcp/copilot/route.ts +++ b/apps/sim/app/api/mcp/copilot/route.ts @@ -1,17 +1,13 @@ import { Server } from '@modelcontextprotocol/sdk/server/index.js' -import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js' import { CallToolRequestSchema, type CallToolResult, ErrorCode, - isJSONRPCNotification, - isJSONRPCRequest, type JSONRPCError, - type JSONRPCMessage, type ListToolsResult, ListToolsRequestSchema, McpError, - type MessageExtraInfo, type RequestId, } from '@modelcontextprotocol/sdk/types.js' import { db } from '@sim/db' @@ -35,6 +31,7 @@ import { resolveWorkflowIdForUser } from '@/lib/workflows/utils' const logger = createLogger('CopilotMcpAPI') export const dynamic = 'force-dynamic' +export const runtime = 'nodejs' /** * MCP Server instructions that guide LLMs on how to use the Sim copilot tools. @@ -78,77 +75,7 @@ When the user refers to a workflow by name or description ("the email one", "my - Variable syntax: \`\` for block outputs, \`{{ENV_VAR}}\` for env vars. ` -class SingleRequestTransport implements Transport { - private started = false - private outgoing: JSONRPCMessage[] = [] - private waitingResolvers: Array<(message: JSONRPCMessage) => void> = [] - - onclose?: () => void - onerror?: (error: Error) => void - onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void - sessionId?: string - - async start(): Promise { - if (this.started) { - throw new Error('Transport already started') - } - this.started = true - } - - async send(message: JSONRPCMessage): Promise { - this.outgoing.push(message) - const resolver = this.waitingResolvers.shift() - if (resolver) { - resolver(message) - } - } - - async close(): Promise { - this.onclose?.() - } - - async dispatch(message: JSONRPCMessage, extra?: MessageExtraInfo): Promise { - if (!this.onmessage) { - throw new Error('Transport is not connected to an MCP server') - } - - await Promise.resolve(this.onmessage(message, extra)) - } - - consumeResponse(): JSONRPCMessage | null { - if (this.outgoing.length === 0) { - return null - } - - const [firstResponse] = this.outgoing - this.outgoing = [] - return firstResponse - } - - async waitForResponse(timeoutMs = 5000): Promise { - const immediate = this.consumeResponse() - if (immediate) { - return immediate - } - - return new Promise((resolve) => { - const timeout = setTimeout(() => { - const index = this.waitingResolvers.indexOf(resolver) - if (index >= 0) { - this.waitingResolvers.splice(index, 1) - } - resolve(null) - }, timeoutMs) - - const resolver = (message: JSONRPCMessage) => { - clearTimeout(timeout) - resolve(message) - } - - this.waitingResolvers.push(resolver) - }) - } -} +type HeaderMap = Record function createError(id: RequestId, code: ErrorCode | number, message: string): JSONRPCError { return { @@ -158,7 +85,140 @@ function createError(id: RequestId, code: ErrorCode | number, message: string): } } -function buildMcpServer(userId?: string): Server { +function normalizeRequestHeaders(request: NextRequest): HeaderMap { + const headers: HeaderMap = {} + + request.headers.forEach((value, key) => { + headers[key.toLowerCase()] = value + }) + + return headers +} + +function readHeader(headers: HeaderMap | undefined, name: string): string | undefined { + if (!headers) return undefined + const value = headers[name.toLowerCase()] + if (Array.isArray(value)) { + return value[0] + } + return value +} + +class NextResponseCapture { + private _status = 200 + private _headers = new Headers() + private _chunks: Buffer[] = [] + private _closeHandlers: Array<() => void> = [] + private _errorHandlers: Array<(error: Error) => void> = [] + private _ended = false + private _endedPromise: Promise + private _resolveEnded: (() => void) | null = null + + constructor() { + this._endedPromise = new Promise((resolve) => { + this._resolveEnded = resolve + }) + } + + writeHead(status: number, headers?: Record): this { + this._status = status + + if (headers) { + Object.entries(headers).forEach(([key, value]) => { + if (Array.isArray(value)) { + this._headers.set(key, value.join(', ')) + } else { + this._headers.set(key, String(value)) + } + }) + } + + return this + } + + flushHeaders(): this { + return this + } + + write(chunk: unknown): boolean { + if (typeof chunk === 'string') { + this._chunks.push(Buffer.from(chunk)) + return true + } + + if (chunk instanceof Uint8Array) { + this._chunks.push(Buffer.from(chunk)) + return true + } + + if (chunk !== undefined && chunk !== null) { + this._chunks.push(Buffer.from(String(chunk))) + } + + return true + } + + end(chunk?: unknown): this { + if (chunk !== undefined) { + this.write(chunk) + } + + this._ended = true + this._resolveEnded?.() + + this._closeHandlers.forEach((handler) => { + try { + handler() + } catch (error) { + this._errorHandlers.forEach((errorHandler) => { + errorHandler(error instanceof Error ? error : new Error(String(error))) + }) + } + }) + + return this + } + + async waitForEnd(timeoutMs = 30000): Promise { + if (this._ended) return + + await Promise.race([ + this._endedPromise, + new Promise((resolve) => { + setTimeout(resolve, timeoutMs) + }), + ]) + } + + on(event: 'close' | 'error', handler: (() => void) | ((error: Error) => void)): this { + if (event === 'close') { + this._closeHandlers.push(handler as () => void) + } + + if (event === 'error') { + this._errorHandlers.push(handler as (error: Error) => void) + } + + return this + } + + toNextResponse(): NextResponse { + if (this._chunks.length === 0) { + return new NextResponse(null, { + status: this._status, + headers: this._headers, + }) + } + + const body = Buffer.concat(this._chunks) + return new NextResponse(body, { + status: this._status, + headers: this._headers, + }) + } +} + +function buildMcpServer(): Server { const server = new Server( { name: 'sim-copilot', @@ -190,43 +250,88 @@ function buildMcpServer(userId?: string): Server { return result }) - server.setRequestHandler(CallToolRequestSchema, async (request) => { - if (!userId) { + server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + const headers = (extra.requestInfo?.headers || {}) as HeaderMap + const apiKeyHeader = readHeader(headers, 'x-api-key') + + if (!apiKeyHeader) { throw new McpError( - ErrorCode.InvalidRequest, + -32000, 'API key required. Set the x-api-key header with a valid Sim API key.' ) } + const authResult = await authenticateApiKeyFromHeader(apiKeyHeader) + if (!authResult.success || !authResult.userId) { + logger.warn('MCP auth failed', { + error: authResult.error, + method: request.method, + }) + + throw new McpError(-32000, authResult.error || 'Invalid API key') + } + + if (authResult.keyId) { + updateApiKeyLastUsed(authResult.keyId).catch((error) => { + logger.warn('Failed to update API key last-used timestamp', { + keyId: authResult.keyId, + error: error instanceof Error ? error.message : String(error), + }) + }) + } + + const usageCheck = await checkServerSideUsageLimits(authResult.userId) + if (usageCheck.isExceeded) { + throw new McpError( + -32000, + `Usage limit exceeded: ${usageCheck.message || 'Upgrade your plan.'}` + ) + } + const params = request.params as { name?: string; arguments?: Record } | undefined if (!params?.name) { throw new McpError(ErrorCode.InvalidParams, 'Tool name required') } - return handleToolsCall( + const result = await handleToolsCall( { name: params.name, arguments: params.arguments, }, - userId + authResult.userId ) + + trackMcpCopilotCall(authResult.userId) + + return result }) return server } async function handleMcpRequestWithSdk( - message: JSONRPCMessage, - userId?: string -): Promise { - const server = buildMcpServer(userId) - const transport = new SingleRequestTransport() + request: NextRequest, + parsedBody: unknown +): Promise { + const server = buildMcpServer() + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: undefined, + enableJsonResponse: true, + }) + + const responseCapture = new NextResponseCapture() + + const requestAdapter = { + method: request.method, + headers: normalizeRequestHeaders(request), + } await server.connect(transport) try { - await transport.dispatch(message) - return transport.waitForResponse() + await transport.handleRequest(requestAdapter as any, responseCapture as any, parsedBody) + await responseCapture.waitForEnd() + return responseCapture.toNextResponse() } finally { await server.close().catch(() => {}) await transport.close().catch(() => {}) @@ -243,98 +348,21 @@ export async function GET() { } export async function POST(request: NextRequest) { - let requestId: RequestId = 0 - try { - let body: JSONRPCMessage + let parsedBody: unknown try { - body = (await request.json()) as JSONRPCMessage + parsedBody = await request.json() } catch { return NextResponse.json(createError(0, ErrorCode.ParseError, 'Invalid JSON body'), { status: 400, }) } - if (isJSONRPCNotification(body)) { - return new NextResponse(null, { status: 202 }) - } - - if (!isJSONRPCRequest(body)) { - return NextResponse.json( - createError(0, ErrorCode.InvalidRequest, 'Invalid JSON-RPC message'), - { status: 400 } - ) - } - - requestId = body.id - - let userId: string | undefined - - if (body.method === 'tools/call') { - const apiKeyHeader = request.headers.get('x-api-key') - if (!apiKeyHeader) { - return NextResponse.json( - createError( - requestId, - -32000, - 'API key required. Set the x-api-key header with a valid Sim API key.' - ), - { status: 401 } - ) - } - - const authResult = await authenticateApiKeyFromHeader(apiKeyHeader) - if (!authResult.success || !authResult.userId) { - logger.warn('MCP auth failed', { - error: authResult.error, - method: body.method, - }) - - return NextResponse.json( - createError(requestId, -32000, authResult.error || 'Invalid API key'), - { status: 401 } - ) - } - - userId = authResult.userId - - if (authResult.keyId) { - updateApiKeyLastUsed(authResult.keyId).catch((error) => { - logger.warn('Failed to update API key last-used timestamp', { - keyId: authResult.keyId, - error: error instanceof Error ? error.message : String(error), - }) - }) - } - - const usageCheck = await checkServerSideUsageLimits(userId) - if (usageCheck.isExceeded) { - return NextResponse.json( - createError( - requestId, - -32000, - `Usage limit exceeded: ${usageCheck.message || 'Upgrade your plan.'}` - ), - { status: 402 } - ) - } - } - - const responseMessage = await handleMcpRequestWithSdk(body, userId) - - if (body.method === 'tools/call' && userId) { - trackMcpCopilotCall(userId) - } - - if (!responseMessage) { - return new NextResponse(null, { status: 202 }) - } - - return NextResponse.json(responseMessage) + return await handleMcpRequestWithSdk(request, parsedBody) } catch (error) { logger.error('Error handling MCP request', { error }) - return NextResponse.json(createError(requestId, ErrorCode.InternalError, 'Internal error'), { + return NextResponse.json(createError(0, ErrorCode.InternalError, 'Internal error'), { status: 500, }) }