This commit is contained in:
Siddharth Ganesan
2026-02-06 13:54:08 -08:00
parent 0f5eb9d351
commit f63ed61bc8

View File

@@ -1,17 +1,13 @@
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'
import {
CallToolRequestSchema,
type CallToolResult,
ErrorCode,
isJSONRPCNotification,
isJSONRPCRequest,
type JSONRPCError,
type JSONRPCMessage,
type ListToolsResult,
ListToolsRequestSchema,
McpError,
type MessageExtraInfo,
type RequestId,
} from '@modelcontextprotocol/sdk/types.js'
import { db } from '@sim/db'
@@ -35,6 +31,7 @@ import { resolveWorkflowIdForUser } from '@/lib/workflows/utils'
const logger = createLogger('CopilotMcpAPI')
export const dynamic = 'force-dynamic'
export const runtime = 'nodejs'
/**
* MCP Server instructions that guide LLMs on how to use the Sim copilot tools.
@@ -78,77 +75,7 @@ When the user refers to a workflow by name or description ("the email one", "my
- Variable syntax: \`<blockname.field>\` for block outputs, \`{{ENV_VAR}}\` for env vars.
`
class SingleRequestTransport implements Transport {
private started = false
private outgoing: JSONRPCMessage[] = []
private waitingResolvers: Array<(message: JSONRPCMessage) => void> = []
onclose?: () => void
onerror?: (error: Error) => void
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void
sessionId?: string
async start(): Promise<void> {
if (this.started) {
throw new Error('Transport already started')
}
this.started = true
}
async send(message: JSONRPCMessage): Promise<void> {
this.outgoing.push(message)
const resolver = this.waitingResolvers.shift()
if (resolver) {
resolver(message)
}
}
async close(): Promise<void> {
this.onclose?.()
}
async dispatch(message: JSONRPCMessage, extra?: MessageExtraInfo): Promise<void> {
if (!this.onmessage) {
throw new Error('Transport is not connected to an MCP server')
}
await Promise.resolve(this.onmessage(message, extra))
}
consumeResponse(): JSONRPCMessage | null {
if (this.outgoing.length === 0) {
return null
}
const [firstResponse] = this.outgoing
this.outgoing = []
return firstResponse
}
async waitForResponse(timeoutMs = 5000): Promise<JSONRPCMessage | null> {
const immediate = this.consumeResponse()
if (immediate) {
return immediate
}
return new Promise((resolve) => {
const timeout = setTimeout(() => {
const index = this.waitingResolvers.indexOf(resolver)
if (index >= 0) {
this.waitingResolvers.splice(index, 1)
}
resolve(null)
}, timeoutMs)
const resolver = (message: JSONRPCMessage) => {
clearTimeout(timeout)
resolve(message)
}
this.waitingResolvers.push(resolver)
})
}
}
type HeaderMap = Record<string, string | string[] | undefined>
function createError(id: RequestId, code: ErrorCode | number, message: string): JSONRPCError {
return {
@@ -158,7 +85,140 @@ function createError(id: RequestId, code: ErrorCode | number, message: string):
}
}
function buildMcpServer(userId?: string): Server {
function normalizeRequestHeaders(request: NextRequest): HeaderMap {
const headers: HeaderMap = {}
request.headers.forEach((value, key) => {
headers[key.toLowerCase()] = value
})
return headers
}
function readHeader(headers: HeaderMap | undefined, name: string): string | undefined {
if (!headers) return undefined
const value = headers[name.toLowerCase()]
if (Array.isArray(value)) {
return value[0]
}
return value
}
class NextResponseCapture {
private _status = 200
private _headers = new Headers()
private _chunks: Buffer[] = []
private _closeHandlers: Array<() => void> = []
private _errorHandlers: Array<(error: Error) => void> = []
private _ended = false
private _endedPromise: Promise<void>
private _resolveEnded: (() => void) | null = null
constructor() {
this._endedPromise = new Promise<void>((resolve) => {
this._resolveEnded = resolve
})
}
writeHead(status: number, headers?: Record<string, string | number | string[]>): this {
this._status = status
if (headers) {
Object.entries(headers).forEach(([key, value]) => {
if (Array.isArray(value)) {
this._headers.set(key, value.join(', '))
} else {
this._headers.set(key, String(value))
}
})
}
return this
}
flushHeaders(): this {
return this
}
write(chunk: unknown): boolean {
if (typeof chunk === 'string') {
this._chunks.push(Buffer.from(chunk))
return true
}
if (chunk instanceof Uint8Array) {
this._chunks.push(Buffer.from(chunk))
return true
}
if (chunk !== undefined && chunk !== null) {
this._chunks.push(Buffer.from(String(chunk)))
}
return true
}
end(chunk?: unknown): this {
if (chunk !== undefined) {
this.write(chunk)
}
this._ended = true
this._resolveEnded?.()
this._closeHandlers.forEach((handler) => {
try {
handler()
} catch (error) {
this._errorHandlers.forEach((errorHandler) => {
errorHandler(error instanceof Error ? error : new Error(String(error)))
})
}
})
return this
}
async waitForEnd(timeoutMs = 30000): Promise<void> {
if (this._ended) return
await Promise.race([
this._endedPromise,
new Promise<void>((resolve) => {
setTimeout(resolve, timeoutMs)
}),
])
}
on(event: 'close' | 'error', handler: (() => void) | ((error: Error) => void)): this {
if (event === 'close') {
this._closeHandlers.push(handler as () => void)
}
if (event === 'error') {
this._errorHandlers.push(handler as (error: Error) => void)
}
return this
}
toNextResponse(): NextResponse {
if (this._chunks.length === 0) {
return new NextResponse(null, {
status: this._status,
headers: this._headers,
})
}
const body = Buffer.concat(this._chunks)
return new NextResponse(body, {
status: this._status,
headers: this._headers,
})
}
}
function buildMcpServer(): Server {
const server = new Server(
{
name: 'sim-copilot',
@@ -190,43 +250,88 @@ function buildMcpServer(userId?: string): Server {
return result
})
server.setRequestHandler(CallToolRequestSchema, async (request) => {
if (!userId) {
server.setRequestHandler(CallToolRequestSchema, async (request, extra) => {
const headers = (extra.requestInfo?.headers || {}) as HeaderMap
const apiKeyHeader = readHeader(headers, 'x-api-key')
if (!apiKeyHeader) {
throw new McpError(
ErrorCode.InvalidRequest,
-32000,
'API key required. Set the x-api-key header with a valid Sim API key.'
)
}
const authResult = await authenticateApiKeyFromHeader(apiKeyHeader)
if (!authResult.success || !authResult.userId) {
logger.warn('MCP auth failed', {
error: authResult.error,
method: request.method,
})
throw new McpError(-32000, authResult.error || 'Invalid API key')
}
if (authResult.keyId) {
updateApiKeyLastUsed(authResult.keyId).catch((error) => {
logger.warn('Failed to update API key last-used timestamp', {
keyId: authResult.keyId,
error: error instanceof Error ? error.message : String(error),
})
})
}
const usageCheck = await checkServerSideUsageLimits(authResult.userId)
if (usageCheck.isExceeded) {
throw new McpError(
-32000,
`Usage limit exceeded: ${usageCheck.message || 'Upgrade your plan.'}`
)
}
const params = request.params as { name?: string; arguments?: Record<string, unknown> } | undefined
if (!params?.name) {
throw new McpError(ErrorCode.InvalidParams, 'Tool name required')
}
return handleToolsCall(
const result = await handleToolsCall(
{
name: params.name,
arguments: params.arguments,
},
userId
authResult.userId
)
trackMcpCopilotCall(authResult.userId)
return result
})
return server
}
async function handleMcpRequestWithSdk(
message: JSONRPCMessage,
userId?: string
): Promise<JSONRPCMessage | null> {
const server = buildMcpServer(userId)
const transport = new SingleRequestTransport()
request: NextRequest,
parsedBody: unknown
): Promise<NextResponse> {
const server = buildMcpServer()
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: undefined,
enableJsonResponse: true,
})
const responseCapture = new NextResponseCapture()
const requestAdapter = {
method: request.method,
headers: normalizeRequestHeaders(request),
}
await server.connect(transport)
try {
await transport.dispatch(message)
return transport.waitForResponse()
await transport.handleRequest(requestAdapter as any, responseCapture as any, parsedBody)
await responseCapture.waitForEnd()
return responseCapture.toNextResponse()
} finally {
await server.close().catch(() => {})
await transport.close().catch(() => {})
@@ -243,98 +348,21 @@ export async function GET() {
}
export async function POST(request: NextRequest) {
let requestId: RequestId = 0
try {
let body: JSONRPCMessage
let parsedBody: unknown
try {
body = (await request.json()) as JSONRPCMessage
parsedBody = await request.json()
} catch {
return NextResponse.json(createError(0, ErrorCode.ParseError, 'Invalid JSON body'), {
status: 400,
})
}
if (isJSONRPCNotification(body)) {
return new NextResponse(null, { status: 202 })
}
if (!isJSONRPCRequest(body)) {
return NextResponse.json(
createError(0, ErrorCode.InvalidRequest, 'Invalid JSON-RPC message'),
{ status: 400 }
)
}
requestId = body.id
let userId: string | undefined
if (body.method === 'tools/call') {
const apiKeyHeader = request.headers.get('x-api-key')
if (!apiKeyHeader) {
return NextResponse.json(
createError(
requestId,
-32000,
'API key required. Set the x-api-key header with a valid Sim API key.'
),
{ status: 401 }
)
}
const authResult = await authenticateApiKeyFromHeader(apiKeyHeader)
if (!authResult.success || !authResult.userId) {
logger.warn('MCP auth failed', {
error: authResult.error,
method: body.method,
})
return NextResponse.json(
createError(requestId, -32000, authResult.error || 'Invalid API key'),
{ status: 401 }
)
}
userId = authResult.userId
if (authResult.keyId) {
updateApiKeyLastUsed(authResult.keyId).catch((error) => {
logger.warn('Failed to update API key last-used timestamp', {
keyId: authResult.keyId,
error: error instanceof Error ? error.message : String(error),
})
})
}
const usageCheck = await checkServerSideUsageLimits(userId)
if (usageCheck.isExceeded) {
return NextResponse.json(
createError(
requestId,
-32000,
`Usage limit exceeded: ${usageCheck.message || 'Upgrade your plan.'}`
),
{ status: 402 }
)
}
}
const responseMessage = await handleMcpRequestWithSdk(body, userId)
if (body.method === 'tools/call' && userId) {
trackMcpCopilotCall(userId)
}
if (!responseMessage) {
return new NextResponse(null, { status: 202 })
}
return NextResponse.json(responseMessage)
return await handleMcpRequestWithSdk(request, parsedBody)
} catch (error) {
logger.error('Error handling MCP request', { error })
return NextResponse.json(createError(requestId, ErrorCode.InternalError, 'Internal error'), {
return NextResponse.json(createError(0, ErrorCode.InternalError, 'Internal error'), {
status: 500,
})
}