mirror of
https://github.com/simstudioai/sim.git
synced 2026-02-07 05:05:15 -05:00
Fix
This commit is contained in:
@@ -1,17 +1,13 @@
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'
|
||||
import {
|
||||
CallToolRequestSchema,
|
||||
type CallToolResult,
|
||||
ErrorCode,
|
||||
isJSONRPCNotification,
|
||||
isJSONRPCRequest,
|
||||
type JSONRPCError,
|
||||
type JSONRPCMessage,
|
||||
type ListToolsResult,
|
||||
ListToolsRequestSchema,
|
||||
McpError,
|
||||
type MessageExtraInfo,
|
||||
type RequestId,
|
||||
} from '@modelcontextprotocol/sdk/types.js'
|
||||
import { db } from '@sim/db'
|
||||
@@ -35,6 +31,7 @@ import { resolveWorkflowIdForUser } from '@/lib/workflows/utils'
|
||||
const logger = createLogger('CopilotMcpAPI')
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
export const runtime = 'nodejs'
|
||||
|
||||
/**
|
||||
* MCP Server instructions that guide LLMs on how to use the Sim copilot tools.
|
||||
@@ -78,77 +75,7 @@ When the user refers to a workflow by name or description ("the email one", "my
|
||||
- Variable syntax: \`<blockname.field>\` for block outputs, \`{{ENV_VAR}}\` for env vars.
|
||||
`
|
||||
|
||||
class SingleRequestTransport implements Transport {
|
||||
private started = false
|
||||
private outgoing: JSONRPCMessage[] = []
|
||||
private waitingResolvers: Array<(message: JSONRPCMessage) => void> = []
|
||||
|
||||
onclose?: () => void
|
||||
onerror?: (error: Error) => void
|
||||
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void
|
||||
sessionId?: string
|
||||
|
||||
async start(): Promise<void> {
|
||||
if (this.started) {
|
||||
throw new Error('Transport already started')
|
||||
}
|
||||
this.started = true
|
||||
}
|
||||
|
||||
async send(message: JSONRPCMessage): Promise<void> {
|
||||
this.outgoing.push(message)
|
||||
const resolver = this.waitingResolvers.shift()
|
||||
if (resolver) {
|
||||
resolver(message)
|
||||
}
|
||||
}
|
||||
|
||||
async close(): Promise<void> {
|
||||
this.onclose?.()
|
||||
}
|
||||
|
||||
async dispatch(message: JSONRPCMessage, extra?: MessageExtraInfo): Promise<void> {
|
||||
if (!this.onmessage) {
|
||||
throw new Error('Transport is not connected to an MCP server')
|
||||
}
|
||||
|
||||
await Promise.resolve(this.onmessage(message, extra))
|
||||
}
|
||||
|
||||
consumeResponse(): JSONRPCMessage | null {
|
||||
if (this.outgoing.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
const [firstResponse] = this.outgoing
|
||||
this.outgoing = []
|
||||
return firstResponse
|
||||
}
|
||||
|
||||
async waitForResponse(timeoutMs = 5000): Promise<JSONRPCMessage | null> {
|
||||
const immediate = this.consumeResponse()
|
||||
if (immediate) {
|
||||
return immediate
|
||||
}
|
||||
|
||||
return new Promise((resolve) => {
|
||||
const timeout = setTimeout(() => {
|
||||
const index = this.waitingResolvers.indexOf(resolver)
|
||||
if (index >= 0) {
|
||||
this.waitingResolvers.splice(index, 1)
|
||||
}
|
||||
resolve(null)
|
||||
}, timeoutMs)
|
||||
|
||||
const resolver = (message: JSONRPCMessage) => {
|
||||
clearTimeout(timeout)
|
||||
resolve(message)
|
||||
}
|
||||
|
||||
this.waitingResolvers.push(resolver)
|
||||
})
|
||||
}
|
||||
}
|
||||
type HeaderMap = Record<string, string | string[] | undefined>
|
||||
|
||||
function createError(id: RequestId, code: ErrorCode | number, message: string): JSONRPCError {
|
||||
return {
|
||||
@@ -158,7 +85,140 @@ function createError(id: RequestId, code: ErrorCode | number, message: string):
|
||||
}
|
||||
}
|
||||
|
||||
function buildMcpServer(userId?: string): Server {
|
||||
function normalizeRequestHeaders(request: NextRequest): HeaderMap {
|
||||
const headers: HeaderMap = {}
|
||||
|
||||
request.headers.forEach((value, key) => {
|
||||
headers[key.toLowerCase()] = value
|
||||
})
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
function readHeader(headers: HeaderMap | undefined, name: string): string | undefined {
|
||||
if (!headers) return undefined
|
||||
const value = headers[name.toLowerCase()]
|
||||
if (Array.isArray(value)) {
|
||||
return value[0]
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
class NextResponseCapture {
|
||||
private _status = 200
|
||||
private _headers = new Headers()
|
||||
private _chunks: Buffer[] = []
|
||||
private _closeHandlers: Array<() => void> = []
|
||||
private _errorHandlers: Array<(error: Error) => void> = []
|
||||
private _ended = false
|
||||
private _endedPromise: Promise<void>
|
||||
private _resolveEnded: (() => void) | null = null
|
||||
|
||||
constructor() {
|
||||
this._endedPromise = new Promise<void>((resolve) => {
|
||||
this._resolveEnded = resolve
|
||||
})
|
||||
}
|
||||
|
||||
writeHead(status: number, headers?: Record<string, string | number | string[]>): this {
|
||||
this._status = status
|
||||
|
||||
if (headers) {
|
||||
Object.entries(headers).forEach(([key, value]) => {
|
||||
if (Array.isArray(value)) {
|
||||
this._headers.set(key, value.join(', '))
|
||||
} else {
|
||||
this._headers.set(key, String(value))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
flushHeaders(): this {
|
||||
return this
|
||||
}
|
||||
|
||||
write(chunk: unknown): boolean {
|
||||
if (typeof chunk === 'string') {
|
||||
this._chunks.push(Buffer.from(chunk))
|
||||
return true
|
||||
}
|
||||
|
||||
if (chunk instanceof Uint8Array) {
|
||||
this._chunks.push(Buffer.from(chunk))
|
||||
return true
|
||||
}
|
||||
|
||||
if (chunk !== undefined && chunk !== null) {
|
||||
this._chunks.push(Buffer.from(String(chunk)))
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
end(chunk?: unknown): this {
|
||||
if (chunk !== undefined) {
|
||||
this.write(chunk)
|
||||
}
|
||||
|
||||
this._ended = true
|
||||
this._resolveEnded?.()
|
||||
|
||||
this._closeHandlers.forEach((handler) => {
|
||||
try {
|
||||
handler()
|
||||
} catch (error) {
|
||||
this._errorHandlers.forEach((errorHandler) => {
|
||||
errorHandler(error instanceof Error ? error : new Error(String(error)))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
async waitForEnd(timeoutMs = 30000): Promise<void> {
|
||||
if (this._ended) return
|
||||
|
||||
await Promise.race([
|
||||
this._endedPromise,
|
||||
new Promise<void>((resolve) => {
|
||||
setTimeout(resolve, timeoutMs)
|
||||
}),
|
||||
])
|
||||
}
|
||||
|
||||
on(event: 'close' | 'error', handler: (() => void) | ((error: Error) => void)): this {
|
||||
if (event === 'close') {
|
||||
this._closeHandlers.push(handler as () => void)
|
||||
}
|
||||
|
||||
if (event === 'error') {
|
||||
this._errorHandlers.push(handler as (error: Error) => void)
|
||||
}
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
toNextResponse(): NextResponse {
|
||||
if (this._chunks.length === 0) {
|
||||
return new NextResponse(null, {
|
||||
status: this._status,
|
||||
headers: this._headers,
|
||||
})
|
||||
}
|
||||
|
||||
const body = Buffer.concat(this._chunks)
|
||||
return new NextResponse(body, {
|
||||
status: this._status,
|
||||
headers: this._headers,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function buildMcpServer(): Server {
|
||||
const server = new Server(
|
||||
{
|
||||
name: 'sim-copilot',
|
||||
@@ -190,43 +250,88 @@ function buildMcpServer(userId?: string): Server {
|
||||
return result
|
||||
})
|
||||
|
||||
server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
||||
if (!userId) {
|
||||
server.setRequestHandler(CallToolRequestSchema, async (request, extra) => {
|
||||
const headers = (extra.requestInfo?.headers || {}) as HeaderMap
|
||||
const apiKeyHeader = readHeader(headers, 'x-api-key')
|
||||
|
||||
if (!apiKeyHeader) {
|
||||
throw new McpError(
|
||||
ErrorCode.InvalidRequest,
|
||||
-32000,
|
||||
'API key required. Set the x-api-key header with a valid Sim API key.'
|
||||
)
|
||||
}
|
||||
|
||||
const authResult = await authenticateApiKeyFromHeader(apiKeyHeader)
|
||||
if (!authResult.success || !authResult.userId) {
|
||||
logger.warn('MCP auth failed', {
|
||||
error: authResult.error,
|
||||
method: request.method,
|
||||
})
|
||||
|
||||
throw new McpError(-32000, authResult.error || 'Invalid API key')
|
||||
}
|
||||
|
||||
if (authResult.keyId) {
|
||||
updateApiKeyLastUsed(authResult.keyId).catch((error) => {
|
||||
logger.warn('Failed to update API key last-used timestamp', {
|
||||
keyId: authResult.keyId,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
const usageCheck = await checkServerSideUsageLimits(authResult.userId)
|
||||
if (usageCheck.isExceeded) {
|
||||
throw new McpError(
|
||||
-32000,
|
||||
`Usage limit exceeded: ${usageCheck.message || 'Upgrade your plan.'}`
|
||||
)
|
||||
}
|
||||
|
||||
const params = request.params as { name?: string; arguments?: Record<string, unknown> } | undefined
|
||||
if (!params?.name) {
|
||||
throw new McpError(ErrorCode.InvalidParams, 'Tool name required')
|
||||
}
|
||||
|
||||
return handleToolsCall(
|
||||
const result = await handleToolsCall(
|
||||
{
|
||||
name: params.name,
|
||||
arguments: params.arguments,
|
||||
},
|
||||
userId
|
||||
authResult.userId
|
||||
)
|
||||
|
||||
trackMcpCopilotCall(authResult.userId)
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
async function handleMcpRequestWithSdk(
|
||||
message: JSONRPCMessage,
|
||||
userId?: string
|
||||
): Promise<JSONRPCMessage | null> {
|
||||
const server = buildMcpServer(userId)
|
||||
const transport = new SingleRequestTransport()
|
||||
request: NextRequest,
|
||||
parsedBody: unknown
|
||||
): Promise<NextResponse> {
|
||||
const server = buildMcpServer()
|
||||
const transport = new StreamableHTTPServerTransport({
|
||||
sessionIdGenerator: undefined,
|
||||
enableJsonResponse: true,
|
||||
})
|
||||
|
||||
const responseCapture = new NextResponseCapture()
|
||||
|
||||
const requestAdapter = {
|
||||
method: request.method,
|
||||
headers: normalizeRequestHeaders(request),
|
||||
}
|
||||
|
||||
await server.connect(transport)
|
||||
|
||||
try {
|
||||
await transport.dispatch(message)
|
||||
return transport.waitForResponse()
|
||||
await transport.handleRequest(requestAdapter as any, responseCapture as any, parsedBody)
|
||||
await responseCapture.waitForEnd()
|
||||
return responseCapture.toNextResponse()
|
||||
} finally {
|
||||
await server.close().catch(() => {})
|
||||
await transport.close().catch(() => {})
|
||||
@@ -243,98 +348,21 @@ export async function GET() {
|
||||
}
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
let requestId: RequestId = 0
|
||||
|
||||
try {
|
||||
let body: JSONRPCMessage
|
||||
let parsedBody: unknown
|
||||
|
||||
try {
|
||||
body = (await request.json()) as JSONRPCMessage
|
||||
parsedBody = await request.json()
|
||||
} catch {
|
||||
return NextResponse.json(createError(0, ErrorCode.ParseError, 'Invalid JSON body'), {
|
||||
status: 400,
|
||||
})
|
||||
}
|
||||
|
||||
if (isJSONRPCNotification(body)) {
|
||||
return new NextResponse(null, { status: 202 })
|
||||
}
|
||||
|
||||
if (!isJSONRPCRequest(body)) {
|
||||
return NextResponse.json(
|
||||
createError(0, ErrorCode.InvalidRequest, 'Invalid JSON-RPC message'),
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
requestId = body.id
|
||||
|
||||
let userId: string | undefined
|
||||
|
||||
if (body.method === 'tools/call') {
|
||||
const apiKeyHeader = request.headers.get('x-api-key')
|
||||
if (!apiKeyHeader) {
|
||||
return NextResponse.json(
|
||||
createError(
|
||||
requestId,
|
||||
-32000,
|
||||
'API key required. Set the x-api-key header with a valid Sim API key.'
|
||||
),
|
||||
{ status: 401 }
|
||||
)
|
||||
}
|
||||
|
||||
const authResult = await authenticateApiKeyFromHeader(apiKeyHeader)
|
||||
if (!authResult.success || !authResult.userId) {
|
||||
logger.warn('MCP auth failed', {
|
||||
error: authResult.error,
|
||||
method: body.method,
|
||||
})
|
||||
|
||||
return NextResponse.json(
|
||||
createError(requestId, -32000, authResult.error || 'Invalid API key'),
|
||||
{ status: 401 }
|
||||
)
|
||||
}
|
||||
|
||||
userId = authResult.userId
|
||||
|
||||
if (authResult.keyId) {
|
||||
updateApiKeyLastUsed(authResult.keyId).catch((error) => {
|
||||
logger.warn('Failed to update API key last-used timestamp', {
|
||||
keyId: authResult.keyId,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
const usageCheck = await checkServerSideUsageLimits(userId)
|
||||
if (usageCheck.isExceeded) {
|
||||
return NextResponse.json(
|
||||
createError(
|
||||
requestId,
|
||||
-32000,
|
||||
`Usage limit exceeded: ${usageCheck.message || 'Upgrade your plan.'}`
|
||||
),
|
||||
{ status: 402 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const responseMessage = await handleMcpRequestWithSdk(body, userId)
|
||||
|
||||
if (body.method === 'tools/call' && userId) {
|
||||
trackMcpCopilotCall(userId)
|
||||
}
|
||||
|
||||
if (!responseMessage) {
|
||||
return new NextResponse(null, { status: 202 })
|
||||
}
|
||||
|
||||
return NextResponse.json(responseMessage)
|
||||
return await handleMcpRequestWithSdk(request, parsedBody)
|
||||
} catch (error) {
|
||||
logger.error('Error handling MCP request', { error })
|
||||
return NextResponse.json(createError(requestId, ErrorCode.InternalError, 'Internal error'), {
|
||||
return NextResponse.json(createError(0, ErrorCode.InternalError, 'Internal error'), {
|
||||
status: 500,
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user