Compare commits

...

1 Commits

Author SHA1 Message Date
Cursor Agent
9f30287eb9 fix(mcp): tighten resilience pipeline behavior 2026-03-10 00:17:49 +00:00
6 changed files with 462 additions and 8 deletions

View File

@@ -0,0 +1,143 @@
import { createLogger } from '@sim/logger'
import type { McpToolResult } from '@/lib/mcp/types'
import type { McpExecutionContext, McpMiddleware, McpMiddlewareNext } from './types'
// Configure standard cache size limit
const MAX_SERVER_STATES = 1000
export type CircuitState = 'CLOSED' | 'OPEN' | 'HALF-OPEN'
export interface CircuitBreakerConfig {
/** Number of failures before tripping to OPEN */
failureThreshold: number
/** How long to wait in OPEN before transitioning to HALF-OPEN (ms) */
resetTimeoutMs: number
}
interface ServerState {
state: CircuitState
failures: number
nextAttemptMs: number
isHalfOpenProbing: boolean
}
const logger = createLogger('mcp:resilience:circuit-breaker')
export class CircuitBreakerMiddleware implements McpMiddleware {
// Use a Map to maintain insertion order for standard LRU-like eviction if necessary.
// We constrain it to prevent memory leaks if thousands of ephemeral servers connect.
private registry = new Map<string, ServerState>()
private config: CircuitBreakerConfig
constructor(config: Partial<CircuitBreakerConfig> = {}) {
this.config = {
failureThreshold: config.failureThreshold ?? 5,
resetTimeoutMs: config.resetTimeoutMs ?? 30000,
}
}
private getState(serverId: string): ServerState {
let state = this.registry.get(serverId)
if (!state) {
state = {
state: 'CLOSED',
failures: 0,
nextAttemptMs: 0,
isHalfOpenProbing: false,
}
this.registry.set(serverId, state)
this.evictIfNecessary()
}
return state
}
private evictIfNecessary() {
if (this.registry.size > MAX_SERVER_STATES) {
// Evict the oldest entry (first inserted)
const firstKey = this.registry.keys().next().value
if (firstKey) {
this.registry.delete(firstKey)
}
}
}
async execute(context: McpExecutionContext, next: McpMiddlewareNext): Promise<McpToolResult> {
const { serverId, toolCall } = context
const serverState = this.getState(serverId)
// 1. Check current state and evaluate timeouts
if (serverState.state === 'OPEN') {
if (Date.now() > serverState.nextAttemptMs) {
// Time to try again, enter HALF-OPEN
logger.info(`Circuit breaker entering HALF-OPEN for server ${serverId}`)
serverState.state = 'HALF-OPEN'
serverState.isHalfOpenProbing = false
} else {
// Fast-fail
throw new Error(
`Circuit breaker is OPEN for server ${serverId}. Fast-failing request to ${toolCall.name}.`
)
}
}
if (serverState.state === 'HALF-OPEN') {
if (serverState.isHalfOpenProbing) {
// Another request is already probing. Fast-fail concurrent requests.
throw new Error(
`Circuit breaker is HALF-OPEN for server ${serverId}. A probe request is currently executing. Fast-failing concurrent request to ${toolCall.name}.`
)
}
// We are the chosen ones. Lock it down.
serverState.isHalfOpenProbing = true
}
try {
// 2. Invoke the next layer
const result = await next(context)
// 3. Handle result parsing (isError = true counts as failure for us)
if (result.isError) {
this.recordFailure(serverId, serverState)
} else {
this.recordSuccess(serverId, serverState)
}
return result
} catch (error) {
// Note: we record failure on ANY exception
this.recordFailure(serverId, serverState)
throw error // Re-throw to caller
}
}
private recordSuccess(serverId: string, state: ServerState) {
if (state.state !== 'CLOSED') {
logger.info(`Circuit breaker reset to CLOSED for server ${serverId}`)
}
state.state = 'CLOSED'
state.failures = 0
state.isHalfOpenProbing = false
}
private recordFailure(serverId: string, state: ServerState) {
if (state.state === 'HALF-OPEN') {
// The probe failed! Trip immediately back to OPEN.
logger.warn(`Circuit breaker probe failed. Tripping back to OPEN for server ${serverId}`)
this.tripToOpen(state)
} else if (state.state === 'CLOSED') {
state.failures++
if (state.failures >= this.config.failureThreshold) {
logger.error(
`Circuit breaker failure threshold reached (${state.failures}/${this.config.failureThreshold}). Tripping to OPEN for server ${serverId}`
)
this.tripToOpen(state)
}
}
}
private tripToOpen(state: ServerState) {
state.state = 'OPEN'
state.isHalfOpenProbing = false
state.nextAttemptMs = Date.now() + this.config.resetTimeoutMs
}
}

View File

@@ -0,0 +1,42 @@
import type { McpToolResult } from '@/lib/mcp/types'
import type { McpExecutionContext, McpMiddleware, McpMiddlewareNext } from './types'
export class ResiliencePipeline {
private middlewares: McpMiddleware[] = []
/**
* Add a middleware to the pipeline chain.
*/
use(middleware: McpMiddleware): this {
this.middlewares.push(middleware)
return this
}
/**
* Execute the pipeline, processing the context through all middlewares,
* and finally invoking the terminal handler.
*/
async execute(
context: McpExecutionContext,
finalHandler: McpMiddlewareNext
): Promise<McpToolResult> {
let index = -1
const dispatch = async (i: number, currentContext: McpExecutionContext): Promise<McpToolResult> => {
if (i <= index) {
throw new Error('next() called multiple times')
}
index = i
// If we reached the end of the middlewares, call the final handler
if (i === this.middlewares.length) {
return finalHandler(currentContext)
}
const middleware = this.middlewares[i]
return middleware.execute(currentContext, (nextContext) => dispatch(i + 1, nextContext))
}
return dispatch(0, context)
}
}

View File

@@ -0,0 +1,155 @@
import { createLogger } from '@sim/logger'
import { z } from 'zod'
import { createMcpToolId } from '@/lib/mcp/shared'
import type { McpTool, McpToolResult, McpToolSchema, McpToolSchemaProperty } from '@/lib/mcp/types'
import type { McpExecutionContext, McpMiddleware, McpMiddlewareNext } from './types'
const logger = createLogger('mcp:schema-validator')
export type ToolProvider = (
serverId: string,
toolName: string
) => McpTool | undefined | Promise<McpTool | undefined>
export class SchemaValidatorMiddleware implements McpMiddleware {
private schemaCache = new Map<string, z.ZodTypeAny>()
private toolProvider?: ToolProvider
constructor(options?: { toolProvider?: ToolProvider }) {
this.toolProvider = options?.toolProvider
}
/**
* Cache a tool's schema explicitly (e.g. during server discovery)
*/
cacheTool(tool: McpTool) {
const toolId = createMcpToolId(tool.serverId, tool.name)
const zodSchema = this.compileSchema(tool.inputSchema)
this.schemaCache.set(toolId, zodSchema)
}
/**
* Clear caches, either for a specific tool or globally.
*/
clearCache(toolId?: string) {
if (toolId) {
this.schemaCache.delete(toolId)
} else {
this.schemaCache.clear()
}
}
async execute(context: McpExecutionContext, next: McpMiddlewareNext): Promise<McpToolResult> {
const { toolCall } = context
const toolName = toolCall.name
const toolId = createMcpToolId(context.serverId, toolName)
let zodSchema = this.schemaCache.get(toolId)
if (!zodSchema && this.toolProvider) {
const tool = await this.toolProvider(context.serverId, toolName)
if (tool) {
zodSchema = this.compileSchema(tool.inputSchema)
this.schemaCache.set(toolId, zodSchema)
}
}
if (zodSchema) {
const parseResult = await zodSchema.safeParseAsync(toolCall.arguments)
if (!parseResult.success) {
// Return natively formatted error payload
const errorDetails = parseResult.error.errors
.map((e) => `${e.path.join('.') || 'root'}: ${e.message}`)
.join(', ')
logger.warn('Schema validation failed', { toolName, error: errorDetails })
return {
isError: true,
content: [
{
type: 'text',
text: `Schema validation failed: [${errorDetails}]`,
},
],
}
}
// Sync successfully parsed / defaulted arguments back to context
context.toolCall.arguments = parseResult.data
}
return next(context)
}
private compileSchema(schema: McpToolSchema): z.ZodObject<any> {
return this.compileObject(schema.properties || {}, schema.required || []) as z.ZodObject<any>
}
private compileObject(
properties: Record<string, McpToolSchemaProperty>,
required: string[]
): z.ZodTypeAny {
const shape: Record<string, z.ZodTypeAny> = {}
for (const [key, prop] of Object.entries(properties)) {
let zodType = this.compileProperty(prop)
if (!required.includes(key)) {
zodType = zodType.optional()
}
shape[key] = zodType
}
return z.object(shape)
}
private compileProperty(prop: McpToolSchemaProperty): z.ZodTypeAny {
let baseType: z.ZodTypeAny = z.any()
switch (prop.type) {
case 'string':
baseType = z.string()
break
case 'number':
case 'integer':
baseType = z.number()
break
case 'boolean':
baseType = z.boolean()
break
case 'array':
if (prop.items) {
baseType = z.array(this.compileProperty(prop.items))
} else {
baseType = z.array(z.any())
}
break
case 'object':
baseType = this.compileObject(prop.properties || {}, prop.required || [])
break
}
// Apply Enum mappings
if (prop.enum && prop.enum.length > 0) {
if (prop.enum.length === 1) {
baseType = z.literal(prop.enum[0])
} else {
// We use mapped literals injected into an array
const literals = prop.enum.map((e) => z.literal(e))
baseType = z.union(literals as any)
}
}
if (prop.description) {
baseType = baseType.describe(prop.description)
}
if (prop.default !== undefined) {
baseType = baseType.default(prop.default)
}
return baseType
}
}

View File

@@ -0,0 +1,53 @@
import { createLogger } from '@sim/logger'
import type { McpToolResult } from '@/lib/mcp/types'
import type { McpExecutionContext, McpMiddleware, McpMiddlewareNext } from './types'
const logger = createLogger('mcp:telemetry')
export class TelemetryMiddleware implements McpMiddleware {
async execute(context: McpExecutionContext, next: McpMiddlewareNext): Promise<McpToolResult> {
const startTime = performance.now()
try {
const result = await next(context)
const latency_ms = Math.round(performance.now() - startTime)
const isError = result.isError === true
logger.info('MCP Tool Execution Completed', {
toolName: context.toolCall.name,
serverId: context.serverId,
workspaceId: context.workspaceId,
latency_ms,
success: !isError,
...(isError && { failure_reason: 'TOOL_ERROR' }),
})
return result
} catch (error) {
const latency_ms = Math.round(performance.now() - startTime)
// Attempt to determine failure reason based on error
let failure_reason = 'API_500' // General failure fallback
if (error instanceof Error) {
const lowerMsg = error.message.toLowerCase()
if (error.name === 'TimeoutError' || lowerMsg.includes('timeout')) {
failure_reason = 'TIMEOUT'
} else if (lowerMsg.includes('validation') || error.name === 'ZodError') {
failure_reason = 'VALIDATION_ERROR'
}
}
logger.error('MCP Tool Execution Failed', {
toolName: context.toolCall.name,
serverId: context.serverId,
workspaceId: context.workspaceId,
latency_ms,
failure_reason,
err: error instanceof Error ? error.message : String(error),
})
throw error // Re-throw to allow upstream handling (e.g. circuit breaker)
}
}
}

View File

@@ -0,0 +1,32 @@
import type { McpToolCall, McpToolResult } from '@/lib/mcp/types'
/**
* Context passed through the Resilience Pipeline
*/
export interface McpExecutionContext {
toolCall: McpToolCall
serverId: string
userId: string
workspaceId: string
/**
* Additional parameters passed directly by the executeTool caller
*/
extraHeaders?: Record<string, string>
}
/**
* Standardized function signature for invoking the NEXT component in the pipeline
*/
export type McpMiddlewareNext = (context: McpExecutionContext) => Promise<McpToolResult>
/**
* Interface that all Resilience Middlewares must implement
*/
export interface McpMiddleware {
/**
* Execute the middleware logic
* @param context The current execution context
* @param next The next middleware/tool in the chain
*/
execute(context: McpExecutionContext, next: McpMiddlewareNext): Promise<McpToolResult>
}

View File

@@ -11,6 +11,10 @@ import { generateRequestId } from '@/lib/core/utils/request'
import { McpClient } from '@/lib/mcp/client'
import { mcpConnectionManager } from '@/lib/mcp/connection-manager'
import { isMcpDomainAllowed, validateMcpDomain } from '@/lib/mcp/domain-check'
import { CircuitBreakerMiddleware } from '@/lib/mcp/resilience/circuit-breaker'
import { ResiliencePipeline } from '@/lib/mcp/resilience/pipeline'
import { SchemaValidatorMiddleware } from '@/lib/mcp/resilience/schema-validator'
import { TelemetryMiddleware } from '@/lib/mcp/resilience/telemetry'
import { resolveMcpConfigEnvVars } from '@/lib/mcp/resolve-config'
import {
createMcpCacheAdapter,
@@ -35,10 +39,23 @@ class McpService {
private readonly cacheTimeout = MCP_CONSTANTS.CACHE_TIMEOUT
private unsubscribeConnectionManager?: () => void
private pipeline: ResiliencePipeline
private schemaValidator: SchemaValidatorMiddleware
private circuitBreaker: CircuitBreakerMiddleware
private telemetry: TelemetryMiddleware
constructor() {
this.cacheAdapter = createMcpCacheAdapter()
logger.info(`MCP Service initialized with ${getMcpCacheType()} cache`)
this.schemaValidator = new SchemaValidatorMiddleware()
this.circuitBreaker = new CircuitBreakerMiddleware()
this.telemetry = new TelemetryMiddleware()
this.pipeline = new ResiliencePipeline()
.use(this.telemetry)
.use(this.schemaValidator)
.use(this.circuitBreaker)
if (mcpConnectionManager) {
this.unsubscribeConnectionManager = mcpConnectionManager.subscribe((event) => {
this.clearCache(event.workspaceId)
@@ -191,15 +208,23 @@ class McpService {
if (extraHeaders && Object.keys(extraHeaders).length > 0) {
resolvedConfig.headers = { ...resolvedConfig.headers, ...extraHeaders }
}
const client = await this.createClient(resolvedConfig)
try {
const result = await client.callTool(toolCall)
logger.info(`[${requestId}] Successfully executed tool ${toolCall.name}`)
return result
} finally {
await client.disconnect()
const context = {
serverId,
workspaceId,
userId,
toolCall,
extraHeaders,
}
const result = await this.pipeline.execute(context, async (ctx) => {
const client = await this.createClient(resolvedConfig)
try {
return await client.callTool(ctx.toolCall)
} finally {
await client.disconnect()
}
})
logger.info(`[${requestId}] Successfully executed tool ${toolCall.name}`)
return result
} catch (error) {
if (this.isSessionError(error) && attempt < maxRetries - 1) {
logger.warn(
@@ -322,6 +347,7 @@ class McpService {
try {
const cached = await this.cacheAdapter.get(cacheKey)
if (cached) {
cached.tools.forEach((t: McpTool) => this.schemaValidator.cacheTool(t))
return cached.tools
}
} catch (error) {
@@ -414,6 +440,7 @@ class McpService {
logger.info(
`[${requestId}] Discovered ${allTools.length} tools from ${servers.length - failedCount}/${servers.length} servers`
)
allTools.forEach((t: McpTool) => this.schemaValidator.cacheTool(t))
return allTools
} catch (error) {
logger.error(`[${requestId}] Failed to discover MCP tools for user ${userId}:`, error)
@@ -450,6 +477,7 @@ class McpService {
try {
const tools = await client.listTools()
logger.info(`[${requestId}] Discovered ${tools.length} tools from server ${config.name}`)
tools.forEach((t: McpTool) => this.schemaValidator.cacheTool(t))
return tools
} finally {
await client.disconnect()
@@ -533,6 +561,7 @@ class McpService {
await this.cacheAdapter.clear()
logger.debug('Cleared all MCP tool cache')
}
this.schemaValidator.clearCache()
} catch (error) {
logger.warn('Failed to clear cache:', error)
}