mirror of
https://github.com/simstudioai/sim.git
synced 2026-02-09 14:14:57 -05:00
Compare commits
5 Commits
feat/the-c
...
sim-609
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f844e2570f | ||
|
|
7c6927d643 | ||
|
|
cf25fb0843 | ||
|
|
4193007ab7 | ||
|
|
f9b885f6d5 |
98
apps/sim/app/api/mcp/events/route.test.ts
Normal file
98
apps/sim/app/api/mcp/events/route.test.ts
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
/**
|
||||||
|
* Tests for MCP SSE events endpoint
|
||||||
|
*
|
||||||
|
* @vitest-environment node
|
||||||
|
*/
|
||||||
|
import { createMockRequest, mockAuth, mockConsoleLogger } from '@sim/testing'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
mockConsoleLogger()
|
||||||
|
const auth = mockAuth()
|
||||||
|
|
||||||
|
const mockGetUserEntityPermissions = vi.fn()
|
||||||
|
vi.doMock('@/lib/workspaces/permissions/utils', () => ({
|
||||||
|
getUserEntityPermissions: mockGetUserEntityPermissions,
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.doMock('@/lib/mcp/connection-manager', () => ({
|
||||||
|
mcpConnectionManager: null,
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.doMock('@/lib/mcp/pubsub', () => ({
|
||||||
|
mcpPubSub: null,
|
||||||
|
}))
|
||||||
|
|
||||||
|
const { GET } = await import('./route')
|
||||||
|
|
||||||
|
describe('MCP Events SSE Endpoint', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns 401 when session is missing', async () => {
|
||||||
|
auth.setUnauthenticated()
|
||||||
|
|
||||||
|
const request = createMockRequest(
|
||||||
|
'GET',
|
||||||
|
undefined,
|
||||||
|
{},
|
||||||
|
'http://localhost:3000/api/mcp/events?workspaceId=ws-123'
|
||||||
|
)
|
||||||
|
|
||||||
|
const response = await GET(request as any)
|
||||||
|
|
||||||
|
expect(response.status).toBe(401)
|
||||||
|
const text = await response.text()
|
||||||
|
expect(text).toBe('Unauthorized')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns 400 when workspaceId is missing', async () => {
|
||||||
|
auth.setAuthenticated()
|
||||||
|
|
||||||
|
const request = createMockRequest('GET', undefined, {}, 'http://localhost:3000/api/mcp/events')
|
||||||
|
|
||||||
|
const response = await GET(request as any)
|
||||||
|
|
||||||
|
expect(response.status).toBe(400)
|
||||||
|
const text = await response.text()
|
||||||
|
expect(text).toBe('Missing workspaceId query parameter')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns 403 when user lacks workspace access', async () => {
|
||||||
|
auth.setAuthenticated()
|
||||||
|
mockGetUserEntityPermissions.mockResolvedValue(null)
|
||||||
|
|
||||||
|
const request = createMockRequest(
|
||||||
|
'GET',
|
||||||
|
undefined,
|
||||||
|
{},
|
||||||
|
'http://localhost:3000/api/mcp/events?workspaceId=ws-123'
|
||||||
|
)
|
||||||
|
|
||||||
|
const response = await GET(request as any)
|
||||||
|
|
||||||
|
expect(response.status).toBe(403)
|
||||||
|
const text = await response.text()
|
||||||
|
expect(text).toBe('Access denied to workspace')
|
||||||
|
expect(mockGetUserEntityPermissions).toHaveBeenCalledWith('user-123', 'workspace', 'ws-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns SSE stream when authorized', async () => {
|
||||||
|
auth.setAuthenticated()
|
||||||
|
mockGetUserEntityPermissions.mockResolvedValue({ read: true })
|
||||||
|
|
||||||
|
const request = createMockRequest(
|
||||||
|
'GET',
|
||||||
|
undefined,
|
||||||
|
{},
|
||||||
|
'http://localhost:3000/api/mcp/events?workspaceId=ws-123'
|
||||||
|
)
|
||||||
|
|
||||||
|
const response = await GET(request as any)
|
||||||
|
|
||||||
|
expect(response.status).toBe(200)
|
||||||
|
expect(response.headers.get('Content-Type')).toBe('text/event-stream')
|
||||||
|
expect(response.headers.get('Cache-Control')).toBe('no-cache')
|
||||||
|
expect(response.headers.get('Connection')).toBe('keep-alive')
|
||||||
|
})
|
||||||
|
})
|
||||||
111
apps/sim/app/api/mcp/events/route.ts
Normal file
111
apps/sim/app/api/mcp/events/route.ts
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
/**
|
||||||
|
* SSE endpoint for MCP tool-change events.
|
||||||
|
*
|
||||||
|
* Pushes `tools_changed` events to the browser when:
|
||||||
|
* - An external MCP server sends `notifications/tools/list_changed` (via connection manager)
|
||||||
|
* - A workflow CRUD route modifies workflow MCP server tools (via pub/sub)
|
||||||
|
*
|
||||||
|
* Auth is handled via session cookies (EventSource sends cookies automatically).
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { createLogger } from '@sim/logger'
|
||||||
|
import type { NextRequest } from 'next/server'
|
||||||
|
import { getSession } from '@/lib/auth'
|
||||||
|
import { SSE_HEADERS } from '@/lib/core/utils/sse'
|
||||||
|
import { mcpConnectionManager } from '@/lib/mcp/connection-manager'
|
||||||
|
import { mcpPubSub } from '@/lib/mcp/pubsub'
|
||||||
|
import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils'
|
||||||
|
|
||||||
|
const logger = createLogger('McpEventsSSE')
|
||||||
|
|
||||||
|
export const dynamic = 'force-dynamic'
|
||||||
|
|
||||||
|
const HEARTBEAT_INTERVAL_MS = 30_000
|
||||||
|
|
||||||
|
export async function GET(request: NextRequest) {
|
||||||
|
const session = await getSession()
|
||||||
|
if (!session?.user?.id) {
|
||||||
|
return new Response('Unauthorized', { status: 401 })
|
||||||
|
}
|
||||||
|
|
||||||
|
const { searchParams } = new URL(request.url)
|
||||||
|
const workspaceId = searchParams.get('workspaceId')
|
||||||
|
if (!workspaceId) {
|
||||||
|
return new Response('Missing workspaceId query parameter', { status: 400 })
|
||||||
|
}
|
||||||
|
|
||||||
|
const permissions = await getUserEntityPermissions(session.user.id, 'workspace', workspaceId)
|
||||||
|
if (!permissions) {
|
||||||
|
return new Response('Access denied to workspace', { status: 403 })
|
||||||
|
}
|
||||||
|
|
||||||
|
const encoder = new TextEncoder()
|
||||||
|
const unsubscribers: Array<() => void> = []
|
||||||
|
|
||||||
|
const stream = new ReadableStream({
|
||||||
|
start(controller) {
|
||||||
|
const send = (eventName: string, data: Record<string, unknown>) => {
|
||||||
|
try {
|
||||||
|
controller.enqueue(
|
||||||
|
encoder.encode(`event: ${eventName}\ndata: ${JSON.stringify(data)}\n\n`)
|
||||||
|
)
|
||||||
|
} catch {
|
||||||
|
// Stream already closed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to external MCP server tool changes
|
||||||
|
if (mcpConnectionManager) {
|
||||||
|
const unsub = mcpConnectionManager.subscribe((event) => {
|
||||||
|
if (event.workspaceId !== workspaceId) return
|
||||||
|
send('tools_changed', {
|
||||||
|
source: 'external',
|
||||||
|
serverId: event.serverId,
|
||||||
|
timestamp: event.timestamp,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
unsubscribers.push(unsub)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to workflow CRUD tool changes
|
||||||
|
if (mcpPubSub) {
|
||||||
|
const unsub = mcpPubSub.onWorkflowToolsChanged((event) => {
|
||||||
|
if (event.workspaceId !== workspaceId) return
|
||||||
|
send('tools_changed', {
|
||||||
|
source: 'workflow',
|
||||||
|
serverId: event.serverId,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
unsubscribers.push(unsub)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heartbeat to keep the connection alive
|
||||||
|
const heartbeat = setInterval(() => {
|
||||||
|
try {
|
||||||
|
controller.enqueue(encoder.encode(': heartbeat\n\n'))
|
||||||
|
} catch {
|
||||||
|
clearInterval(heartbeat)
|
||||||
|
}
|
||||||
|
}, HEARTBEAT_INTERVAL_MS)
|
||||||
|
unsubscribers.push(() => clearInterval(heartbeat))
|
||||||
|
|
||||||
|
// Cleanup when client disconnects
|
||||||
|
request.signal.addEventListener('abort', () => {
|
||||||
|
for (const unsub of unsubscribers) {
|
||||||
|
unsub()
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
controller.close()
|
||||||
|
} catch {
|
||||||
|
// Already closed
|
||||||
|
}
|
||||||
|
logger.info(`SSE connection closed for workspace ${workspaceId}`)
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(`SSE connection opened for workspace ${workspaceId}`)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return new Response(stream, { headers: SSE_HEADERS })
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import { createLogger } from '@sim/logger'
|
|||||||
import { and, eq } from 'drizzle-orm'
|
import { and, eq } from 'drizzle-orm'
|
||||||
import type { NextRequest } from 'next/server'
|
import type { NextRequest } from 'next/server'
|
||||||
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
|
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
|
||||||
|
import { mcpPubSub } from '@/lib/mcp/pubsub'
|
||||||
import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils'
|
import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils'
|
||||||
|
|
||||||
const logger = createLogger('WorkflowMcpServerAPI')
|
const logger = createLogger('WorkflowMcpServerAPI')
|
||||||
@@ -146,6 +147,8 @@ export const DELETE = withMcpAuth<RouteParams>('admin')(
|
|||||||
|
|
||||||
logger.info(`[${requestId}] Successfully deleted workflow MCP server: ${serverId}`)
|
logger.info(`[${requestId}] Successfully deleted workflow MCP server: ${serverId}`)
|
||||||
|
|
||||||
|
mcpPubSub?.publishWorkflowToolsChanged({ serverId, workspaceId })
|
||||||
|
|
||||||
return createMcpSuccessResponse({ message: `Server ${serverId} deleted successfully` })
|
return createMcpSuccessResponse({ message: `Server ${serverId} deleted successfully` })
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`[${requestId}] Error deleting workflow MCP server:`, error)
|
logger.error(`[${requestId}] Error deleting workflow MCP server:`, error)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import { createLogger } from '@sim/logger'
|
|||||||
import { and, eq } from 'drizzle-orm'
|
import { and, eq } from 'drizzle-orm'
|
||||||
import type { NextRequest } from 'next/server'
|
import type { NextRequest } from 'next/server'
|
||||||
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
|
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
|
||||||
|
import { mcpPubSub } from '@/lib/mcp/pubsub'
|
||||||
import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils'
|
import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils'
|
||||||
import { sanitizeToolName } from '@/lib/mcp/workflow-tool-schema'
|
import { sanitizeToolName } from '@/lib/mcp/workflow-tool-schema'
|
||||||
|
|
||||||
@@ -115,6 +116,8 @@ export const PATCH = withMcpAuth<RouteParams>('write')(
|
|||||||
|
|
||||||
logger.info(`[${requestId}] Successfully updated tool ${toolId}`)
|
logger.info(`[${requestId}] Successfully updated tool ${toolId}`)
|
||||||
|
|
||||||
|
mcpPubSub?.publishWorkflowToolsChanged({ serverId, workspaceId })
|
||||||
|
|
||||||
return createMcpSuccessResponse({ tool: updatedTool })
|
return createMcpSuccessResponse({ tool: updatedTool })
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`[${requestId}] Error updating tool:`, error)
|
logger.error(`[${requestId}] Error updating tool:`, error)
|
||||||
@@ -160,6 +163,8 @@ export const DELETE = withMcpAuth<RouteParams>('write')(
|
|||||||
|
|
||||||
logger.info(`[${requestId}] Successfully deleted tool ${toolId}`)
|
logger.info(`[${requestId}] Successfully deleted tool ${toolId}`)
|
||||||
|
|
||||||
|
mcpPubSub?.publishWorkflowToolsChanged({ serverId, workspaceId })
|
||||||
|
|
||||||
return createMcpSuccessResponse({ message: `Tool ${toolId} deleted successfully` })
|
return createMcpSuccessResponse({ message: `Tool ${toolId} deleted successfully` })
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`[${requestId}] Error deleting tool:`, error)
|
logger.error(`[${requestId}] Error deleting tool:`, error)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import { createLogger } from '@sim/logger'
|
|||||||
import { and, eq } from 'drizzle-orm'
|
import { and, eq } from 'drizzle-orm'
|
||||||
import type { NextRequest } from 'next/server'
|
import type { NextRequest } from 'next/server'
|
||||||
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
|
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
|
||||||
|
import { mcpPubSub } from '@/lib/mcp/pubsub'
|
||||||
import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils'
|
import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils'
|
||||||
import { sanitizeToolName } from '@/lib/mcp/workflow-tool-schema'
|
import { sanitizeToolName } from '@/lib/mcp/workflow-tool-schema'
|
||||||
import { hasValidStartBlock } from '@/lib/workflows/triggers/trigger-utils.server'
|
import { hasValidStartBlock } from '@/lib/workflows/triggers/trigger-utils.server'
|
||||||
@@ -188,6 +189,8 @@ export const POST = withMcpAuth<RouteParams>('write')(
|
|||||||
`[${requestId}] Successfully added tool ${toolName} (workflow: ${body.workflowId}) to server ${serverId}`
|
`[${requestId}] Successfully added tool ${toolName} (workflow: ${body.workflowId}) to server ${serverId}`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mcpPubSub?.publishWorkflowToolsChanged({ serverId, workspaceId })
|
||||||
|
|
||||||
return createMcpSuccessResponse({ tool }, 201)
|
return createMcpSuccessResponse({ tool }, 201)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`[${requestId}] Error adding tool:`, error)
|
logger.error(`[${requestId}] Error adding tool:`, error)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import { createLogger } from '@sim/logger'
|
|||||||
import { eq, inArray, sql } from 'drizzle-orm'
|
import { eq, inArray, sql } from 'drizzle-orm'
|
||||||
import type { NextRequest } from 'next/server'
|
import type { NextRequest } from 'next/server'
|
||||||
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
|
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
|
||||||
|
import { mcpPubSub } from '@/lib/mcp/pubsub'
|
||||||
import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils'
|
import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils'
|
||||||
import { sanitizeToolName } from '@/lib/mcp/workflow-tool-schema'
|
import { sanitizeToolName } from '@/lib/mcp/workflow-tool-schema'
|
||||||
import { hasValidStartBlock } from '@/lib/workflows/triggers/trigger-utils.server'
|
import { hasValidStartBlock } from '@/lib/workflows/triggers/trigger-utils.server'
|
||||||
@@ -174,6 +175,10 @@ export const POST = withMcpAuth('write')(
|
|||||||
`[${requestId}] Added ${addedTools.length} tools to server ${serverId}:`,
|
`[${requestId}] Added ${addedTools.length} tools to server ${serverId}:`,
|
||||||
addedTools.map((t) => t.toolName)
|
addedTools.map((t) => t.toolName)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (addedTools.length > 0) {
|
||||||
|
mcpPubSub?.publishWorkflowToolsChanged({ serverId, workspaceId })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ interface ApiDeployProps {
|
|||||||
deploymentInfo: WorkflowDeploymentInfo | null
|
deploymentInfo: WorkflowDeploymentInfo | null
|
||||||
isLoading: boolean
|
isLoading: boolean
|
||||||
needsRedeployment: boolean
|
needsRedeployment: boolean
|
||||||
apiDeployError: string | null
|
|
||||||
getInputFormatExample: (includeStreaming?: boolean) => string
|
getInputFormatExample: (includeStreaming?: boolean) => string
|
||||||
selectedStreamingOutputs: string[]
|
selectedStreamingOutputs: string[]
|
||||||
onSelectedStreamingOutputsChange: (outputs: string[]) => void
|
onSelectedStreamingOutputsChange: (outputs: string[]) => void
|
||||||
@@ -63,7 +62,6 @@ export function ApiDeploy({
|
|||||||
deploymentInfo,
|
deploymentInfo,
|
||||||
isLoading,
|
isLoading,
|
||||||
needsRedeployment,
|
needsRedeployment,
|
||||||
apiDeployError,
|
|
||||||
getInputFormatExample,
|
getInputFormatExample,
|
||||||
selectedStreamingOutputs,
|
selectedStreamingOutputs,
|
||||||
onSelectedStreamingOutputsChange,
|
onSelectedStreamingOutputsChange,
|
||||||
@@ -419,12 +417,6 @@ console.log(limits);`
|
|||||||
if (isLoading || !info) {
|
if (isLoading || !info) {
|
||||||
return (
|
return (
|
||||||
<div className='space-y-[16px]'>
|
<div className='space-y-[16px]'>
|
||||||
{apiDeployError && (
|
|
||||||
<div className='rounded-[4px] border border-destructive/30 bg-destructive/10 p-3 text-destructive text-sm'>
|
|
||||||
<div className='font-semibold'>API Deployment Error</div>
|
|
||||||
<div>{apiDeployError}</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
<div>
|
<div>
|
||||||
<Skeleton className='mb-[6.5px] h-[16px] w-[62px]' />
|
<Skeleton className='mb-[6.5px] h-[16px] w-[62px]' />
|
||||||
<Skeleton className='h-[28px] w-[260px] rounded-[4px]' />
|
<Skeleton className='h-[28px] w-[260px] rounded-[4px]' />
|
||||||
@@ -443,13 +435,6 @@ console.log(limits);`
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className='space-y-[16px]'>
|
<div className='space-y-[16px]'>
|
||||||
{apiDeployError && (
|
|
||||||
<div className='rounded-[4px] border border-destructive/30 bg-destructive/10 p-3 text-destructive text-sm'>
|
|
||||||
<div className='font-semibold'>API Deployment Error</div>
|
|
||||||
<div>{apiDeployError}</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
<div className='mb-[6.5px] flex items-center justify-between'>
|
<div className='mb-[6.5px] flex items-center justify-between'>
|
||||||
<Label className='block pl-[2px] font-medium text-[13px] text-[var(--text-primary)]'>
|
<Label className='block pl-[2px] font-medium text-[13px] text-[var(--text-primary)]'>
|
||||||
|
|||||||
@@ -94,8 +94,8 @@ export function DeployModal({
|
|||||||
const workflowWorkspaceId = workflowMetadata?.workspaceId ?? null
|
const workflowWorkspaceId = workflowMetadata?.workspaceId ?? null
|
||||||
const [activeTab, setActiveTab] = useState<TabView>('general')
|
const [activeTab, setActiveTab] = useState<TabView>('general')
|
||||||
const [chatSubmitting, setChatSubmitting] = useState(false)
|
const [chatSubmitting, setChatSubmitting] = useState(false)
|
||||||
const [apiDeployError, setApiDeployError] = useState<string | null>(null)
|
const [deployError, setDeployError] = useState<string | null>(null)
|
||||||
const [apiDeployWarnings, setApiDeployWarnings] = useState<string[]>([])
|
const [deployWarnings, setDeployWarnings] = useState<string[]>([])
|
||||||
const [isChatFormValid, setIsChatFormValid] = useState(false)
|
const [isChatFormValid, setIsChatFormValid] = useState(false)
|
||||||
const [selectedStreamingOutputs, setSelectedStreamingOutputs] = useState<string[]>([])
|
const [selectedStreamingOutputs, setSelectedStreamingOutputs] = useState<string[]>([])
|
||||||
|
|
||||||
@@ -225,8 +225,8 @@ export function DeployModal({
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (open && workflowId) {
|
if (open && workflowId) {
|
||||||
setActiveTab('general')
|
setActiveTab('general')
|
||||||
setApiDeployError(null)
|
setDeployError(null)
|
||||||
setApiDeployWarnings([])
|
setDeployWarnings([])
|
||||||
}
|
}
|
||||||
}, [open, workflowId])
|
}, [open, workflowId])
|
||||||
|
|
||||||
@@ -281,19 +281,19 @@ export function DeployModal({
|
|||||||
const onDeploy = useCallback(async () => {
|
const onDeploy = useCallback(async () => {
|
||||||
if (!workflowId) return
|
if (!workflowId) return
|
||||||
|
|
||||||
setApiDeployError(null)
|
setDeployError(null)
|
||||||
setApiDeployWarnings([])
|
setDeployWarnings([])
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await deployMutation.mutateAsync({ workflowId, deployChatEnabled: false })
|
const result = await deployMutation.mutateAsync({ workflowId, deployChatEnabled: false })
|
||||||
if (result.warnings && result.warnings.length > 0) {
|
if (result.warnings && result.warnings.length > 0) {
|
||||||
setApiDeployWarnings(result.warnings)
|
setDeployWarnings(result.warnings)
|
||||||
}
|
}
|
||||||
await refetchDeployedState()
|
await refetchDeployedState()
|
||||||
} catch (error: unknown) {
|
} catch (error: unknown) {
|
||||||
logger.error('Error deploying workflow:', { error })
|
logger.error('Error deploying workflow:', { error })
|
||||||
const errorMessage = error instanceof Error ? error.message : 'Failed to deploy workflow'
|
const errorMessage = error instanceof Error ? error.message : 'Failed to deploy workflow'
|
||||||
setApiDeployError(errorMessage)
|
setDeployError(errorMessage)
|
||||||
}
|
}
|
||||||
}, [workflowId, deployMutation, refetchDeployedState])
|
}, [workflowId, deployMutation, refetchDeployedState])
|
||||||
|
|
||||||
@@ -301,12 +301,12 @@ export function DeployModal({
|
|||||||
async (version: number) => {
|
async (version: number) => {
|
||||||
if (!workflowId) return
|
if (!workflowId) return
|
||||||
|
|
||||||
setApiDeployWarnings([])
|
setDeployWarnings([])
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await activateVersionMutation.mutateAsync({ workflowId, version })
|
const result = await activateVersionMutation.mutateAsync({ workflowId, version })
|
||||||
if (result.warnings && result.warnings.length > 0) {
|
if (result.warnings && result.warnings.length > 0) {
|
||||||
setApiDeployWarnings(result.warnings)
|
setDeployWarnings(result.warnings)
|
||||||
}
|
}
|
||||||
await refetchDeployedState()
|
await refetchDeployedState()
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -332,26 +332,26 @@ export function DeployModal({
|
|||||||
const handleRedeploy = useCallback(async () => {
|
const handleRedeploy = useCallback(async () => {
|
||||||
if (!workflowId) return
|
if (!workflowId) return
|
||||||
|
|
||||||
setApiDeployError(null)
|
setDeployError(null)
|
||||||
setApiDeployWarnings([])
|
setDeployWarnings([])
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await deployMutation.mutateAsync({ workflowId, deployChatEnabled: false })
|
const result = await deployMutation.mutateAsync({ workflowId, deployChatEnabled: false })
|
||||||
if (result.warnings && result.warnings.length > 0) {
|
if (result.warnings && result.warnings.length > 0) {
|
||||||
setApiDeployWarnings(result.warnings)
|
setDeployWarnings(result.warnings)
|
||||||
}
|
}
|
||||||
await refetchDeployedState()
|
await refetchDeployedState()
|
||||||
} catch (error: unknown) {
|
} catch (error: unknown) {
|
||||||
logger.error('Error redeploying workflow:', { error })
|
logger.error('Error redeploying workflow:', { error })
|
||||||
const errorMessage = error instanceof Error ? error.message : 'Failed to redeploy workflow'
|
const errorMessage = error instanceof Error ? error.message : 'Failed to redeploy workflow'
|
||||||
setApiDeployError(errorMessage)
|
setDeployError(errorMessage)
|
||||||
}
|
}
|
||||||
}, [workflowId, deployMutation, refetchDeployedState])
|
}, [workflowId, deployMutation, refetchDeployedState])
|
||||||
|
|
||||||
const handleCloseModal = useCallback(() => {
|
const handleCloseModal = useCallback(() => {
|
||||||
setChatSubmitting(false)
|
setChatSubmitting(false)
|
||||||
setApiDeployError(null)
|
setDeployError(null)
|
||||||
setApiDeployWarnings([])
|
setDeployWarnings([])
|
||||||
onOpenChange(false)
|
onOpenChange(false)
|
||||||
}, [onOpenChange])
|
}, [onOpenChange])
|
||||||
|
|
||||||
@@ -483,17 +483,23 @@ export function DeployModal({
|
|||||||
</ModalTabsList>
|
</ModalTabsList>
|
||||||
|
|
||||||
<ModalBody className='min-h-0 flex-1'>
|
<ModalBody className='min-h-0 flex-1'>
|
||||||
{apiDeployError && (
|
{(deployError || deployWarnings.length > 0) && (
|
||||||
<div className='mb-3 rounded-[4px] border border-destructive/30 bg-destructive/10 p-3 text-destructive text-sm'>
|
<div className='mb-3 flex flex-col gap-2'>
|
||||||
<div className='font-semibold'>Deployment Error</div>
|
{deployError && (
|
||||||
<div>{apiDeployError}</div>
|
<Badge variant='red' size='lg' dot className='max-w-full truncate'>
|
||||||
</div>
|
{deployError}
|
||||||
)}
|
</Badge>
|
||||||
{apiDeployWarnings.length > 0 && (
|
)}
|
||||||
<div className='mb-3 rounded-[4px] border border-amber-500/30 bg-amber-500/10 p-3 text-amber-700 text-sm dark:text-amber-400'>
|
{deployWarnings.map((warning, index) => (
|
||||||
<div className='font-semibold'>Deployment Warning</div>
|
<Badge
|
||||||
{apiDeployWarnings.map((warning, index) => (
|
key={index}
|
||||||
<div key={index}>{warning}</div>
|
variant='amber'
|
||||||
|
size='lg'
|
||||||
|
dot
|
||||||
|
className='max-w-full truncate'
|
||||||
|
>
|
||||||
|
{warning}
|
||||||
|
</Badge>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -515,7 +521,6 @@ export function DeployModal({
|
|||||||
deploymentInfo={deploymentInfo}
|
deploymentInfo={deploymentInfo}
|
||||||
isLoading={isLoadingDeploymentInfo}
|
isLoading={isLoadingDeploymentInfo}
|
||||||
needsRedeployment={needsRedeployment}
|
needsRedeployment={needsRedeployment}
|
||||||
apiDeployError={apiDeployError}
|
|
||||||
getInputFormatExample={getInputFormatExample}
|
getInputFormatExample={getInputFormatExample}
|
||||||
selectedStreamingOutputs={selectedStreamingOutputs}
|
selectedStreamingOutputs={selectedStreamingOutputs}
|
||||||
onSelectedStreamingOutputsChange={setSelectedStreamingOutputs}
|
onSelectedStreamingOutputsChange={setSelectedStreamingOutputs}
|
||||||
|
|||||||
@@ -62,7 +62,12 @@ import {
|
|||||||
type CustomTool as CustomToolDefinition,
|
type CustomTool as CustomToolDefinition,
|
||||||
useCustomTools,
|
useCustomTools,
|
||||||
} from '@/hooks/queries/custom-tools'
|
} from '@/hooks/queries/custom-tools'
|
||||||
import { useForceRefreshMcpTools, useMcpServers, useStoredMcpTools } from '@/hooks/queries/mcp'
|
import {
|
||||||
|
useForceRefreshMcpTools,
|
||||||
|
useMcpServers,
|
||||||
|
useMcpToolsEvents,
|
||||||
|
useStoredMcpTools,
|
||||||
|
} from '@/hooks/queries/mcp'
|
||||||
import {
|
import {
|
||||||
useChildDeploymentStatus,
|
useChildDeploymentStatus,
|
||||||
useDeployChildWorkflow,
|
useDeployChildWorkflow,
|
||||||
@@ -1035,6 +1040,7 @@ export const ToolInput = memo(function ToolInput({
|
|||||||
const { data: mcpServers = [], isLoading: mcpServersLoading } = useMcpServers(workspaceId)
|
const { data: mcpServers = [], isLoading: mcpServersLoading } = useMcpServers(workspaceId)
|
||||||
const { data: storedMcpTools = [] } = useStoredMcpTools(workspaceId)
|
const { data: storedMcpTools = [] } = useStoredMcpTools(workspaceId)
|
||||||
const forceRefreshMcpTools = useForceRefreshMcpTools()
|
const forceRefreshMcpTools = useForceRefreshMcpTools()
|
||||||
|
useMcpToolsEvents(workspaceId)
|
||||||
const openSettingsModal = useSettingsModalStore((state) => state.openModal)
|
const openSettingsModal = useSettingsModalStore((state) => state.openModal)
|
||||||
const mcpDataLoading = mcpLoading || mcpServersLoading
|
const mcpDataLoading = mcpLoading || mcpServersLoading
|
||||||
|
|
||||||
|
|||||||
@@ -1151,7 +1151,7 @@ export const Terminal = memo(function Terminal() {
|
|||||||
<aside
|
<aside
|
||||||
ref={terminalRef}
|
ref={terminalRef}
|
||||||
className={clsx(
|
className={clsx(
|
||||||
'terminal-container fixed right-[var(--panel-width)] bottom-0 left-[var(--sidebar-width)] z-10 overflow-hidden bg-[var(--surface-1)]',
|
'terminal-container fixed right-[var(--panel-width)] bottom-0 left-[var(--sidebar-width)] z-10 overflow-hidden border-[var(--border)] border-t bg-[var(--surface-1)]',
|
||||||
isToggling && 'transition-[height] duration-100 ease-out'
|
isToggling && 'transition-[height] duration-100 ease-out'
|
||||||
)}
|
)}
|
||||||
onTransitionEnd={handleTransitionEnd}
|
onTransitionEnd={handleTransitionEnd}
|
||||||
@@ -1160,7 +1160,7 @@ export const Terminal = memo(function Terminal() {
|
|||||||
tabIndex={-1}
|
tabIndex={-1}
|
||||||
aria-label='Terminal'
|
aria-label='Terminal'
|
||||||
>
|
>
|
||||||
<div className='relative flex h-full border-[var(--border)] border-t'>
|
<div className='relative flex h-full'>
|
||||||
{/* Left Section - Logs */}
|
{/* Left Section - Logs */}
|
||||||
<div
|
<div
|
||||||
className={clsx('flex flex-col', !selectedEntry && 'flex-1')}
|
className={clsx('flex flex-col', !selectedEntry && 'flex-1')}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import { useEffect } from 'react'
|
||||||
import { createLogger } from '@sim/logger'
|
import { createLogger } from '@sim/logger'
|
||||||
import { keepPreviousData, useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
|
import { keepPreviousData, useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
|
||||||
import { sanitizeForHttp, sanitizeHeaders } from '@/lib/mcp/shared'
|
import { sanitizeForHttp, sanitizeHeaders } from '@/lib/mcp/shared'
|
||||||
@@ -359,3 +360,65 @@ export function useStoredMcpTools(workspaceId: string) {
|
|||||||
staleTime: 60 * 1000,
|
staleTime: 60 * 1000,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shared EventSource connections keyed by workspaceId.
|
||||||
|
* Reference-counted so the connection is closed when the last consumer unmounts.
|
||||||
|
* Attached to `globalThis` so connections survive HMR in development.
|
||||||
|
*/
|
||||||
|
const SSE_KEY = '__mcp_sse_connections' as const
|
||||||
|
|
||||||
|
type SseEntry = { source: EventSource; refs: number }
|
||||||
|
|
||||||
|
const sseConnections: Map<string, SseEntry> =
|
||||||
|
((globalThis as Record<string, unknown>)[SSE_KEY] as Map<string, SseEntry>) ??
|
||||||
|
((globalThis as Record<string, unknown>)[SSE_KEY] = new Map<string, SseEntry>())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subscribe to MCP tool-change SSE events for a workspace.
|
||||||
|
* On each `tools_changed` event, invalidates the relevant React Query caches
|
||||||
|
* so the UI refreshes automatically.
|
||||||
|
*/
|
||||||
|
export function useMcpToolsEvents(workspaceId: string) {
|
||||||
|
const queryClient = useQueryClient()
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!workspaceId) return
|
||||||
|
|
||||||
|
const invalidate = () => {
|
||||||
|
queryClient.invalidateQueries({ queryKey: mcpKeys.tools(workspaceId) })
|
||||||
|
queryClient.invalidateQueries({ queryKey: mcpKeys.servers(workspaceId) })
|
||||||
|
queryClient.invalidateQueries({ queryKey: mcpKeys.storedTools(workspaceId) })
|
||||||
|
}
|
||||||
|
|
||||||
|
let entry = sseConnections.get(workspaceId)
|
||||||
|
|
||||||
|
if (!entry) {
|
||||||
|
const source = new EventSource(`/api/mcp/events?workspaceId=${workspaceId}`)
|
||||||
|
|
||||||
|
source.addEventListener('tools_changed', () => {
|
||||||
|
invalidate()
|
||||||
|
})
|
||||||
|
|
||||||
|
source.onerror = () => {
|
||||||
|
logger.warn(`SSE connection error for workspace ${workspaceId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry = { source, refs: 0 }
|
||||||
|
sseConnections.set(workspaceId, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.refs++
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
const current = sseConnections.get(workspaceId)
|
||||||
|
if (!current) return
|
||||||
|
|
||||||
|
current.refs--
|
||||||
|
if (current.refs <= 0) {
|
||||||
|
current.source.close()
|
||||||
|
sseConnections.delete(workspaceId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [workspaceId, queryClient])
|
||||||
|
}
|
||||||
|
|||||||
109
apps/sim/lib/mcp/client.test.ts
Normal file
109
apps/sim/lib/mcp/client.test.ts
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
/**
|
||||||
|
* @vitest-environment node
|
||||||
|
*/
|
||||||
|
import { loggerMock } from '@sim/testing'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
vi.mock('@sim/logger', () => loggerMock)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Capture the notification handler registered via `client.setNotificationHandler()`.
|
||||||
|
* This lets us simulate the MCP SDK delivering a `tools/list_changed` notification.
|
||||||
|
*/
|
||||||
|
let capturedNotificationHandler: (() => Promise<void>) | null = null
|
||||||
|
|
||||||
|
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
|
||||||
|
Client: vi.fn().mockImplementation(() => ({
|
||||||
|
connect: vi.fn().mockResolvedValue(undefined),
|
||||||
|
close: vi.fn().mockResolvedValue(undefined),
|
||||||
|
getServerVersion: vi.fn().mockReturnValue('2025-06-18'),
|
||||||
|
getServerCapabilities: vi.fn().mockReturnValue({ tools: { listChanged: true } }),
|
||||||
|
setNotificationHandler: vi
|
||||||
|
.fn()
|
||||||
|
.mockImplementation((_schema: unknown, handler: () => Promise<void>) => {
|
||||||
|
capturedNotificationHandler = handler
|
||||||
|
}),
|
||||||
|
listTools: vi.fn().mockResolvedValue({ tools: [] }),
|
||||||
|
})),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({
|
||||||
|
StreamableHTTPClientTransport: vi.fn().mockImplementation(() => ({
|
||||||
|
onclose: null,
|
||||||
|
sessionId: 'test-session',
|
||||||
|
})),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@modelcontextprotocol/sdk/types.js', () => ({
|
||||||
|
ToolListChangedNotificationSchema: { method: 'notifications/tools/list_changed' },
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/lib/core/execution-limits', () => ({
|
||||||
|
getMaxExecutionTimeout: vi.fn().mockReturnValue(30000),
|
||||||
|
}))
|
||||||
|
|
||||||
|
import { McpClient } from './client'
|
||||||
|
import type { McpServerConfig } from './types'
|
||||||
|
|
||||||
|
function createConfig(): McpServerConfig {
|
||||||
|
return {
|
||||||
|
id: 'server-1',
|
||||||
|
name: 'Test Server',
|
||||||
|
transport: 'streamable-http',
|
||||||
|
url: 'https://test.example.com/mcp',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('McpClient notification handler', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
capturedNotificationHandler = null
|
||||||
|
})
|
||||||
|
|
||||||
|
it('fires onToolsChanged when a notification arrives while connected', async () => {
|
||||||
|
const onToolsChanged = vi.fn()
|
||||||
|
|
||||||
|
const client = new McpClient({
|
||||||
|
config: createConfig(),
|
||||||
|
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
|
||||||
|
onToolsChanged,
|
||||||
|
})
|
||||||
|
|
||||||
|
await client.connect()
|
||||||
|
|
||||||
|
expect(capturedNotificationHandler).not.toBeNull()
|
||||||
|
|
||||||
|
await capturedNotificationHandler!()
|
||||||
|
|
||||||
|
expect(onToolsChanged).toHaveBeenCalledTimes(1)
|
||||||
|
expect(onToolsChanged).toHaveBeenCalledWith('server-1')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('suppresses notifications after disconnect', async () => {
|
||||||
|
const onToolsChanged = vi.fn()
|
||||||
|
|
||||||
|
const client = new McpClient({
|
||||||
|
config: createConfig(),
|
||||||
|
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
|
||||||
|
onToolsChanged,
|
||||||
|
})
|
||||||
|
|
||||||
|
await client.connect()
|
||||||
|
expect(capturedNotificationHandler).not.toBeNull()
|
||||||
|
|
||||||
|
await client.disconnect()
|
||||||
|
await capturedNotificationHandler!()
|
||||||
|
|
||||||
|
expect(onToolsChanged).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('does not register a notification handler when onToolsChanged is not provided', async () => {
|
||||||
|
const client = new McpClient({
|
||||||
|
config: createConfig(),
|
||||||
|
securityPolicy: { requireConsent: false, auditLevel: 'basic' },
|
||||||
|
})
|
||||||
|
|
||||||
|
await client.connect()
|
||||||
|
|
||||||
|
expect(capturedNotificationHandler).toBeNull()
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -10,10 +10,15 @@
|
|||||||
|
|
||||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||||
import type { ListToolsResult, Tool } from '@modelcontextprotocol/sdk/types.js'
|
import {
|
||||||
|
type ListToolsResult,
|
||||||
|
type Tool,
|
||||||
|
ToolListChangedNotificationSchema,
|
||||||
|
} from '@modelcontextprotocol/sdk/types.js'
|
||||||
import { createLogger } from '@sim/logger'
|
import { createLogger } from '@sim/logger'
|
||||||
import { getMaxExecutionTimeout } from '@/lib/core/execution-limits'
|
import { getMaxExecutionTimeout } from '@/lib/core/execution-limits'
|
||||||
import {
|
import {
|
||||||
|
type McpClientOptions,
|
||||||
McpConnectionError,
|
McpConnectionError,
|
||||||
type McpConnectionStatus,
|
type McpConnectionStatus,
|
||||||
type McpConsentRequest,
|
type McpConsentRequest,
|
||||||
@@ -24,6 +29,7 @@ import {
|
|||||||
type McpTool,
|
type McpTool,
|
||||||
type McpToolCall,
|
type McpToolCall,
|
||||||
type McpToolResult,
|
type McpToolResult,
|
||||||
|
type McpToolsChangedCallback,
|
||||||
type McpVersionInfo,
|
type McpVersionInfo,
|
||||||
} from '@/lib/mcp/types'
|
} from '@/lib/mcp/types'
|
||||||
|
|
||||||
@@ -35,6 +41,7 @@ export class McpClient {
|
|||||||
private config: McpServerConfig
|
private config: McpServerConfig
|
||||||
private connectionStatus: McpConnectionStatus
|
private connectionStatus: McpConnectionStatus
|
||||||
private securityPolicy: McpSecurityPolicy
|
private securityPolicy: McpSecurityPolicy
|
||||||
|
private onToolsChanged?: McpToolsChangedCallback
|
||||||
private isConnected = false
|
private isConnected = false
|
||||||
|
|
||||||
private static readonly SUPPORTED_VERSIONS = [
|
private static readonly SUPPORTED_VERSIONS = [
|
||||||
@@ -44,23 +51,36 @@ export class McpClient {
|
|||||||
]
|
]
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new MCP client
|
* Creates a new MCP client.
|
||||||
*
|
*
|
||||||
* No session ID parameter (we disconnect after each operation).
|
* Accepts either the legacy (config, securityPolicy?) signature
|
||||||
* The SDK handles session management automatically via Mcp-Session-Id header.
|
* or a single McpClientOptions object with an optional onToolsChanged callback.
|
||||||
*
|
|
||||||
* @param config - Server configuration
|
|
||||||
* @param securityPolicy - Optional security policy
|
|
||||||
*/
|
*/
|
||||||
constructor(config: McpServerConfig, securityPolicy?: McpSecurityPolicy) {
|
constructor(config: McpServerConfig, securityPolicy?: McpSecurityPolicy)
|
||||||
this.config = config
|
constructor(options: McpClientOptions)
|
||||||
this.connectionStatus = { connected: false }
|
constructor(
|
||||||
this.securityPolicy = securityPolicy ?? {
|
configOrOptions: McpServerConfig | McpClientOptions,
|
||||||
requireConsent: true,
|
securityPolicy?: McpSecurityPolicy
|
||||||
auditLevel: 'basic',
|
) {
|
||||||
maxToolExecutionsPerHour: 1000,
|
if ('config' in configOrOptions) {
|
||||||
|
this.config = configOrOptions.config
|
||||||
|
this.securityPolicy = configOrOptions.securityPolicy ?? {
|
||||||
|
requireConsent: true,
|
||||||
|
auditLevel: 'basic',
|
||||||
|
maxToolExecutionsPerHour: 1000,
|
||||||
|
}
|
||||||
|
this.onToolsChanged = configOrOptions.onToolsChanged
|
||||||
|
} else {
|
||||||
|
this.config = configOrOptions
|
||||||
|
this.securityPolicy = securityPolicy ?? {
|
||||||
|
requireConsent: true,
|
||||||
|
auditLevel: 'basic',
|
||||||
|
maxToolExecutionsPerHour: 1000,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.connectionStatus = { connected: false }
|
||||||
|
|
||||||
if (!this.config.url) {
|
if (!this.config.url) {
|
||||||
throw new McpError('URL required for Streamable HTTP transport')
|
throw new McpError('URL required for Streamable HTTP transport')
|
||||||
}
|
}
|
||||||
@@ -79,16 +99,15 @@ export class McpClient {
|
|||||||
{
|
{
|
||||||
capabilities: {
|
capabilities: {
|
||||||
tools: {},
|
tools: {},
|
||||||
// Resources and prompts can be added later
|
|
||||||
// resources: {},
|
|
||||||
// prompts: {},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize connection to MCP server
|
* Initialize connection to MCP server.
|
||||||
|
* If an `onToolsChanged` callback was provided, registers a notification handler
|
||||||
|
* for `notifications/tools/list_changed` after connecting.
|
||||||
*/
|
*/
|
||||||
async connect(): Promise<void> {
|
async connect(): Promise<void> {
|
||||||
logger.info(`Connecting to MCP server: ${this.config.name} (${this.config.transport})`)
|
logger.info(`Connecting to MCP server: ${this.config.name} (${this.config.transport})`)
|
||||||
@@ -100,6 +119,15 @@ export class McpClient {
|
|||||||
this.connectionStatus.connected = true
|
this.connectionStatus.connected = true
|
||||||
this.connectionStatus.lastConnected = new Date()
|
this.connectionStatus.lastConnected = new Date()
|
||||||
|
|
||||||
|
if (this.onToolsChanged) {
|
||||||
|
this.client.setNotificationHandler(ToolListChangedNotificationSchema, async () => {
|
||||||
|
if (!this.isConnected) return
|
||||||
|
logger.info(`[${this.config.name}] Received tools/list_changed notification`)
|
||||||
|
this.onToolsChanged?.(this.config.id)
|
||||||
|
})
|
||||||
|
logger.info(`[${this.config.name}] Registered tools/list_changed notification handler`)
|
||||||
|
}
|
||||||
|
|
||||||
const serverVersion = this.client.getServerVersion()
|
const serverVersion = this.client.getServerVersion()
|
||||||
logger.info(`Successfully connected to MCP server: ${this.config.name}`, {
|
logger.info(`Successfully connected to MCP server: ${this.config.name}`, {
|
||||||
protocolVersion: serverVersion,
|
protocolVersion: serverVersion,
|
||||||
@@ -241,6 +269,23 @@ export class McpClient {
|
|||||||
return !!serverCapabilities?.[capability]
|
return !!serverCapabilities?.[capability]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if the server declared `capabilities.tools.listChanged: true` during initialization.
|
||||||
|
*/
|
||||||
|
hasListChangedCapability(): boolean {
|
||||||
|
const caps = this.client.getServerCapabilities()
|
||||||
|
const toolsCap = caps?.tools as Record<string, unknown> | undefined
|
||||||
|
return !!toolsCap?.listChanged
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register a callback to be invoked when the underlying transport closes.
|
||||||
|
* Used by the connection manager for reconnection logic.
|
||||||
|
*/
|
||||||
|
onClose(callback: () => void): void {
|
||||||
|
this.transport.onclose = callback
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get server configuration
|
* Get server configuration
|
||||||
*/
|
*/
|
||||||
|
|||||||
180
apps/sim/lib/mcp/connection-manager.test.ts
Normal file
180
apps/sim/lib/mcp/connection-manager.test.ts
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
/**
|
||||||
|
* @vitest-environment node
|
||||||
|
*/
|
||||||
|
import { loggerMock } from '@sim/testing'
|
||||||
|
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
interface MockMcpClient {
|
||||||
|
connect: ReturnType<typeof vi.fn>
|
||||||
|
disconnect: ReturnType<typeof vi.fn>
|
||||||
|
hasListChangedCapability: ReturnType<typeof vi.fn>
|
||||||
|
onClose: ReturnType<typeof vi.fn>
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Deferred promise to control when `client.connect()` resolves. */
|
||||||
|
function createDeferred<T = void>() {
|
||||||
|
let resolve!: (value: T) => void
|
||||||
|
const promise = new Promise<T>((res) => {
|
||||||
|
resolve = res
|
||||||
|
})
|
||||||
|
return { promise, resolve }
|
||||||
|
}
|
||||||
|
|
||||||
|
function serverConfig(id: string, name = `Server ${id}`) {
|
||||||
|
return {
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
transport: 'streamable-http' as const,
|
||||||
|
url: `https://${id}.example.com/mcp`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Shared setup: resets modules and applies base mocks. */
|
||||||
|
function setupBaseMocks() {
|
||||||
|
vi.resetModules()
|
||||||
|
vi.doMock('@sim/logger', () => loggerMock)
|
||||||
|
vi.doMock('@/lib/core/config/feature-flags', () => ({ isTest: false }))
|
||||||
|
vi.doMock('@/lib/mcp/pubsub', () => ({
|
||||||
|
mcpPubSub: { onToolsChanged: vi.fn(() => vi.fn()), publishToolsChanged: vi.fn() },
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('McpConnectionManager', () => {
|
||||||
|
let manager: {
|
||||||
|
connect: (...args: unknown[]) => Promise<{ supportsListChanged: boolean }>
|
||||||
|
dispose: () => void
|
||||||
|
} | null = null
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
manager?.dispose()
|
||||||
|
manager = null
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('concurrent connect() guard', () => {
|
||||||
|
it('creates only one client when two connect() calls race for the same serverId', async () => {
|
||||||
|
setupBaseMocks()
|
||||||
|
|
||||||
|
const deferred = createDeferred()
|
||||||
|
const instances: MockMcpClient[] = []
|
||||||
|
|
||||||
|
vi.doMock('./client', () => ({
|
||||||
|
McpClient: vi.fn().mockImplementation(() => {
|
||||||
|
const instance: MockMcpClient = {
|
||||||
|
connect: vi.fn().mockImplementation(() => deferred.promise),
|
||||||
|
disconnect: vi.fn().mockResolvedValue(undefined),
|
||||||
|
hasListChangedCapability: vi.fn().mockReturnValue(true),
|
||||||
|
onClose: vi.fn(),
|
||||||
|
}
|
||||||
|
instances.push(instance)
|
||||||
|
return instance
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const { mcpConnectionManager: mgr } = await import('./connection-manager')
|
||||||
|
manager = mgr
|
||||||
|
|
||||||
|
const config = serverConfig('server-1')
|
||||||
|
|
||||||
|
const p1 = mgr.connect(config, 'user-1', 'ws-1')
|
||||||
|
const p2 = mgr.connect(config, 'user-1', 'ws-1')
|
||||||
|
|
||||||
|
deferred.resolve()
|
||||||
|
const [r1, r2] = await Promise.all([p1, p2])
|
||||||
|
|
||||||
|
expect(instances).toHaveLength(1)
|
||||||
|
expect(r1.supportsListChanged).toBe(true)
|
||||||
|
expect(r2.supportsListChanged).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('allows a new connect() after a previous one completes', async () => {
|
||||||
|
setupBaseMocks()
|
||||||
|
|
||||||
|
const instances: MockMcpClient[] = []
|
||||||
|
|
||||||
|
vi.doMock('./client', () => ({
|
||||||
|
McpClient: vi.fn().mockImplementation(() => {
|
||||||
|
const instance: MockMcpClient = {
|
||||||
|
connect: vi.fn().mockResolvedValue(undefined),
|
||||||
|
disconnect: vi.fn().mockResolvedValue(undefined),
|
||||||
|
hasListChangedCapability: vi.fn().mockReturnValue(false),
|
||||||
|
onClose: vi.fn(),
|
||||||
|
}
|
||||||
|
instances.push(instance)
|
||||||
|
return instance
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const { mcpConnectionManager: mgr } = await import('./connection-manager')
|
||||||
|
manager = mgr
|
||||||
|
|
||||||
|
const config = serverConfig('server-2')
|
||||||
|
|
||||||
|
const r1 = await mgr.connect(config, 'user-1', 'ws-1')
|
||||||
|
expect(r1.supportsListChanged).toBe(false)
|
||||||
|
|
||||||
|
const r2 = await mgr.connect(config, 'user-1', 'ws-1')
|
||||||
|
expect(r2.supportsListChanged).toBe(false)
|
||||||
|
|
||||||
|
expect(instances).toHaveLength(2)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('cleans up connectingServers when connect() throws', async () => {
|
||||||
|
setupBaseMocks()
|
||||||
|
|
||||||
|
let callCount = 0
|
||||||
|
const instances: MockMcpClient[] = []
|
||||||
|
|
||||||
|
vi.doMock('./client', () => ({
|
||||||
|
McpClient: vi.fn().mockImplementation(() => {
|
||||||
|
callCount++
|
||||||
|
const instance: MockMcpClient = {
|
||||||
|
connect:
|
||||||
|
callCount === 1
|
||||||
|
? vi.fn().mockRejectedValue(new Error('Connection refused'))
|
||||||
|
: vi.fn().mockResolvedValue(undefined),
|
||||||
|
disconnect: vi.fn().mockResolvedValue(undefined),
|
||||||
|
hasListChangedCapability: vi.fn().mockReturnValue(true),
|
||||||
|
onClose: vi.fn(),
|
||||||
|
}
|
||||||
|
instances.push(instance)
|
||||||
|
return instance
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const { mcpConnectionManager: mgr } = await import('./connection-manager')
|
||||||
|
manager = mgr
|
||||||
|
|
||||||
|
const config = serverConfig('server-3')
|
||||||
|
|
||||||
|
const r1 = await mgr.connect(config, 'user-1', 'ws-1')
|
||||||
|
expect(r1.supportsListChanged).toBe(false)
|
||||||
|
|
||||||
|
const r2 = await mgr.connect(config, 'user-1', 'ws-1')
|
||||||
|
expect(r2.supportsListChanged).toBe(true)
|
||||||
|
expect(instances).toHaveLength(2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('dispose', () => {
|
||||||
|
it('rejects new connections after dispose', async () => {
|
||||||
|
setupBaseMocks()
|
||||||
|
|
||||||
|
vi.doMock('./client', () => ({
|
||||||
|
McpClient: vi.fn().mockImplementation(() => ({
|
||||||
|
connect: vi.fn().mockResolvedValue(undefined),
|
||||||
|
disconnect: vi.fn().mockResolvedValue(undefined),
|
||||||
|
hasListChangedCapability: vi.fn().mockReturnValue(true),
|
||||||
|
onClose: vi.fn(),
|
||||||
|
})),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const { mcpConnectionManager: mgr } = await import('./connection-manager')
|
||||||
|
manager = mgr
|
||||||
|
|
||||||
|
mgr.dispose()
|
||||||
|
|
||||||
|
const result = await mgr.connect(serverConfig('server-4'), 'user-1', 'ws-1')
|
||||||
|
expect(result.supportsListChanged).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
361
apps/sim/lib/mcp/connection-manager.ts
Normal file
361
apps/sim/lib/mcp/connection-manager.ts
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
/**
|
||||||
|
* MCP Connection Manager
|
||||||
|
*
|
||||||
|
* Maintains persistent connections to MCP servers that support
|
||||||
|
* `notifications/tools/list_changed`. When a notification arrives,
|
||||||
|
* the manager invalidates the tools cache and emits a ToolsChangedEvent
|
||||||
|
* so the frontend SSE endpoint can push updates to browsers.
|
||||||
|
*
|
||||||
|
* Servers that do not support `listChanged` fall back to the existing
|
||||||
|
* stale-time cache approach — no persistent connection is kept.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { createLogger } from '@sim/logger'
|
||||||
|
import { isTest } from '@/lib/core/config/feature-flags'
|
||||||
|
import { McpClient } from '@/lib/mcp/client'
|
||||||
|
import { mcpPubSub } from '@/lib/mcp/pubsub'
|
||||||
|
import type {
|
||||||
|
ManagedConnectionState,
|
||||||
|
McpServerConfig,
|
||||||
|
McpToolsChangedCallback,
|
||||||
|
ToolsChangedEvent,
|
||||||
|
} from '@/lib/mcp/types'
|
||||||
|
|
||||||
|
const logger = createLogger('McpConnectionManager')
|
||||||
|
|
||||||
|
const MAX_CONNECTIONS = 50
|
||||||
|
const MAX_RECONNECT_ATTEMPTS = 10
|
||||||
|
const BASE_RECONNECT_DELAY_MS = 1000
|
||||||
|
const IDLE_TIMEOUT_MS = 30 * 60 * 1000 // 30 minutes
|
||||||
|
const IDLE_CHECK_INTERVAL_MS = 5 * 60 * 1000 // 5 minutes
|
||||||
|
|
||||||
|
type ToolsChangedListener = (event: ToolsChangedEvent) => void
|
||||||
|
|
||||||
|
class McpConnectionManager {
|
||||||
|
private connections = new Map<string, McpClient>()
|
||||||
|
private states = new Map<string, ManagedConnectionState>()
|
||||||
|
private reconnectTimers = new Map<string, ReturnType<typeof setTimeout>>()
|
||||||
|
private listeners = new Set<ToolsChangedListener>()
|
||||||
|
private connectingServers = new Set<string>()
|
||||||
|
private idleCheckTimer: ReturnType<typeof setInterval> | null = null
|
||||||
|
private disposed = false
|
||||||
|
private unsubscribePubSub?: () => void
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
if (mcpPubSub) {
|
||||||
|
this.unsubscribePubSub = mcpPubSub.onToolsChanged((event) => {
|
||||||
|
this.notifyLocalListeners(event)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subscribe to tools-changed events from any managed connection.
|
||||||
|
* Returns an unsubscribe function.
|
||||||
|
*/
|
||||||
|
subscribe(listener: ToolsChangedListener): () => void {
|
||||||
|
this.listeners.add(listener)
|
||||||
|
return () => {
|
||||||
|
this.listeners.delete(listener)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Establish a persistent connection to an MCP server.
|
||||||
|
* If the server supports `listChanged`, the connection is kept alive
|
||||||
|
* and notifications are forwarded to subscribers.
|
||||||
|
*
|
||||||
|
* If the server does NOT support `listChanged`, the client is disconnected
|
||||||
|
* immediately — there's nothing to listen for.
|
||||||
|
*/
|
||||||
|
async connect(
|
||||||
|
config: McpServerConfig,
|
||||||
|
userId: string,
|
||||||
|
workspaceId: string
|
||||||
|
): Promise<{ supportsListChanged: boolean }> {
|
||||||
|
if (this.disposed) {
|
||||||
|
logger.warn('Connection manager is disposed, ignoring connect request')
|
||||||
|
return { supportsListChanged: false }
|
||||||
|
}
|
||||||
|
|
||||||
|
const serverId = config.id
|
||||||
|
|
||||||
|
if (this.connections.has(serverId) || this.connectingServers.has(serverId)) {
|
||||||
|
logger.info(`[${config.name}] Already has a managed connection or is connecting, skipping`)
|
||||||
|
const state = this.states.get(serverId)
|
||||||
|
return { supportsListChanged: state?.supportsListChanged ?? false }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.connections.size >= MAX_CONNECTIONS) {
|
||||||
|
logger.warn(`Max connections (${MAX_CONNECTIONS}) reached, cannot connect to ${config.name}`)
|
||||||
|
return { supportsListChanged: false }
|
||||||
|
}
|
||||||
|
|
||||||
|
this.connectingServers.add(serverId)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const onToolsChanged: McpToolsChangedCallback = (sid) => {
|
||||||
|
this.handleToolsChanged(sid)
|
||||||
|
}
|
||||||
|
|
||||||
|
const client = new McpClient({
|
||||||
|
config,
|
||||||
|
securityPolicy: {
|
||||||
|
requireConsent: false,
|
||||||
|
auditLevel: 'basic',
|
||||||
|
maxToolExecutionsPerHour: 1000,
|
||||||
|
},
|
||||||
|
onToolsChanged,
|
||||||
|
})
|
||||||
|
|
||||||
|
try {
|
||||||
|
await client.connect()
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`[${config.name}] Failed to connect for persistent monitoring:`, error)
|
||||||
|
return { supportsListChanged: false }
|
||||||
|
}
|
||||||
|
|
||||||
|
const supportsListChanged = client.hasListChangedCapability()
|
||||||
|
|
||||||
|
if (!supportsListChanged) {
|
||||||
|
logger.info(
|
||||||
|
`[${config.name}] Server does not support listChanged — disconnecting (fallback to cache)`
|
||||||
|
)
|
||||||
|
await client.disconnect()
|
||||||
|
return { supportsListChanged: false }
|
||||||
|
}
|
||||||
|
|
||||||
|
this.connections.set(serverId, client)
|
||||||
|
this.states.set(serverId, {
|
||||||
|
serverId,
|
||||||
|
serverName: config.name,
|
||||||
|
workspaceId,
|
||||||
|
userId,
|
||||||
|
connected: true,
|
||||||
|
supportsListChanged: true,
|
||||||
|
reconnectAttempts: 0,
|
||||||
|
lastActivity: Date.now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
client.onClose(() => {
|
||||||
|
this.handleDisconnect(config, userId, workspaceId)
|
||||||
|
})
|
||||||
|
|
||||||
|
this.ensureIdleCheck()
|
||||||
|
|
||||||
|
logger.info(`[${config.name}] Persistent connection established (listChanged supported)`)
|
||||||
|
return { supportsListChanged: true }
|
||||||
|
} finally {
|
||||||
|
this.connectingServers.delete(serverId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Disconnect a managed connection.
|
||||||
|
*/
|
||||||
|
async disconnect(serverId: string): Promise<void> {
|
||||||
|
this.clearReconnectTimer(serverId)
|
||||||
|
|
||||||
|
const client = this.connections.get(serverId)
|
||||||
|
if (client) {
|
||||||
|
try {
|
||||||
|
await client.disconnect()
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn(`Error disconnecting managed client ${serverId}:`, error)
|
||||||
|
}
|
||||||
|
this.connections.delete(serverId)
|
||||||
|
}
|
||||||
|
|
||||||
|
this.states.delete(serverId)
|
||||||
|
logger.info(`Managed connection removed: ${serverId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check whether a managed connection exists for the given server.
|
||||||
|
*/
|
||||||
|
hasConnection(serverId: string): boolean {
|
||||||
|
return this.connections.has(serverId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get connection state for a server.
|
||||||
|
*/
|
||||||
|
getState(serverId: string): ManagedConnectionState | undefined {
|
||||||
|
return this.states.get(serverId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all managed connection states (for diagnostics).
|
||||||
|
*/
|
||||||
|
getAllStates(): ManagedConnectionState[] {
|
||||||
|
return [...this.states.values()]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dispose all connections and timers.
|
||||||
|
*/
|
||||||
|
dispose(): void {
|
||||||
|
this.disposed = true
|
||||||
|
|
||||||
|
this.unsubscribePubSub?.()
|
||||||
|
|
||||||
|
for (const timer of this.reconnectTimers.values()) {
|
||||||
|
clearTimeout(timer)
|
||||||
|
}
|
||||||
|
this.reconnectTimers.clear()
|
||||||
|
|
||||||
|
if (this.idleCheckTimer) {
|
||||||
|
clearInterval(this.idleCheckTimer)
|
||||||
|
this.idleCheckTimer = null
|
||||||
|
}
|
||||||
|
|
||||||
|
const disconnects = [...this.connections.entries()].map(async ([id, client]) => {
|
||||||
|
try {
|
||||||
|
await client.disconnect()
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn(`Error disconnecting ${id} during dispose:`, error)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
Promise.allSettled(disconnects).then(() => {
|
||||||
|
logger.info('Connection manager disposed')
|
||||||
|
})
|
||||||
|
|
||||||
|
this.connections.clear()
|
||||||
|
this.states.clear()
|
||||||
|
this.listeners.clear()
|
||||||
|
this.connectingServers.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Notify only process-local listeners.
|
||||||
|
* Called by the pub/sub subscription (receives events from all processes).
|
||||||
|
*/
|
||||||
|
private notifyLocalListeners(event: ToolsChangedEvent): void {
|
||||||
|
for (const listener of this.listeners) {
|
||||||
|
try {
|
||||||
|
listener(event)
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error in tools-changed listener:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle a tools/list_changed notification from an external MCP server.
|
||||||
|
* Publishes to pub/sub so all processes are notified.
|
||||||
|
*/
|
||||||
|
private handleToolsChanged(serverId: string): void {
|
||||||
|
const state = this.states.get(serverId)
|
||||||
|
if (!state) return
|
||||||
|
|
||||||
|
state.lastActivity = Date.now()
|
||||||
|
|
||||||
|
const event: ToolsChangedEvent = {
|
||||||
|
serverId,
|
||||||
|
serverName: state.serverName,
|
||||||
|
workspaceId: state.workspaceId,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(`[${state.serverName}] Tools changed — publishing to pub/sub`)
|
||||||
|
|
||||||
|
mcpPubSub?.publishToolsChanged(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
private handleDisconnect(config: McpServerConfig, userId: string, workspaceId: string): void {
|
||||||
|
const serverId = config.id
|
||||||
|
const state = this.states.get(serverId)
|
||||||
|
|
||||||
|
if (!state || this.disposed) return
|
||||||
|
|
||||||
|
state.connected = false
|
||||||
|
this.connections.delete(serverId)
|
||||||
|
|
||||||
|
logger.warn(`[${config.name}] Persistent connection lost, scheduling reconnect`)
|
||||||
|
|
||||||
|
this.scheduleReconnect(config, userId, workspaceId)
|
||||||
|
}
|
||||||
|
|
||||||
|
private scheduleReconnect(config: McpServerConfig, userId: string, workspaceId: string): void {
|
||||||
|
const serverId = config.id
|
||||||
|
const state = this.states.get(serverId)
|
||||||
|
|
||||||
|
if (!state || this.disposed) return
|
||||||
|
|
||||||
|
if (state.reconnectAttempts >= MAX_RECONNECT_ATTEMPTS) {
|
||||||
|
logger.error(
|
||||||
|
`[${config.name}] Max reconnect attempts (${MAX_RECONNECT_ATTEMPTS}) reached — giving up`
|
||||||
|
)
|
||||||
|
this.states.delete(serverId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const delay = Math.min(BASE_RECONNECT_DELAY_MS * 2 ** state.reconnectAttempts, 60_000)
|
||||||
|
state.reconnectAttempts++
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
`[${config.name}] Reconnecting in ${delay}ms (attempt ${state.reconnectAttempts}/${MAX_RECONNECT_ATTEMPTS})`
|
||||||
|
)
|
||||||
|
|
||||||
|
this.clearReconnectTimer(serverId)
|
||||||
|
|
||||||
|
const timer = setTimeout(async () => {
|
||||||
|
this.reconnectTimers.delete(serverId)
|
||||||
|
|
||||||
|
if (this.disposed) return
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.connections.delete(serverId)
|
||||||
|
this.states.delete(serverId)
|
||||||
|
|
||||||
|
const result = await this.connect(config, userId, workspaceId)
|
||||||
|
if (result.supportsListChanged) {
|
||||||
|
const newState = this.states.get(serverId)
|
||||||
|
if (newState) {
|
||||||
|
newState.reconnectAttempts = 0
|
||||||
|
}
|
||||||
|
logger.info(`[${config.name}] Reconnected successfully`)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`[${config.name}] Reconnect failed:`, error)
|
||||||
|
this.scheduleReconnect(config, userId, workspaceId)
|
||||||
|
}
|
||||||
|
}, delay)
|
||||||
|
|
||||||
|
this.reconnectTimers.set(serverId, timer)
|
||||||
|
}
|
||||||
|
|
||||||
|
private clearReconnectTimer(serverId: string): void {
|
||||||
|
const timer = this.reconnectTimers.get(serverId)
|
||||||
|
if (timer) {
|
||||||
|
clearTimeout(timer)
|
||||||
|
this.reconnectTimers.delete(serverId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ensureIdleCheck(): void {
|
||||||
|
if (this.idleCheckTimer) return
|
||||||
|
|
||||||
|
this.idleCheckTimer = setInterval(() => {
|
||||||
|
const now = Date.now()
|
||||||
|
for (const [serverId, state] of this.states) {
|
||||||
|
if (now - state.lastActivity > IDLE_TIMEOUT_MS) {
|
||||||
|
logger.info(
|
||||||
|
`[${state.serverName}] Idle timeout reached, disconnecting managed connection`
|
||||||
|
)
|
||||||
|
this.disconnect(serverId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.states.size === 0 && this.idleCheckTimer) {
|
||||||
|
clearInterval(this.idleCheckTimer)
|
||||||
|
this.idleCheckTimer = null
|
||||||
|
}
|
||||||
|
}, IDLE_CHECK_INTERVAL_MS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const mcpConnectionManager = isTest
|
||||||
|
? (null as unknown as McpConnectionManager)
|
||||||
|
: new McpConnectionManager()
|
||||||
93
apps/sim/lib/mcp/pubsub.test.ts
Normal file
93
apps/sim/lib/mcp/pubsub.test.ts
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
/**
|
||||||
|
* @vitest-environment node
|
||||||
|
*/
|
||||||
|
import { createMockRedis, loggerMock, type MockRedis } from '@sim/testing'
|
||||||
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
/** Extend the @sim/testing Redis mock with the methods RedisMcpPubSub uses. */
|
||||||
|
function createPubSubRedis(): MockRedis & { removeAllListeners: ReturnType<typeof vi.fn> } {
|
||||||
|
const mock = createMockRedis()
|
||||||
|
// ioredis subscribe invokes a callback as the last argument
|
||||||
|
mock.subscribe.mockImplementation((...args: unknown[]) => {
|
||||||
|
const cb = args[args.length - 1]
|
||||||
|
if (typeof cb === 'function') (cb as (err: null) => void)(null)
|
||||||
|
})
|
||||||
|
// on() returns `this` for chaining in ioredis
|
||||||
|
mock.on.mockReturnThis()
|
||||||
|
return { ...mock, removeAllListeners: vi.fn().mockReturnThis() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Shared setup: resets modules and applies base mocks. Returns the two Redis instances. */
|
||||||
|
async function setupPubSub() {
|
||||||
|
const instances: ReturnType<typeof createPubSubRedis>[] = []
|
||||||
|
|
||||||
|
vi.resetModules()
|
||||||
|
vi.doMock('@sim/logger', () => loggerMock)
|
||||||
|
vi.doMock('@/lib/core/config/env', () => ({ env: { REDIS_URL: 'redis://localhost:6379' } }))
|
||||||
|
vi.doMock('ioredis', () => ({
|
||||||
|
default: vi.fn().mockImplementation(() => {
|
||||||
|
const instance = createPubSubRedis()
|
||||||
|
instances.push(instance)
|
||||||
|
return instance
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const { mcpPubSub } = await import('./pubsub')
|
||||||
|
const [pub, sub] = instances
|
||||||
|
|
||||||
|
return { mcpPubSub, pub, sub, instances }
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('RedisMcpPubSub', () => {
|
||||||
|
it('creates two Redis clients (pub and sub)', async () => {
|
||||||
|
const { mcpPubSub, instances } = await setupPubSub()
|
||||||
|
|
||||||
|
expect(instances).toHaveLength(2)
|
||||||
|
mcpPubSub.dispose()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('registers error, connect, and message listeners', async () => {
|
||||||
|
const { mcpPubSub, pub, sub } = await setupPubSub()
|
||||||
|
|
||||||
|
const pubEvents = pub.on.mock.calls.map((c: unknown[]) => c[0])
|
||||||
|
const subEvents = sub.on.mock.calls.map((c: unknown[]) => c[0])
|
||||||
|
|
||||||
|
expect(pubEvents).toContain('error')
|
||||||
|
expect(pubEvents).toContain('connect')
|
||||||
|
expect(subEvents).toContain('error')
|
||||||
|
expect(subEvents).toContain('connect')
|
||||||
|
expect(subEvents).toContain('message')
|
||||||
|
|
||||||
|
mcpPubSub.dispose()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('dispose', () => {
|
||||||
|
it('calls removeAllListeners on both pub and sub before quit', async () => {
|
||||||
|
const { mcpPubSub, pub, sub } = await setupPubSub()
|
||||||
|
|
||||||
|
mcpPubSub.dispose()
|
||||||
|
|
||||||
|
expect(pub.removeAllListeners).toHaveBeenCalledTimes(1)
|
||||||
|
expect(sub.removeAllListeners).toHaveBeenCalledTimes(1)
|
||||||
|
expect(sub.unsubscribe).toHaveBeenCalledTimes(1)
|
||||||
|
expect(pub.quit).toHaveBeenCalledTimes(1)
|
||||||
|
expect(sub.quit).toHaveBeenCalledTimes(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('drops publish calls after dispose', async () => {
|
||||||
|
const { mcpPubSub, pub } = await setupPubSub()
|
||||||
|
|
||||||
|
mcpPubSub.dispose()
|
||||||
|
pub.publish.mockClear()
|
||||||
|
|
||||||
|
mcpPubSub.publishToolsChanged({
|
||||||
|
serverId: 'srv-1',
|
||||||
|
serverName: 'Test',
|
||||||
|
workspaceId: 'ws-1',
|
||||||
|
timestamp: Date.now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(pub.publish).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
204
apps/sim/lib/mcp/pubsub.ts
Normal file
204
apps/sim/lib/mcp/pubsub.ts
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
/**
|
||||||
|
* MCP Pub/Sub Adapter
|
||||||
|
*
|
||||||
|
* Broadcasts MCP notification events across processes using Redis Pub/Sub.
|
||||||
|
* Gracefully falls back to process-local EventEmitter when Redis is unavailable.
|
||||||
|
*
|
||||||
|
* Two channels:
|
||||||
|
* - `mcp:tools_changed` — external MCP server sent a listChanged notification
|
||||||
|
* (published by connection manager, consumed by events SSE endpoint)
|
||||||
|
* - `mcp:workflow_tools_changed` — workflow CRUD modified a workflow MCP server's tools
|
||||||
|
* (published by serve route, consumed by serve route on other processes to push to local SSE clients)
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { EventEmitter } from 'events'
|
||||||
|
import { createLogger } from '@sim/logger'
|
||||||
|
import Redis from 'ioredis'
|
||||||
|
import { env } from '@/lib/core/config/env'
|
||||||
|
import type { ToolsChangedEvent, WorkflowToolsChangedEvent } from '@/lib/mcp/types'
|
||||||
|
|
||||||
|
const logger = createLogger('McpPubSub')
|
||||||
|
|
||||||
|
const CHANNEL_TOOLS_CHANGED = 'mcp:tools_changed'
|
||||||
|
const CHANNEL_WORKFLOW_TOOLS_CHANGED = 'mcp:workflow_tools_changed'
|
||||||
|
|
||||||
|
type ToolsChangedHandler = (event: ToolsChangedEvent) => void
|
||||||
|
type WorkflowToolsChangedHandler = (event: WorkflowToolsChangedEvent) => void
|
||||||
|
|
||||||
|
interface McpPubSubAdapter {
|
||||||
|
publishToolsChanged(event: ToolsChangedEvent): void
|
||||||
|
publishWorkflowToolsChanged(event: WorkflowToolsChangedEvent): void
|
||||||
|
onToolsChanged(handler: ToolsChangedHandler): () => void
|
||||||
|
onWorkflowToolsChanged(handler: WorkflowToolsChangedHandler): () => void
|
||||||
|
dispose(): void
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Redis-backed pub/sub adapter.
|
||||||
|
* Uses dedicated pub and sub clients (ioredis requires separate connections for subscribers).
|
||||||
|
*/
|
||||||
|
class RedisMcpPubSub implements McpPubSubAdapter {
|
||||||
|
private pub: Redis
|
||||||
|
private sub: Redis
|
||||||
|
private toolsChangedHandlers = new Set<ToolsChangedHandler>()
|
||||||
|
private workflowToolsChangedHandlers = new Set<WorkflowToolsChangedHandler>()
|
||||||
|
private disposed = false
|
||||||
|
|
||||||
|
constructor(redisUrl: string) {
|
||||||
|
const commonOpts = {
|
||||||
|
keepAlive: 1000,
|
||||||
|
connectTimeout: 10000,
|
||||||
|
maxRetriesPerRequest: null as unknown as number,
|
||||||
|
enableOfflineQueue: true,
|
||||||
|
retryStrategy: (times: number) => {
|
||||||
|
if (times > 10) return 30000
|
||||||
|
return Math.min(times * 500, 5000)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
this.pub = new Redis(redisUrl, { ...commonOpts, connectionName: 'mcp-pubsub-pub' })
|
||||||
|
this.sub = new Redis(redisUrl, { ...commonOpts, connectionName: 'mcp-pubsub-sub' })
|
||||||
|
|
||||||
|
this.pub.on('error', (err) => logger.error('MCP pub/sub publish client error:', err.message))
|
||||||
|
this.sub.on('error', (err) => logger.error('MCP pub/sub subscribe client error:', err.message))
|
||||||
|
this.pub.on('connect', () => logger.info('MCP pub/sub publish client connected'))
|
||||||
|
this.sub.on('connect', () => logger.info('MCP pub/sub subscribe client connected'))
|
||||||
|
|
||||||
|
this.sub.subscribe(CHANNEL_TOOLS_CHANGED, CHANNEL_WORKFLOW_TOOLS_CHANGED, (err) => {
|
||||||
|
if (err) {
|
||||||
|
logger.error('Failed to subscribe to MCP pub/sub channels:', err)
|
||||||
|
} else {
|
||||||
|
logger.info('Subscribed to MCP pub/sub channels')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
this.sub.on('message', (channel: string, message: string) => {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(message)
|
||||||
|
if (channel === CHANNEL_TOOLS_CHANGED) {
|
||||||
|
for (const handler of this.toolsChangedHandlers) {
|
||||||
|
try {
|
||||||
|
handler(parsed as ToolsChangedEvent)
|
||||||
|
} catch (err) {
|
||||||
|
logger.error('Error in tools_changed handler:', err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (channel === CHANNEL_WORKFLOW_TOOLS_CHANGED) {
|
||||||
|
for (const handler of this.workflowToolsChangedHandlers) {
|
||||||
|
try {
|
||||||
|
handler(parsed as WorkflowToolsChangedEvent)
|
||||||
|
} catch (err) {
|
||||||
|
logger.error('Error in workflow_tools_changed handler:', err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
logger.error('Failed to parse pub/sub message:', err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
publishToolsChanged(event: ToolsChangedEvent): void {
|
||||||
|
if (this.disposed) return
|
||||||
|
this.pub.publish(CHANNEL_TOOLS_CHANGED, JSON.stringify(event)).catch((err) => {
|
||||||
|
logger.error('Failed to publish tools_changed:', err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
publishWorkflowToolsChanged(event: WorkflowToolsChangedEvent): void {
|
||||||
|
if (this.disposed) return
|
||||||
|
this.pub.publish(CHANNEL_WORKFLOW_TOOLS_CHANGED, JSON.stringify(event)).catch((err) => {
|
||||||
|
logger.error('Failed to publish workflow_tools_changed:', err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
onToolsChanged(handler: ToolsChangedHandler): () => void {
|
||||||
|
this.toolsChangedHandlers.add(handler)
|
||||||
|
return () => {
|
||||||
|
this.toolsChangedHandlers.delete(handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onWorkflowToolsChanged(handler: WorkflowToolsChangedHandler): () => void {
|
||||||
|
this.workflowToolsChangedHandlers.add(handler)
|
||||||
|
return () => {
|
||||||
|
this.workflowToolsChangedHandlers.delete(handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dispose(): void {
|
||||||
|
this.disposed = true
|
||||||
|
this.toolsChangedHandlers.clear()
|
||||||
|
this.workflowToolsChangedHandlers.clear()
|
||||||
|
|
||||||
|
this.pub.removeAllListeners()
|
||||||
|
this.sub.removeAllListeners()
|
||||||
|
|
||||||
|
this.sub.unsubscribe().catch(() => {})
|
||||||
|
this.pub.quit().catch(() => {})
|
||||||
|
this.sub.quit().catch(() => {})
|
||||||
|
logger.info('Redis MCP pub/sub disposed')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Process-local fallback using EventEmitter.
|
||||||
|
* Used when Redis is not configured — notifications only reach listeners in the same process.
|
||||||
|
*/
|
||||||
|
class LocalMcpPubSub implements McpPubSubAdapter {
|
||||||
|
private emitter = new EventEmitter()
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
this.emitter.setMaxListeners(100)
|
||||||
|
logger.info('MCP pub/sub: Using process-local EventEmitter (Redis not configured)')
|
||||||
|
}
|
||||||
|
|
||||||
|
publishToolsChanged(event: ToolsChangedEvent): void {
|
||||||
|
this.emitter.emit(CHANNEL_TOOLS_CHANGED, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
publishWorkflowToolsChanged(event: WorkflowToolsChangedEvent): void {
|
||||||
|
this.emitter.emit(CHANNEL_WORKFLOW_TOOLS_CHANGED, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
onToolsChanged(handler: ToolsChangedHandler): () => void {
|
||||||
|
this.emitter.on(CHANNEL_TOOLS_CHANGED, handler)
|
||||||
|
return () => {
|
||||||
|
this.emitter.off(CHANNEL_TOOLS_CHANGED, handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onWorkflowToolsChanged(handler: WorkflowToolsChangedHandler): () => void {
|
||||||
|
this.emitter.on(CHANNEL_WORKFLOW_TOOLS_CHANGED, handler)
|
||||||
|
return () => {
|
||||||
|
this.emitter.off(CHANNEL_WORKFLOW_TOOLS_CHANGED, handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dispose(): void {
|
||||||
|
this.emitter.removeAllListeners()
|
||||||
|
logger.info('Local MCP pub/sub disposed')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create the appropriate pub/sub adapter based on Redis availability.
|
||||||
|
*/
|
||||||
|
function createMcpPubSub(): McpPubSubAdapter {
|
||||||
|
const redisUrl = env.REDIS_URL
|
||||||
|
|
||||||
|
if (redisUrl) {
|
||||||
|
try {
|
||||||
|
logger.info('MCP pub/sub: Using Redis')
|
||||||
|
return new RedisMcpPubSub(redisUrl)
|
||||||
|
} catch (err) {
|
||||||
|
logger.error('Failed to create Redis pub/sub, falling back to local:', err)
|
||||||
|
return new LocalMcpPubSub()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new LocalMcpPubSub()
|
||||||
|
}
|
||||||
|
|
||||||
|
export const mcpPubSub: McpPubSubAdapter =
|
||||||
|
typeof window !== 'undefined' ? (null as unknown as McpPubSubAdapter) : createMcpPubSub()
|
||||||
@@ -9,6 +9,7 @@ import { and, eq, isNull } from 'drizzle-orm'
|
|||||||
import { isTest } from '@/lib/core/config/feature-flags'
|
import { isTest } from '@/lib/core/config/feature-flags'
|
||||||
import { generateRequestId } from '@/lib/core/utils/request'
|
import { generateRequestId } from '@/lib/core/utils/request'
|
||||||
import { McpClient } from '@/lib/mcp/client'
|
import { McpClient } from '@/lib/mcp/client'
|
||||||
|
import { mcpConnectionManager } from '@/lib/mcp/connection-manager'
|
||||||
import { resolveMcpConfigEnvVars } from '@/lib/mcp/resolve-config'
|
import { resolveMcpConfigEnvVars } from '@/lib/mcp/resolve-config'
|
||||||
import {
|
import {
|
||||||
createMcpCacheAdapter,
|
createMcpCacheAdapter,
|
||||||
@@ -31,16 +32,24 @@ const logger = createLogger('McpService')
|
|||||||
class McpService {
|
class McpService {
|
||||||
private cacheAdapter: McpCacheStorageAdapter
|
private cacheAdapter: McpCacheStorageAdapter
|
||||||
private readonly cacheTimeout = MCP_CONSTANTS.CACHE_TIMEOUT
|
private readonly cacheTimeout = MCP_CONSTANTS.CACHE_TIMEOUT
|
||||||
|
private unsubscribeConnectionManager?: () => void
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.cacheAdapter = createMcpCacheAdapter()
|
this.cacheAdapter = createMcpCacheAdapter()
|
||||||
logger.info(`MCP Service initialized with ${getMcpCacheType()} cache`)
|
logger.info(`MCP Service initialized with ${getMcpCacheType()} cache`)
|
||||||
|
|
||||||
|
if (mcpConnectionManager) {
|
||||||
|
this.unsubscribeConnectionManager = mcpConnectionManager.subscribe((event) => {
|
||||||
|
this.clearCache(event.workspaceId)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dispose of the service and cleanup resources
|
* Dispose of the service and cleanup resources
|
||||||
*/
|
*/
|
||||||
dispose(): void {
|
dispose(): void {
|
||||||
|
this.unsubscribeConnectionManager?.()
|
||||||
this.cacheAdapter.dispose()
|
this.cacheAdapter.dispose()
|
||||||
logger.info('MCP Service disposed')
|
logger.info('MCP Service disposed')
|
||||||
}
|
}
|
||||||
@@ -328,7 +337,7 @@ class McpService {
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
`[${requestId}] Discovered ${tools.length} tools from server ${config.name}`
|
`[${requestId}] Discovered ${tools.length} tools from server ${config.name}`
|
||||||
)
|
)
|
||||||
return { serverId: config.id, tools }
|
return { serverId: config.id, tools, resolvedConfig }
|
||||||
} finally {
|
} finally {
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
}
|
}
|
||||||
@@ -364,6 +373,21 @@ class McpService {
|
|||||||
logger.error(`[${requestId}] Error updating server statuses:`, err)
|
logger.error(`[${requestId}] Error updating server statuses:`, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Fire-and-forget persistent connections for servers that support listChanged
|
||||||
|
if (mcpConnectionManager) {
|
||||||
|
for (const [index, result] of results.entries()) {
|
||||||
|
if (result.status === 'fulfilled') {
|
||||||
|
const { resolvedConfig } = result.value
|
||||||
|
mcpConnectionManager.connect(resolvedConfig, userId, workspaceId).catch((err) => {
|
||||||
|
logger.warn(
|
||||||
|
`[${requestId}] Persistent connection failed for ${servers[index].name}:`,
|
||||||
|
err
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (failedCount === 0) {
|
if (failedCount === 0) {
|
||||||
try {
|
try {
|
||||||
await this.cacheAdapter.set(cacheKey, allTools, this.cacheTimeout)
|
await this.cacheAdapter.set(cacheKey, allTools, this.cacheTimeout)
|
||||||
|
|||||||
@@ -147,6 +147,52 @@ export interface McpServerSummary {
|
|||||||
error?: string
|
error?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Callback invoked when an MCP server sends a `notifications/tools/list_changed` notification.
|
||||||
|
*/
|
||||||
|
export type McpToolsChangedCallback = (serverId: string) => void
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Options for creating an McpClient with notification support.
|
||||||
|
*/
|
||||||
|
export interface McpClientOptions {
|
||||||
|
config: McpServerConfig
|
||||||
|
securityPolicy?: McpSecurityPolicy
|
||||||
|
onToolsChanged?: McpToolsChangedCallback
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Event emitted by the connection manager when a server's tools change.
|
||||||
|
*/
|
||||||
|
export interface ToolsChangedEvent {
|
||||||
|
serverId: string
|
||||||
|
serverName: string
|
||||||
|
workspaceId: string
|
||||||
|
timestamp: number
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* State of a managed persistent connection.
|
||||||
|
*/
|
||||||
|
export interface ManagedConnectionState {
|
||||||
|
serverId: string
|
||||||
|
serverName: string
|
||||||
|
workspaceId: string
|
||||||
|
userId: string
|
||||||
|
connected: boolean
|
||||||
|
supportsListChanged: boolean
|
||||||
|
reconnectAttempts: number
|
||||||
|
lastActivity: number
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Event emitted when workflow CRUD modifies a workflow MCP server's tools.
|
||||||
|
*/
|
||||||
|
export interface WorkflowToolsChangedEvent {
|
||||||
|
serverId: string
|
||||||
|
workspaceId: string
|
||||||
|
}
|
||||||
|
|
||||||
export interface McpApiResponse<T = unknown> {
|
export interface McpApiResponse<T = unknown> {
|
||||||
success: boolean
|
success: boolean
|
||||||
data?: T
|
data?: T
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ import {
|
|||||||
calculateCost,
|
calculateCost,
|
||||||
generateStructuredOutputInstructions,
|
generateStructuredOutputInstructions,
|
||||||
shouldBillModelUsage,
|
shouldBillModelUsage,
|
||||||
|
supportsReasoningEffort,
|
||||||
supportsTemperature,
|
supportsTemperature,
|
||||||
|
supportsThinking,
|
||||||
|
supportsVerbosity,
|
||||||
} from '@/providers/utils'
|
} from '@/providers/utils'
|
||||||
|
|
||||||
const logger = createLogger('Providers')
|
const logger = createLogger('Providers')
|
||||||
@@ -21,11 +24,24 @@ export const MAX_TOOL_ITERATIONS = 20
|
|||||||
|
|
||||||
function sanitizeRequest(request: ProviderRequest): ProviderRequest {
|
function sanitizeRequest(request: ProviderRequest): ProviderRequest {
|
||||||
const sanitizedRequest = { ...request }
|
const sanitizedRequest = { ...request }
|
||||||
|
const model = sanitizedRequest.model
|
||||||
|
|
||||||
if (sanitizedRequest.model && !supportsTemperature(sanitizedRequest.model)) {
|
if (model && !supportsTemperature(model)) {
|
||||||
sanitizedRequest.temperature = undefined
|
sanitizedRequest.temperature = undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model && !supportsReasoningEffort(model)) {
|
||||||
|
sanitizedRequest.reasoningEffort = undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model && !supportsVerbosity(model)) {
|
||||||
|
sanitizedRequest.verbosity = undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model && !supportsThinking(model)) {
|
||||||
|
sanitizedRequest.thinkingLevel = undefined
|
||||||
|
}
|
||||||
|
|
||||||
return sanitizedRequest
|
return sanitizedRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -33,8 +33,11 @@ import {
|
|||||||
prepareToolExecution,
|
prepareToolExecution,
|
||||||
prepareToolsWithUsageControl,
|
prepareToolsWithUsageControl,
|
||||||
shouldBillModelUsage,
|
shouldBillModelUsage,
|
||||||
|
supportsReasoningEffort,
|
||||||
supportsTemperature,
|
supportsTemperature,
|
||||||
|
supportsThinking,
|
||||||
supportsToolUsageControl,
|
supportsToolUsageControl,
|
||||||
|
supportsVerbosity,
|
||||||
updateOllamaProviderModels,
|
updateOllamaProviderModels,
|
||||||
} from '@/providers/utils'
|
} from '@/providers/utils'
|
||||||
|
|
||||||
@@ -333,6 +336,82 @@ describe('Model Capabilities', () => {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('supportsReasoningEffort', () => {
|
||||||
|
it.concurrent('should return true for models with reasoning effort capability', () => {
|
||||||
|
expect(supportsReasoningEffort('gpt-5')).toBe(true)
|
||||||
|
expect(supportsReasoningEffort('gpt-5-mini')).toBe(true)
|
||||||
|
expect(supportsReasoningEffort('gpt-5.1')).toBe(true)
|
||||||
|
expect(supportsReasoningEffort('gpt-5.2')).toBe(true)
|
||||||
|
expect(supportsReasoningEffort('o3')).toBe(true)
|
||||||
|
expect(supportsReasoningEffort('o4-mini')).toBe(true)
|
||||||
|
expect(supportsReasoningEffort('azure/gpt-5')).toBe(true)
|
||||||
|
expect(supportsReasoningEffort('azure/o3')).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it.concurrent('should return false for models without reasoning effort capability', () => {
|
||||||
|
expect(supportsReasoningEffort('gpt-4o')).toBe(false)
|
||||||
|
expect(supportsReasoningEffort('gpt-4.1')).toBe(false)
|
||||||
|
expect(supportsReasoningEffort('claude-sonnet-4-5')).toBe(false)
|
||||||
|
expect(supportsReasoningEffort('claude-opus-4-6')).toBe(false)
|
||||||
|
expect(supportsReasoningEffort('gemini-2.5-flash')).toBe(false)
|
||||||
|
expect(supportsReasoningEffort('unknown-model')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it.concurrent('should be case-insensitive', () => {
|
||||||
|
expect(supportsReasoningEffort('GPT-5')).toBe(true)
|
||||||
|
expect(supportsReasoningEffort('O3')).toBe(true)
|
||||||
|
expect(supportsReasoningEffort('GPT-4O')).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('supportsVerbosity', () => {
|
||||||
|
it.concurrent('should return true for models with verbosity capability', () => {
|
||||||
|
expect(supportsVerbosity('gpt-5')).toBe(true)
|
||||||
|
expect(supportsVerbosity('gpt-5-mini')).toBe(true)
|
||||||
|
expect(supportsVerbosity('gpt-5.1')).toBe(true)
|
||||||
|
expect(supportsVerbosity('gpt-5.2')).toBe(true)
|
||||||
|
expect(supportsVerbosity('azure/gpt-5')).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it.concurrent('should return false for models without verbosity capability', () => {
|
||||||
|
expect(supportsVerbosity('gpt-4o')).toBe(false)
|
||||||
|
expect(supportsVerbosity('o3')).toBe(false)
|
||||||
|
expect(supportsVerbosity('o4-mini')).toBe(false)
|
||||||
|
expect(supportsVerbosity('claude-sonnet-4-5')).toBe(false)
|
||||||
|
expect(supportsVerbosity('unknown-model')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it.concurrent('should be case-insensitive', () => {
|
||||||
|
expect(supportsVerbosity('GPT-5')).toBe(true)
|
||||||
|
expect(supportsVerbosity('GPT-4O')).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('supportsThinking', () => {
|
||||||
|
it.concurrent('should return true for models with thinking capability', () => {
|
||||||
|
expect(supportsThinking('claude-opus-4-6')).toBe(true)
|
||||||
|
expect(supportsThinking('claude-opus-4-5')).toBe(true)
|
||||||
|
expect(supportsThinking('claude-sonnet-4-5')).toBe(true)
|
||||||
|
expect(supportsThinking('claude-sonnet-4-0')).toBe(true)
|
||||||
|
expect(supportsThinking('claude-haiku-4-5')).toBe(true)
|
||||||
|
expect(supportsThinking('gemini-3-pro-preview')).toBe(true)
|
||||||
|
expect(supportsThinking('gemini-3-flash-preview')).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it.concurrent('should return false for models without thinking capability', () => {
|
||||||
|
expect(supportsThinking('gpt-4o')).toBe(false)
|
||||||
|
expect(supportsThinking('gpt-5')).toBe(false)
|
||||||
|
expect(supportsThinking('o3')).toBe(false)
|
||||||
|
expect(supportsThinking('deepseek-v3')).toBe(false)
|
||||||
|
expect(supportsThinking('unknown-model')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it.concurrent('should be case-insensitive', () => {
|
||||||
|
expect(supportsThinking('CLAUDE-OPUS-4-6')).toBe(true)
|
||||||
|
expect(supportsThinking('GPT-4O')).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe('Model Constants', () => {
|
describe('Model Constants', () => {
|
||||||
it.concurrent('should have correct models in MODELS_TEMP_RANGE_0_2', () => {
|
it.concurrent('should have correct models in MODELS_TEMP_RANGE_0_2', () => {
|
||||||
expect(MODELS_TEMP_RANGE_0_2).toContain('gpt-4o')
|
expect(MODELS_TEMP_RANGE_0_2).toContain('gpt-4o')
|
||||||
|
|||||||
@@ -959,6 +959,18 @@ export function supportsTemperature(model: string): boolean {
|
|||||||
return supportsTemperatureFromDefinitions(model)
|
return supportsTemperatureFromDefinitions(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function supportsReasoningEffort(model: string): boolean {
|
||||||
|
return MODELS_WITH_REASONING_EFFORT.includes(model.toLowerCase())
|
||||||
|
}
|
||||||
|
|
||||||
|
export function supportsVerbosity(model: string): boolean {
|
||||||
|
return MODELS_WITH_VERBOSITY.includes(model.toLowerCase())
|
||||||
|
}
|
||||||
|
|
||||||
|
export function supportsThinking(model: string): boolean {
|
||||||
|
return MODELS_WITH_THINKING.includes(model.toLowerCase())
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the maximum temperature value for a model
|
* Get the maximum temperature value for a model
|
||||||
* @returns Maximum temperature value (1 or 2) or undefined if temperature not supported
|
* @returns Maximum temperature value (1 or 2) or undefined if temperature not supported
|
||||||
|
|||||||
Reference in New Issue
Block a user