mirror of
https://github.com/simstudioai/sim.git
synced 2026-02-08 21:54:57 -05:00
fix(mcp): harden notification system against race conditions
- Guard concurrent connect() calls in connection manager with connectingServers Set - Suppress post-disconnect notification handler firing in MCP client - Clean up Redis event listeners in pub/sub dispose() - Add tests for all three hardening fixes (11 new tests)
This commit is contained in:
111
apps/sim/lib/mcp/client.test.ts
Normal file
111
apps/sim/lib/mcp/client.test.ts
Normal file
@@ -0,0 +1,111 @@
|
||||
/**
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { loggerMock } from '@sim/testing'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('@sim/logger', () => loggerMock)
|
||||
|
||||
/**
|
||||
* Capture the notification handler registered via `client.setNotificationHandler()`.
|
||||
* This lets us simulate the MCP SDK delivering a `tools/list_changed` notification.
|
||||
*/
|
||||
let capturedNotificationHandler: (() => Promise<void>) | null = null
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
|
||||
Client: vi.fn().mockImplementation(() => ({
|
||||
connect: vi.fn().mockResolvedValue(undefined),
|
||||
close: vi.fn().mockResolvedValue(undefined),
|
||||
getServerVersion: vi.fn().mockReturnValue('2025-06-18'),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ tools: { listChanged: true } }),
|
||||
setNotificationHandler: vi
|
||||
.fn()
|
||||
.mockImplementation((_schema: unknown, handler: () => Promise<void>) => {
|
||||
capturedNotificationHandler = handler
|
||||
}),
|
||||
listTools: vi.fn().mockResolvedValue({ tools: [] }),
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({
|
||||
StreamableHTTPClientTransport: vi.fn().mockImplementation(() => ({
|
||||
onclose: null,
|
||||
sessionId: 'test-session',
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/types.js', () => ({
|
||||
ToolListChangedNotificationSchema: { method: 'notifications/tools/list_changed' },
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/core/execution-limits', () => ({
|
||||
getMaxExecutionTimeout: vi.fn().mockReturnValue(30000),
|
||||
}))
|
||||
|
||||
import { McpClient } from './client'
|
||||
import type { McpServerConfig } from './types'
|
||||
|
||||
function createConfig(): McpServerConfig {
|
||||
return {
|
||||
id: 'server-1',
|
||||
name: 'Test Server',
|
||||
transport: 'streamable-http',
|
||||
url: 'https://test.example.com/mcp',
|
||||
}
|
||||
}
|
||||
|
||||
describe('McpClient notification handler', () => {
|
||||
beforeEach(() => {
|
||||
capturedNotificationHandler = null
|
||||
})
|
||||
|
||||
it('fires onToolsChanged when a notification arrives while connected', async () => {
|
||||
const onToolsChanged = vi.fn()
|
||||
|
||||
const client = new McpClient({
|
||||
config: createConfig(),
|
||||
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
|
||||
onToolsChanged,
|
||||
})
|
||||
|
||||
await client.connect()
|
||||
|
||||
expect(capturedNotificationHandler).not.toBeNull()
|
||||
|
||||
await capturedNotificationHandler!()
|
||||
|
||||
expect(onToolsChanged).toHaveBeenCalledTimes(1)
|
||||
expect(onToolsChanged).toHaveBeenCalledWith('server-1')
|
||||
})
|
||||
|
||||
it('suppresses notifications after disconnect', async () => {
|
||||
const onToolsChanged = vi.fn()
|
||||
|
||||
const client = new McpClient({
|
||||
config: createConfig(),
|
||||
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
|
||||
onToolsChanged,
|
||||
})
|
||||
|
||||
await client.connect()
|
||||
expect(capturedNotificationHandler).not.toBeNull()
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
// Simulate a late notification arriving after disconnect
|
||||
await capturedNotificationHandler!()
|
||||
|
||||
expect(onToolsChanged).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not register a notification handler when onToolsChanged is not provided', async () => {
|
||||
const client = new McpClient({
|
||||
config: createConfig(),
|
||||
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
|
||||
})
|
||||
|
||||
await client.connect()
|
||||
|
||||
expect(capturedNotificationHandler).toBeNull()
|
||||
})
|
||||
})
|
||||
@@ -10,10 +10,15 @@
|
||||
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||
import type { ListToolsResult, Tool } from '@modelcontextprotocol/sdk/types.js'
|
||||
import {
|
||||
type ListToolsResult,
|
||||
type Tool,
|
||||
ToolListChangedNotificationSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { getMaxExecutionTimeout } from '@/lib/core/execution-limits'
|
||||
import {
|
||||
type McpClientOptions,
|
||||
McpConnectionError,
|
||||
type McpConnectionStatus,
|
||||
type McpConsentRequest,
|
||||
@@ -24,6 +29,7 @@ import {
|
||||
type McpTool,
|
||||
type McpToolCall,
|
||||
type McpToolResult,
|
||||
type McpToolsChangedCallback,
|
||||
type McpVersionInfo,
|
||||
} from '@/lib/mcp/types'
|
||||
|
||||
@@ -35,6 +41,7 @@ export class McpClient {
|
||||
private config: McpServerConfig
|
||||
private connectionStatus: McpConnectionStatus
|
||||
private securityPolicy: McpSecurityPolicy
|
||||
private onToolsChanged?: McpToolsChangedCallback
|
||||
private isConnected = false
|
||||
|
||||
private static readonly SUPPORTED_VERSIONS = [
|
||||
@@ -44,23 +51,36 @@ export class McpClient {
|
||||
]
|
||||
|
||||
/**
|
||||
* Creates a new MCP client
|
||||
* Creates a new MCP client.
|
||||
*
|
||||
* No session ID parameter (we disconnect after each operation).
|
||||
* The SDK handles session management automatically via Mcp-Session-Id header.
|
||||
*
|
||||
* @param config - Server configuration
|
||||
* @param securityPolicy - Optional security policy
|
||||
* Accepts either the legacy (config, securityPolicy?) signature
|
||||
* or a single McpClientOptions object with an optional onToolsChanged callback.
|
||||
*/
|
||||
constructor(config: McpServerConfig, securityPolicy?: McpSecurityPolicy) {
|
||||
this.config = config
|
||||
this.connectionStatus = { connected: false }
|
||||
this.securityPolicy = securityPolicy ?? {
|
||||
requireConsent: true,
|
||||
auditLevel: 'basic',
|
||||
maxToolExecutionsPerHour: 1000,
|
||||
constructor(config: McpServerConfig, securityPolicy?: McpSecurityPolicy)
|
||||
constructor(options: McpClientOptions)
|
||||
constructor(
|
||||
configOrOptions: McpServerConfig | McpClientOptions,
|
||||
securityPolicy?: McpSecurityPolicy
|
||||
) {
|
||||
if ('config' in configOrOptions) {
|
||||
this.config = configOrOptions.config
|
||||
this.securityPolicy = configOrOptions.securityPolicy ?? {
|
||||
requireConsent: true,
|
||||
auditLevel: 'basic',
|
||||
maxToolExecutionsPerHour: 1000,
|
||||
}
|
||||
this.onToolsChanged = configOrOptions.onToolsChanged
|
||||
} else {
|
||||
this.config = configOrOptions
|
||||
this.securityPolicy = securityPolicy ?? {
|
||||
requireConsent: true,
|
||||
auditLevel: 'basic',
|
||||
maxToolExecutionsPerHour: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
this.connectionStatus = { connected: false }
|
||||
|
||||
if (!this.config.url) {
|
||||
throw new McpError('URL required for Streamable HTTP transport')
|
||||
}
|
||||
@@ -79,16 +99,15 @@ export class McpClient {
|
||||
{
|
||||
capabilities: {
|
||||
tools: {},
|
||||
// Resources and prompts can be added later
|
||||
// resources: {},
|
||||
// prompts: {},
|
||||
},
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize connection to MCP server
|
||||
* Initialize connection to MCP server.
|
||||
* If an `onToolsChanged` callback was provided, registers a notification handler
|
||||
* for `notifications/tools/list_changed` after connecting.
|
||||
*/
|
||||
async connect(): Promise<void> {
|
||||
logger.info(`Connecting to MCP server: ${this.config.name} (${this.config.transport})`)
|
||||
@@ -100,6 +119,15 @@ export class McpClient {
|
||||
this.connectionStatus.connected = true
|
||||
this.connectionStatus.lastConnected = new Date()
|
||||
|
||||
if (this.onToolsChanged) {
|
||||
this.client.setNotificationHandler(ToolListChangedNotificationSchema, async () => {
|
||||
if (!this.isConnected) return
|
||||
logger.info(`[${this.config.name}] Received tools/list_changed notification`)
|
||||
this.onToolsChanged?.(this.config.id)
|
||||
})
|
||||
logger.info(`[${this.config.name}] Registered tools/list_changed notification handler`)
|
||||
}
|
||||
|
||||
const serverVersion = this.client.getServerVersion()
|
||||
logger.info(`Successfully connected to MCP server: ${this.config.name}`, {
|
||||
protocolVersion: serverVersion,
|
||||
@@ -241,6 +269,23 @@ export class McpClient {
|
||||
return !!serverCapabilities?.[capability]
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the server declared `capabilities.tools.listChanged: true` during initialization.
|
||||
*/
|
||||
hasListChangedCapability(): boolean {
|
||||
const caps = this.client.getServerCapabilities()
|
||||
const toolsCap = caps?.tools as Record<string, unknown> | undefined
|
||||
return !!toolsCap?.listChanged
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a callback to be invoked when the underlying transport closes.
|
||||
* Used by the connection manager for reconnection logic.
|
||||
*/
|
||||
onClose(callback: () => void): void {
|
||||
this.transport.onclose = callback
|
||||
}
|
||||
|
||||
/**
|
||||
* Get server configuration
|
||||
*/
|
||||
|
||||
184
apps/sim/lib/mcp/connection-manager.test.ts
Normal file
184
apps/sim/lib/mcp/connection-manager.test.ts
Normal file
@@ -0,0 +1,184 @@
|
||||
/**
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { loggerMock } from '@sim/testing'
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
interface MockMcpClient {
|
||||
connect: ReturnType<typeof vi.fn>
|
||||
disconnect: ReturnType<typeof vi.fn>
|
||||
hasListChangedCapability: ReturnType<typeof vi.fn>
|
||||
onClose: ReturnType<typeof vi.fn>
|
||||
}
|
||||
|
||||
/** Deferred promise to control when `client.connect()` resolves. */
|
||||
function createDeferred<T = void>() {
|
||||
let resolve!: (value: T) => void
|
||||
const promise = new Promise<T>((res) => {
|
||||
resolve = res
|
||||
})
|
||||
return { promise, resolve }
|
||||
}
|
||||
|
||||
function serverConfig(id: string, name = `Server ${id}`) {
|
||||
return {
|
||||
id,
|
||||
name,
|
||||
transport: 'streamable-http' as const,
|
||||
url: `https://${id}.example.com/mcp`,
|
||||
}
|
||||
}
|
||||
|
||||
/** Shared setup: resets modules and applies base mocks. */
|
||||
function setupBaseMocks() {
|
||||
vi.resetModules()
|
||||
vi.doMock('@sim/logger', () => loggerMock)
|
||||
vi.doMock('@/lib/core/config/feature-flags', () => ({ isTest: false }))
|
||||
vi.doMock('@/lib/mcp/pubsub', () => ({
|
||||
mcpPubSub: { onToolsChanged: vi.fn(() => vi.fn()), publishToolsChanged: vi.fn() },
|
||||
}))
|
||||
}
|
||||
|
||||
describe('McpConnectionManager', () => {
|
||||
let manager: { connect: Function; dispose: Function } | null = null
|
||||
|
||||
afterEach(() => {
|
||||
manager?.dispose()
|
||||
manager = null
|
||||
})
|
||||
|
||||
describe('concurrent connect() guard', () => {
|
||||
it('creates only one client when two connect() calls race for the same serverId', async () => {
|
||||
setupBaseMocks()
|
||||
|
||||
const deferred = createDeferred()
|
||||
const instances: MockMcpClient[] = []
|
||||
|
||||
vi.doMock('./client', () => ({
|
||||
McpClient: vi.fn().mockImplementation(() => {
|
||||
const instance: MockMcpClient = {
|
||||
connect: vi.fn().mockImplementation(() => deferred.promise),
|
||||
disconnect: vi.fn().mockResolvedValue(undefined),
|
||||
hasListChangedCapability: vi.fn().mockReturnValue(true),
|
||||
onClose: vi.fn(),
|
||||
}
|
||||
instances.push(instance)
|
||||
return instance
|
||||
}),
|
||||
}))
|
||||
|
||||
const { mcpConnectionManager: mgr } = await import('./connection-manager')
|
||||
manager = mgr
|
||||
|
||||
const config = serverConfig('server-1')
|
||||
|
||||
// Fire two concurrent connect() calls for the same server
|
||||
const p1 = mgr.connect(config, 'user-1', 'ws-1')
|
||||
const p2 = mgr.connect(config, 'user-1', 'ws-1')
|
||||
|
||||
deferred.resolve()
|
||||
const [r1, r2] = await Promise.all([p1, p2])
|
||||
|
||||
// Only one McpClient should have been instantiated
|
||||
expect(instances).toHaveLength(1)
|
||||
expect(r1.supportsListChanged).toBe(true)
|
||||
// Second call hits the connectingServers guard and returns false
|
||||
expect(r2.supportsListChanged).toBe(false)
|
||||
})
|
||||
|
||||
it('allows a new connect() after a previous one completes', async () => {
|
||||
setupBaseMocks()
|
||||
|
||||
const instances: MockMcpClient[] = []
|
||||
|
||||
vi.doMock('./client', () => ({
|
||||
McpClient: vi.fn().mockImplementation(() => {
|
||||
const instance: MockMcpClient = {
|
||||
connect: vi.fn().mockResolvedValue(undefined),
|
||||
disconnect: vi.fn().mockResolvedValue(undefined),
|
||||
hasListChangedCapability: vi.fn().mockReturnValue(false),
|
||||
onClose: vi.fn(),
|
||||
}
|
||||
instances.push(instance)
|
||||
return instance
|
||||
}),
|
||||
}))
|
||||
|
||||
const { mcpConnectionManager: mgr } = await import('./connection-manager')
|
||||
manager = mgr
|
||||
|
||||
const config = serverConfig('server-2')
|
||||
|
||||
// First connect — server doesn't support listChanged, disconnects immediately
|
||||
const r1 = await mgr.connect(config, 'user-1', 'ws-1')
|
||||
expect(r1.supportsListChanged).toBe(false)
|
||||
|
||||
// connectingServers cleaned up via finally, so second connect proceeds
|
||||
const r2 = await mgr.connect(config, 'user-1', 'ws-1')
|
||||
expect(r2.supportsListChanged).toBe(false)
|
||||
|
||||
expect(instances).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('cleans up connectingServers when connect() throws', async () => {
|
||||
setupBaseMocks()
|
||||
|
||||
let callCount = 0
|
||||
const instances: MockMcpClient[] = []
|
||||
|
||||
vi.doMock('./client', () => ({
|
||||
McpClient: vi.fn().mockImplementation(() => {
|
||||
callCount++
|
||||
const instance: MockMcpClient = {
|
||||
connect:
|
||||
callCount === 1
|
||||
? vi.fn().mockRejectedValue(new Error('Connection refused'))
|
||||
: vi.fn().mockResolvedValue(undefined),
|
||||
disconnect: vi.fn().mockResolvedValue(undefined),
|
||||
hasListChangedCapability: vi.fn().mockReturnValue(true),
|
||||
onClose: vi.fn(),
|
||||
}
|
||||
instances.push(instance)
|
||||
return instance
|
||||
}),
|
||||
}))
|
||||
|
||||
const { mcpConnectionManager: mgr } = await import('./connection-manager')
|
||||
manager = mgr
|
||||
|
||||
const config = serverConfig('server-3')
|
||||
|
||||
// First connect fails
|
||||
const r1 = await mgr.connect(config, 'user-1', 'ws-1')
|
||||
expect(r1.supportsListChanged).toBe(false)
|
||||
|
||||
// Second connect should NOT be blocked by a stale connectingServers entry
|
||||
const r2 = await mgr.connect(config, 'user-1', 'ws-1')
|
||||
expect(r2.supportsListChanged).toBe(true)
|
||||
expect(instances).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('dispose', () => {
|
||||
it('rejects new connections after dispose', async () => {
|
||||
setupBaseMocks()
|
||||
|
||||
vi.doMock('./client', () => ({
|
||||
McpClient: vi.fn().mockImplementation(() => ({
|
||||
connect: vi.fn().mockResolvedValue(undefined),
|
||||
disconnect: vi.fn().mockResolvedValue(undefined),
|
||||
hasListChangedCapability: vi.fn().mockReturnValue(true),
|
||||
onClose: vi.fn(),
|
||||
})),
|
||||
}))
|
||||
|
||||
const { mcpConnectionManager: mgr } = await import('./connection-manager')
|
||||
manager = mgr
|
||||
|
||||
mgr.dispose()
|
||||
|
||||
const result = await mgr.connect(serverConfig('server-4'), 'user-1', 'ws-1')
|
||||
expect(result.supportsListChanged).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
361
apps/sim/lib/mcp/connection-manager.ts
Normal file
361
apps/sim/lib/mcp/connection-manager.ts
Normal file
@@ -0,0 +1,361 @@
|
||||
/**
|
||||
* MCP Connection Manager
|
||||
*
|
||||
* Maintains persistent connections to MCP servers that support
|
||||
* `notifications/tools/list_changed`. When a notification arrives,
|
||||
* the manager invalidates the tools cache and emits a ToolsChangedEvent
|
||||
* so the frontend SSE endpoint can push updates to browsers.
|
||||
*
|
||||
* Servers that do not support `listChanged` fall back to the existing
|
||||
* stale-time cache approach — no persistent connection is kept.
|
||||
*/
|
||||
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { isTest } from '@/lib/core/config/feature-flags'
|
||||
import { McpClient } from '@/lib/mcp/client'
|
||||
import { mcpPubSub } from '@/lib/mcp/pubsub'
|
||||
import type {
|
||||
ManagedConnectionState,
|
||||
McpServerConfig,
|
||||
McpToolsChangedCallback,
|
||||
ToolsChangedEvent,
|
||||
} from '@/lib/mcp/types'
|
||||
|
||||
const logger = createLogger('McpConnectionManager')
|
||||
|
||||
const MAX_CONNECTIONS = 50
|
||||
const MAX_RECONNECT_ATTEMPTS = 10
|
||||
const BASE_RECONNECT_DELAY_MS = 1000
|
||||
const IDLE_TIMEOUT_MS = 30 * 60 * 1000 // 30 minutes
|
||||
const IDLE_CHECK_INTERVAL_MS = 5 * 60 * 1000 // 5 minutes
|
||||
|
||||
type ToolsChangedListener = (event: ToolsChangedEvent) => void
|
||||
|
||||
class McpConnectionManager {
|
||||
private connections = new Map<string, McpClient>()
|
||||
private states = new Map<string, ManagedConnectionState>()
|
||||
private reconnectTimers = new Map<string, ReturnType<typeof setTimeout>>()
|
||||
private listeners = new Set<ToolsChangedListener>()
|
||||
private connectingServers = new Set<string>()
|
||||
private idleCheckTimer: ReturnType<typeof setInterval> | null = null
|
||||
private disposed = false
|
||||
private unsubscribePubSub?: () => void
|
||||
|
||||
constructor() {
|
||||
if (mcpPubSub) {
|
||||
this.unsubscribePubSub = mcpPubSub.onToolsChanged((event) => {
|
||||
this.notifyLocalListeners(event)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to tools-changed events from any managed connection.
|
||||
* Returns an unsubscribe function.
|
||||
*/
|
||||
subscribe(listener: ToolsChangedListener): () => void {
|
||||
this.listeners.add(listener)
|
||||
return () => {
|
||||
this.listeners.delete(listener)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Establish a persistent connection to an MCP server.
|
||||
* If the server supports `listChanged`, the connection is kept alive
|
||||
* and notifications are forwarded to subscribers.
|
||||
*
|
||||
* If the server does NOT support `listChanged`, the client is disconnected
|
||||
* immediately — there's nothing to listen for.
|
||||
*/
|
||||
async connect(
|
||||
config: McpServerConfig,
|
||||
userId: string,
|
||||
workspaceId: string
|
||||
): Promise<{ supportsListChanged: boolean }> {
|
||||
if (this.disposed) {
|
||||
logger.warn('Connection manager is disposed, ignoring connect request')
|
||||
return { supportsListChanged: false }
|
||||
}
|
||||
|
||||
const serverId = config.id
|
||||
|
||||
if (this.connections.has(serverId) || this.connectingServers.has(serverId)) {
|
||||
logger.info(`[${config.name}] Already has a managed connection or is connecting, skipping`)
|
||||
const state = this.states.get(serverId)
|
||||
return { supportsListChanged: state?.supportsListChanged ?? false }
|
||||
}
|
||||
|
||||
if (this.connections.size >= MAX_CONNECTIONS) {
|
||||
logger.warn(`Max connections (${MAX_CONNECTIONS}) reached, cannot connect to ${config.name}`)
|
||||
return { supportsListChanged: false }
|
||||
}
|
||||
|
||||
this.connectingServers.add(serverId)
|
||||
|
||||
try {
|
||||
const onToolsChanged: McpToolsChangedCallback = (sid) => {
|
||||
this.handleToolsChanged(sid)
|
||||
}
|
||||
|
||||
const client = new McpClient({
|
||||
config,
|
||||
securityPolicy: {
|
||||
requireConsent: false,
|
||||
auditLevel: 'basic',
|
||||
maxToolExecutionsPerHour: 1000,
|
||||
},
|
||||
onToolsChanged,
|
||||
})
|
||||
|
||||
try {
|
||||
await client.connect()
|
||||
} catch (error) {
|
||||
logger.error(`[${config.name}] Failed to connect for persistent monitoring:`, error)
|
||||
return { supportsListChanged: false }
|
||||
}
|
||||
|
||||
const supportsListChanged = client.hasListChangedCapability()
|
||||
|
||||
if (!supportsListChanged) {
|
||||
logger.info(
|
||||
`[${config.name}] Server does not support listChanged — disconnecting (fallback to cache)`
|
||||
)
|
||||
await client.disconnect()
|
||||
return { supportsListChanged: false }
|
||||
}
|
||||
|
||||
this.connections.set(serverId, client)
|
||||
this.states.set(serverId, {
|
||||
serverId,
|
||||
serverName: config.name,
|
||||
workspaceId,
|
||||
userId,
|
||||
connected: true,
|
||||
supportsListChanged: true,
|
||||
reconnectAttempts: 0,
|
||||
lastActivity: Date.now(),
|
||||
})
|
||||
|
||||
client.onClose(() => {
|
||||
this.handleDisconnect(config, userId, workspaceId)
|
||||
})
|
||||
|
||||
this.ensureIdleCheck()
|
||||
|
||||
logger.info(`[${config.name}] Persistent connection established (listChanged supported)`)
|
||||
return { supportsListChanged: true }
|
||||
} finally {
|
||||
this.connectingServers.delete(serverId)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Disconnect a managed connection.
|
||||
*/
|
||||
async disconnect(serverId: string): Promise<void> {
|
||||
this.clearReconnectTimer(serverId)
|
||||
|
||||
const client = this.connections.get(serverId)
|
||||
if (client) {
|
||||
try {
|
||||
await client.disconnect()
|
||||
} catch (error) {
|
||||
logger.warn(`Error disconnecting managed client ${serverId}:`, error)
|
||||
}
|
||||
this.connections.delete(serverId)
|
||||
}
|
||||
|
||||
this.states.delete(serverId)
|
||||
logger.info(`Managed connection removed: ${serverId}`)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether a managed connection exists for the given server.
|
||||
*/
|
||||
hasConnection(serverId: string): boolean {
|
||||
return this.connections.has(serverId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get connection state for a server.
|
||||
*/
|
||||
getState(serverId: string): ManagedConnectionState | undefined {
|
||||
return this.states.get(serverId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all managed connection states (for diagnostics).
|
||||
*/
|
||||
getAllStates(): ManagedConnectionState[] {
|
||||
return [...this.states.values()]
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose all connections and timers.
|
||||
*/
|
||||
dispose(): void {
|
||||
this.disposed = true
|
||||
|
||||
this.unsubscribePubSub?.()
|
||||
|
||||
for (const timer of this.reconnectTimers.values()) {
|
||||
clearTimeout(timer)
|
||||
}
|
||||
this.reconnectTimers.clear()
|
||||
|
||||
if (this.idleCheckTimer) {
|
||||
clearInterval(this.idleCheckTimer)
|
||||
this.idleCheckTimer = null
|
||||
}
|
||||
|
||||
const disconnects = [...this.connections.entries()].map(async ([id, client]) => {
|
||||
try {
|
||||
await client.disconnect()
|
||||
} catch (error) {
|
||||
logger.warn(`Error disconnecting ${id} during dispose:`, error)
|
||||
}
|
||||
})
|
||||
|
||||
Promise.allSettled(disconnects).then(() => {
|
||||
logger.info('Connection manager disposed')
|
||||
})
|
||||
|
||||
this.connections.clear()
|
||||
this.states.clear()
|
||||
this.listeners.clear()
|
||||
this.connectingServers.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* Notify only process-local listeners.
|
||||
* Called by the pub/sub subscription (receives events from all processes).
|
||||
*/
|
||||
private notifyLocalListeners(event: ToolsChangedEvent): void {
|
||||
for (const listener of this.listeners) {
|
||||
try {
|
||||
listener(event)
|
||||
} catch (error) {
|
||||
logger.error('Error in tools-changed listener:', error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle a tools/list_changed notification from an external MCP server.
|
||||
* Publishes to pub/sub so all processes are notified.
|
||||
*/
|
||||
private handleToolsChanged(serverId: string): void {
|
||||
const state = this.states.get(serverId)
|
||||
if (!state) return
|
||||
|
||||
state.lastActivity = Date.now()
|
||||
|
||||
const event: ToolsChangedEvent = {
|
||||
serverId,
|
||||
serverName: state.serverName,
|
||||
workspaceId: state.workspaceId,
|
||||
timestamp: Date.now(),
|
||||
}
|
||||
|
||||
logger.info(`[${state.serverName}] Tools changed — publishing to pub/sub`)
|
||||
|
||||
mcpPubSub?.publishToolsChanged(event)
|
||||
}
|
||||
|
||||
private handleDisconnect(config: McpServerConfig, userId: string, workspaceId: string): void {
|
||||
const serverId = config.id
|
||||
const state = this.states.get(serverId)
|
||||
|
||||
if (!state || this.disposed) return
|
||||
|
||||
state.connected = false
|
||||
this.connections.delete(serverId)
|
||||
|
||||
logger.warn(`[${config.name}] Persistent connection lost, scheduling reconnect`)
|
||||
|
||||
this.scheduleReconnect(config, userId, workspaceId)
|
||||
}
|
||||
|
||||
private scheduleReconnect(config: McpServerConfig, userId: string, workspaceId: string): void {
|
||||
const serverId = config.id
|
||||
const state = this.states.get(serverId)
|
||||
|
||||
if (!state || this.disposed) return
|
||||
|
||||
if (state.reconnectAttempts >= MAX_RECONNECT_ATTEMPTS) {
|
||||
logger.error(
|
||||
`[${config.name}] Max reconnect attempts (${MAX_RECONNECT_ATTEMPTS}) reached — giving up`
|
||||
)
|
||||
this.states.delete(serverId)
|
||||
return
|
||||
}
|
||||
|
||||
const delay = Math.min(BASE_RECONNECT_DELAY_MS * 2 ** state.reconnectAttempts, 60_000)
|
||||
state.reconnectAttempts++
|
||||
|
||||
logger.info(
|
||||
`[${config.name}] Reconnecting in ${delay}ms (attempt ${state.reconnectAttempts}/${MAX_RECONNECT_ATTEMPTS})`
|
||||
)
|
||||
|
||||
this.clearReconnectTimer(serverId)
|
||||
|
||||
const timer = setTimeout(async () => {
|
||||
this.reconnectTimers.delete(serverId)
|
||||
|
||||
if (this.disposed) return
|
||||
|
||||
try {
|
||||
this.connections.delete(serverId)
|
||||
this.states.delete(serverId)
|
||||
|
||||
const result = await this.connect(config, userId, workspaceId)
|
||||
if (result.supportsListChanged) {
|
||||
const newState = this.states.get(serverId)
|
||||
if (newState) {
|
||||
newState.reconnectAttempts = 0
|
||||
}
|
||||
logger.info(`[${config.name}] Reconnected successfully`)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[${config.name}] Reconnect failed:`, error)
|
||||
this.scheduleReconnect(config, userId, workspaceId)
|
||||
}
|
||||
}, delay)
|
||||
|
||||
this.reconnectTimers.set(serverId, timer)
|
||||
}
|
||||
|
||||
private clearReconnectTimer(serverId: string): void {
|
||||
const timer = this.reconnectTimers.get(serverId)
|
||||
if (timer) {
|
||||
clearTimeout(timer)
|
||||
this.reconnectTimers.delete(serverId)
|
||||
}
|
||||
}
|
||||
|
||||
private ensureIdleCheck(): void {
|
||||
if (this.idleCheckTimer) return
|
||||
|
||||
this.idleCheckTimer = setInterval(() => {
|
||||
const now = Date.now()
|
||||
for (const [serverId, state] of this.states) {
|
||||
if (now - state.lastActivity > IDLE_TIMEOUT_MS) {
|
||||
logger.info(
|
||||
`[${state.serverName}] Idle timeout reached, disconnecting managed connection`
|
||||
)
|
||||
this.disconnect(serverId)
|
||||
}
|
||||
}
|
||||
|
||||
if (this.states.size === 0 && this.idleCheckTimer) {
|
||||
clearInterval(this.idleCheckTimer)
|
||||
this.idleCheckTimer = null
|
||||
}
|
||||
}, IDLE_CHECK_INTERVAL_MS)
|
||||
}
|
||||
}
|
||||
|
||||
export const mcpConnectionManager = isTest
|
||||
? (null as unknown as McpConnectionManager)
|
||||
: new McpConnectionManager()
|
||||
93
apps/sim/lib/mcp/pubsub.test.ts
Normal file
93
apps/sim/lib/mcp/pubsub.test.ts
Normal file
@@ -0,0 +1,93 @@
|
||||
/**
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { createMockRedis, loggerMock, type MockRedis } from '@sim/testing'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
/** Extend the @sim/testing Redis mock with the methods RedisMcpPubSub uses. */
|
||||
function createPubSubRedis(): MockRedis & { removeAllListeners: ReturnType<typeof vi.fn> } {
|
||||
const mock = createMockRedis()
|
||||
// ioredis subscribe invokes a callback as the last argument
|
||||
mock.subscribe.mockImplementation((...args: unknown[]) => {
|
||||
const cb = args[args.length - 1]
|
||||
if (typeof cb === 'function') (cb as (err: null) => void)(null)
|
||||
})
|
||||
// on() returns `this` for chaining in ioredis
|
||||
mock.on.mockReturnThis()
|
||||
return { ...mock, removeAllListeners: vi.fn().mockReturnThis() }
|
||||
}
|
||||
|
||||
/** Shared setup: resets modules and applies base mocks. Returns the two Redis instances. */
|
||||
async function setupPubSub() {
|
||||
const instances: ReturnType<typeof createPubSubRedis>[] = []
|
||||
|
||||
vi.resetModules()
|
||||
vi.doMock('@sim/logger', () => loggerMock)
|
||||
vi.doMock('@/lib/core/config/env', () => ({ env: { REDIS_URL: 'redis://localhost:6379' } }))
|
||||
vi.doMock('ioredis', () => ({
|
||||
default: vi.fn().mockImplementation(() => {
|
||||
const instance = createPubSubRedis()
|
||||
instances.push(instance)
|
||||
return instance
|
||||
}),
|
||||
}))
|
||||
|
||||
const { mcpPubSub } = await import('./pubsub')
|
||||
const [pub, sub] = instances
|
||||
|
||||
return { mcpPubSub, pub, sub, instances }
|
||||
}
|
||||
|
||||
describe('RedisMcpPubSub', () => {
|
||||
it('creates two Redis clients (pub and sub)', async () => {
|
||||
const { mcpPubSub, instances } = await setupPubSub()
|
||||
|
||||
expect(instances).toHaveLength(2)
|
||||
mcpPubSub.dispose()
|
||||
})
|
||||
|
||||
it('registers error, connect, and message listeners', async () => {
|
||||
const { mcpPubSub, pub, sub } = await setupPubSub()
|
||||
|
||||
const pubEvents = pub.on.mock.calls.map((c: unknown[]) => c[0])
|
||||
const subEvents = sub.on.mock.calls.map((c: unknown[]) => c[0])
|
||||
|
||||
expect(pubEvents).toContain('error')
|
||||
expect(pubEvents).toContain('connect')
|
||||
expect(subEvents).toContain('error')
|
||||
expect(subEvents).toContain('connect')
|
||||
expect(subEvents).toContain('message')
|
||||
|
||||
mcpPubSub.dispose()
|
||||
})
|
||||
|
||||
describe('dispose', () => {
|
||||
it('calls removeAllListeners on both pub and sub before quit', async () => {
|
||||
const { mcpPubSub, pub, sub } = await setupPubSub()
|
||||
|
||||
mcpPubSub.dispose()
|
||||
|
||||
expect(pub.removeAllListeners).toHaveBeenCalledTimes(1)
|
||||
expect(sub.removeAllListeners).toHaveBeenCalledTimes(1)
|
||||
expect(sub.unsubscribe).toHaveBeenCalledTimes(1)
|
||||
expect(pub.quit).toHaveBeenCalledTimes(1)
|
||||
expect(sub.quit).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('drops publish calls after dispose', async () => {
|
||||
const { mcpPubSub, pub } = await setupPubSub()
|
||||
|
||||
mcpPubSub.dispose()
|
||||
pub.publish.mockClear()
|
||||
|
||||
mcpPubSub.publishToolsChanged({
|
||||
serverId: 'srv-1',
|
||||
serverName: 'Test',
|
||||
workspaceId: 'ws-1',
|
||||
timestamp: Date.now(),
|
||||
})
|
||||
|
||||
expect(pub.publish).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
209
apps/sim/lib/mcp/pubsub.ts
Normal file
209
apps/sim/lib/mcp/pubsub.ts
Normal file
@@ -0,0 +1,209 @@
|
||||
/**
|
||||
* MCP Pub/Sub Adapter
|
||||
*
|
||||
* Broadcasts MCP notification events across processes using Redis Pub/Sub.
|
||||
* Gracefully falls back to process-local EventEmitter when Redis is unavailable.
|
||||
*
|
||||
* Two channels:
|
||||
* - `mcp:tools_changed` — external MCP server sent a listChanged notification
|
||||
* (published by connection manager, consumed by events SSE endpoint)
|
||||
* - `mcp:workflow_tools_changed` — workflow CRUD modified a workflow MCP server's tools
|
||||
* (published by serve route, consumed by serve route on other processes to push to local SSE clients)
|
||||
*/
|
||||
|
||||
import { EventEmitter } from 'events'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import Redis from 'ioredis'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import type { ToolsChangedEvent } from '@/lib/mcp/types'
|
||||
|
||||
const logger = createLogger('McpPubSub')
|
||||
|
||||
const CHANNEL_TOOLS_CHANGED = 'mcp:tools_changed'
|
||||
const CHANNEL_WORKFLOW_TOOLS_CHANGED = 'mcp:workflow_tools_changed'
|
||||
|
||||
export interface WorkflowToolsChangedEvent {
|
||||
serverId: string
|
||||
workspaceId: string
|
||||
}
|
||||
|
||||
type ToolsChangedHandler = (event: ToolsChangedEvent) => void
|
||||
type WorkflowToolsChangedHandler = (event: WorkflowToolsChangedEvent) => void
|
||||
|
||||
interface McpPubSubAdapter {
|
||||
publishToolsChanged(event: ToolsChangedEvent): void
|
||||
publishWorkflowToolsChanged(event: WorkflowToolsChangedEvent): void
|
||||
onToolsChanged(handler: ToolsChangedHandler): () => void
|
||||
onWorkflowToolsChanged(handler: WorkflowToolsChangedHandler): () => void
|
||||
dispose(): void
|
||||
}
|
||||
|
||||
/**
|
||||
* Redis-backed pub/sub adapter.
|
||||
* Uses dedicated pub and sub clients (ioredis requires separate connections for subscribers).
|
||||
*/
|
||||
class RedisMcpPubSub implements McpPubSubAdapter {
|
||||
private pub: Redis
|
||||
private sub: Redis
|
||||
private toolsChangedHandlers = new Set<ToolsChangedHandler>()
|
||||
private workflowToolsChangedHandlers = new Set<WorkflowToolsChangedHandler>()
|
||||
private disposed = false
|
||||
|
||||
constructor(redisUrl: string) {
|
||||
const commonOpts = {
|
||||
keepAlive: 1000,
|
||||
connectTimeout: 10000,
|
||||
maxRetriesPerRequest: null as unknown as number,
|
||||
enableOfflineQueue: true,
|
||||
retryStrategy: (times: number) => {
|
||||
if (times > 10) return 30000
|
||||
return Math.min(times * 500, 5000)
|
||||
},
|
||||
}
|
||||
|
||||
this.pub = new Redis(redisUrl, { ...commonOpts, connectionName: 'mcp-pubsub-pub' })
|
||||
this.sub = new Redis(redisUrl, { ...commonOpts, connectionName: 'mcp-pubsub-sub' })
|
||||
|
||||
this.pub.on('error', (err) => logger.error('MCP pub/sub publish client error:', err.message))
|
||||
this.sub.on('error', (err) => logger.error('MCP pub/sub subscribe client error:', err.message))
|
||||
this.pub.on('connect', () => logger.info('MCP pub/sub publish client connected'))
|
||||
this.sub.on('connect', () => logger.info('MCP pub/sub subscribe client connected'))
|
||||
|
||||
this.sub.subscribe(CHANNEL_TOOLS_CHANGED, CHANNEL_WORKFLOW_TOOLS_CHANGED, (err) => {
|
||||
if (err) {
|
||||
logger.error('Failed to subscribe to MCP pub/sub channels:', err)
|
||||
} else {
|
||||
logger.info('Subscribed to MCP pub/sub channels')
|
||||
}
|
||||
})
|
||||
|
||||
this.sub.on('message', (channel: string, message: string) => {
|
||||
try {
|
||||
const parsed = JSON.parse(message)
|
||||
if (channel === CHANNEL_TOOLS_CHANGED) {
|
||||
for (const handler of this.toolsChangedHandlers) {
|
||||
try {
|
||||
handler(parsed as ToolsChangedEvent)
|
||||
} catch (err) {
|
||||
logger.error('Error in tools_changed handler:', err)
|
||||
}
|
||||
}
|
||||
} else if (channel === CHANNEL_WORKFLOW_TOOLS_CHANGED) {
|
||||
for (const handler of this.workflowToolsChangedHandlers) {
|
||||
try {
|
||||
handler(parsed as WorkflowToolsChangedEvent)
|
||||
} catch (err) {
|
||||
logger.error('Error in workflow_tools_changed handler:', err)
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('Failed to parse pub/sub message:', err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
publishToolsChanged(event: ToolsChangedEvent): void {
|
||||
if (this.disposed) return
|
||||
this.pub.publish(CHANNEL_TOOLS_CHANGED, JSON.stringify(event)).catch((err) => {
|
||||
logger.error('Failed to publish tools_changed:', err)
|
||||
})
|
||||
}
|
||||
|
||||
publishWorkflowToolsChanged(event: WorkflowToolsChangedEvent): void {
|
||||
if (this.disposed) return
|
||||
this.pub.publish(CHANNEL_WORKFLOW_TOOLS_CHANGED, JSON.stringify(event)).catch((err) => {
|
||||
logger.error('Failed to publish workflow_tools_changed:', err)
|
||||
})
|
||||
}
|
||||
|
||||
onToolsChanged(handler: ToolsChangedHandler): () => void {
|
||||
this.toolsChangedHandlers.add(handler)
|
||||
return () => {
|
||||
this.toolsChangedHandlers.delete(handler)
|
||||
}
|
||||
}
|
||||
|
||||
onWorkflowToolsChanged(handler: WorkflowToolsChangedHandler): () => void {
|
||||
this.workflowToolsChangedHandlers.add(handler)
|
||||
return () => {
|
||||
this.workflowToolsChangedHandlers.delete(handler)
|
||||
}
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
this.disposed = true
|
||||
this.toolsChangedHandlers.clear()
|
||||
this.workflowToolsChangedHandlers.clear()
|
||||
|
||||
this.pub.removeAllListeners()
|
||||
this.sub.removeAllListeners()
|
||||
|
||||
this.sub.unsubscribe().catch(() => {})
|
||||
this.pub.quit().catch(() => {})
|
||||
this.sub.quit().catch(() => {})
|
||||
logger.info('Redis MCP pub/sub disposed')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process-local fallback using EventEmitter.
|
||||
* Used when Redis is not configured — notifications only reach listeners in the same process.
|
||||
*/
|
||||
class LocalMcpPubSub implements McpPubSubAdapter {
|
||||
private emitter = new EventEmitter()
|
||||
|
||||
constructor() {
|
||||
this.emitter.setMaxListeners(100)
|
||||
logger.info('MCP pub/sub: Using process-local EventEmitter (Redis not configured)')
|
||||
}
|
||||
|
||||
publishToolsChanged(event: ToolsChangedEvent): void {
|
||||
this.emitter.emit(CHANNEL_TOOLS_CHANGED, event)
|
||||
}
|
||||
|
||||
publishWorkflowToolsChanged(event: WorkflowToolsChangedEvent): void {
|
||||
this.emitter.emit(CHANNEL_WORKFLOW_TOOLS_CHANGED, event)
|
||||
}
|
||||
|
||||
onToolsChanged(handler: ToolsChangedHandler): () => void {
|
||||
this.emitter.on(CHANNEL_TOOLS_CHANGED, handler)
|
||||
return () => {
|
||||
this.emitter.off(CHANNEL_TOOLS_CHANGED, handler)
|
||||
}
|
||||
}
|
||||
|
||||
onWorkflowToolsChanged(handler: WorkflowToolsChangedHandler): () => void {
|
||||
this.emitter.on(CHANNEL_WORKFLOW_TOOLS_CHANGED, handler)
|
||||
return () => {
|
||||
this.emitter.off(CHANNEL_WORKFLOW_TOOLS_CHANGED, handler)
|
||||
}
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
this.emitter.removeAllListeners()
|
||||
logger.info('Local MCP pub/sub disposed')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the appropriate pub/sub adapter based on Redis availability.
|
||||
*/
|
||||
function createMcpPubSub(): McpPubSubAdapter {
|
||||
const redisUrl = env.REDIS_URL
|
||||
|
||||
if (redisUrl) {
|
||||
try {
|
||||
logger.info('MCP pub/sub: Using Redis')
|
||||
return new RedisMcpPubSub(redisUrl)
|
||||
} catch (err) {
|
||||
logger.error('Failed to create Redis pub/sub, falling back to local:', err)
|
||||
return new LocalMcpPubSub()
|
||||
}
|
||||
}
|
||||
|
||||
return new LocalMcpPubSub()
|
||||
}
|
||||
|
||||
export const mcpPubSub: McpPubSubAdapter =
|
||||
typeof window !== 'undefined' ? (null as unknown as McpPubSubAdapter) : createMcpPubSub()
|
||||
@@ -147,6 +147,44 @@ export interface McpServerSummary {
|
||||
error?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Callback invoked when an MCP server sends a `notifications/tools/list_changed` notification.
|
||||
*/
|
||||
export type McpToolsChangedCallback = (serverId: string) => void
|
||||
|
||||
/**
|
||||
* Options for creating an McpClient with notification support.
|
||||
*/
|
||||
export interface McpClientOptions {
|
||||
config: McpServerConfig
|
||||
securityPolicy?: McpSecurityPolicy
|
||||
onToolsChanged?: McpToolsChangedCallback
|
||||
}
|
||||
|
||||
/**
|
||||
* Event emitted by the connection manager when a server's tools change.
|
||||
*/
|
||||
export interface ToolsChangedEvent {
|
||||
serverId: string
|
||||
serverName: string
|
||||
workspaceId: string
|
||||
timestamp: number
|
||||
}
|
||||
|
||||
/**
|
||||
* State of a managed persistent connection.
|
||||
*/
|
||||
export interface ManagedConnectionState {
|
||||
serverId: string
|
||||
serverName: string
|
||||
workspaceId: string
|
||||
userId: string
|
||||
connected: boolean
|
||||
supportsListChanged: boolean
|
||||
reconnectAttempts: number
|
||||
lastActivity: number
|
||||
}
|
||||
|
||||
export interface McpApiResponse<T = unknown> {
|
||||
success: boolean
|
||||
data?: T
|
||||
|
||||
Reference in New Issue
Block a user