mirror of
https://github.com/simstudioai/sim.git
synced 2026-02-18 18:25:14 -05:00
feat(mcp): add ALLOWED_MCP_DOMAINS env var for domain allowlist (#3240)
* feat(mcp): add ALLOWED_MCP_DOMAINS env var for domain allowlist * ack PR comments * cleanup
This commit is contained in:
@@ -3,6 +3,7 @@ import { mcpServers } from '@sim/db/schema'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { and, eq, isNull } from 'drizzle-orm'
|
||||
import type { NextRequest } from 'next/server'
|
||||
import { McpDomainNotAllowedError, validateMcpDomain } from '@/lib/mcp/domain-check'
|
||||
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
|
||||
import { mcpService } from '@/lib/mcp/service'
|
||||
import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils'
|
||||
@@ -29,6 +30,17 @@ export const PATCH = withMcpAuth<{ id: string }>('write')(
|
||||
// Remove workspaceId from body to prevent it from being updated
|
||||
const { workspaceId: _, ...updateData } = body
|
||||
|
||||
if (updateData.url) {
|
||||
try {
|
||||
validateMcpDomain(updateData.url)
|
||||
} catch (e) {
|
||||
if (e instanceof McpDomainNotAllowedError) {
|
||||
return createMcpErrorResponse(e, e.message, 403)
|
||||
}
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
// Get the current server to check if URL is changing
|
||||
const [currentServer] = await db
|
||||
.select({ url: mcpServers.url })
|
||||
|
||||
@@ -3,6 +3,7 @@ import { mcpServers } from '@sim/db/schema'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { and, eq, isNull } from 'drizzle-orm'
|
||||
import type { NextRequest } from 'next/server'
|
||||
import { McpDomainNotAllowedError, validateMcpDomain } from '@/lib/mcp/domain-check'
|
||||
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
|
||||
import { mcpService } from '@/lib/mcp/service'
|
||||
import {
|
||||
@@ -72,6 +73,15 @@ export const POST = withMcpAuth('write')(
|
||||
)
|
||||
}
|
||||
|
||||
try {
|
||||
validateMcpDomain(body.url)
|
||||
} catch (e) {
|
||||
if (e instanceof McpDomainNotAllowedError) {
|
||||
return createMcpErrorResponse(e, e.message, 403)
|
||||
}
|
||||
throw e
|
||||
}
|
||||
|
||||
const serverId = body.url ? generateMcpServerId(workspaceId, body.url) : crypto.randomUUID()
|
||||
|
||||
const [existingServer] = await db
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import type { NextRequest } from 'next/server'
|
||||
import { McpClient } from '@/lib/mcp/client'
|
||||
import { McpDomainNotAllowedError, validateMcpDomain } from '@/lib/mcp/domain-check'
|
||||
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
|
||||
import { resolveMcpConfigEnvVars } from '@/lib/mcp/resolve-config'
|
||||
import type { McpTransport } from '@/lib/mcp/types'
|
||||
@@ -71,6 +72,15 @@ export const POST = withMcpAuth('write')(
|
||||
)
|
||||
}
|
||||
|
||||
try {
|
||||
validateMcpDomain(body.url)
|
||||
} catch (e) {
|
||||
if (e instanceof McpDomainNotAllowedError) {
|
||||
return createMcpErrorResponse(e, e.message, 403)
|
||||
}
|
||||
throw e
|
||||
}
|
||||
|
||||
// Build initial config for resolution
|
||||
const initialConfig = {
|
||||
id: `test-${requestId}`,
|
||||
|
||||
27
apps/sim/app/api/settings/allowed-mcp-domains/route.ts
Normal file
27
apps/sim/app/api/settings/allowed-mcp-domains/route.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import { NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { getAllowedMcpDomainsFromEnv } from '@/lib/core/config/feature-flags'
|
||||
import { getBaseUrl } from '@/lib/core/utils/urls'
|
||||
|
||||
export async function GET() {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const configuredDomains = getAllowedMcpDomainsFromEnv()
|
||||
if (configuredDomains === null) {
|
||||
return NextResponse.json({ allowedMcpDomains: null })
|
||||
}
|
||||
|
||||
try {
|
||||
const platformHostname = new URL(getBaseUrl()).hostname.toLowerCase()
|
||||
if (!configuredDomains.includes(platformHostname)) {
|
||||
return NextResponse.json({
|
||||
allowedMcpDomains: [...configuredDomains, platformHostname],
|
||||
})
|
||||
}
|
||||
} catch {}
|
||||
|
||||
return NextResponse.json({ allowedMcpDomains: configuredDomains })
|
||||
}
|
||||
@@ -106,6 +106,21 @@ interface McpServer {
|
||||
|
||||
const logger = createLogger('McpSettings')
|
||||
|
||||
/**
|
||||
* Checks if a URL's hostname is in the allowed domains list.
|
||||
* Returns true if no allowlist is configured (null) or the domain matches.
|
||||
*/
|
||||
function isDomainAllowed(url: string | undefined, allowedDomains: string[] | null): boolean {
|
||||
if (allowedDomains === null) return true
|
||||
if (!url) return true
|
||||
try {
|
||||
const hostname = new URL(url).hostname.toLowerCase()
|
||||
return allowedDomains.includes(hostname)
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_FORM_DATA: McpServerFormData = {
|
||||
name: '',
|
||||
transport: 'streamable-http',
|
||||
@@ -390,6 +405,15 @@ export function MCP({ initialServerId }: MCPProps) {
|
||||
} = useMcpServerTest()
|
||||
const availableEnvVars = useAvailableEnvVarKeys(workspaceId)
|
||||
|
||||
const [allowedMcpDomains, setAllowedMcpDomains] = useState<string[] | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
fetch('/api/settings/allowed-mcp-domains')
|
||||
.then((res) => res.json())
|
||||
.then((data) => setAllowedMcpDomains(data.allowedMcpDomains ?? null))
|
||||
.catch(() => setAllowedMcpDomains(null))
|
||||
}, [])
|
||||
|
||||
const urlInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
const [showAddForm, setShowAddForm] = useState(false)
|
||||
@@ -1006,10 +1030,12 @@ export function MCP({ initialServerId }: MCPProps) {
|
||||
const showNoResults = searchTerm.trim() && filteredServers.length === 0 && servers.length > 0
|
||||
|
||||
const isFormValid = formData.name.trim() && formData.url?.trim()
|
||||
const isSubmitDisabled = serversLoading || isAddingServer || !isFormValid
|
||||
const isAddDomainBlocked = !isDomainAllowed(formData.url, allowedMcpDomains)
|
||||
const isSubmitDisabled = serversLoading || isAddingServer || !isFormValid || isAddDomainBlocked
|
||||
const testButtonLabel = getTestButtonLabel(testResult, isTestingConnection)
|
||||
|
||||
const isEditFormValid = editFormData.name.trim() && editFormData.url?.trim()
|
||||
const isEditDomainBlocked = !isDomainAllowed(editFormData.url, allowedMcpDomains)
|
||||
const editTestButtonLabel = getTestButtonLabel(editTestResult, isEditTestingConnection)
|
||||
const hasEditChanges = useMemo(() => {
|
||||
if (editFormData.name !== editOriginalData.name) return true
|
||||
@@ -1299,6 +1325,11 @@ export function MCP({ initialServerId }: MCPProps) {
|
||||
onChange={(e) => handleEditInputChange('url', e.target.value)}
|
||||
onScroll={setEditUrlScrollLeft}
|
||||
/>
|
||||
{isEditDomainBlocked && (
|
||||
<p className='mt-[4px] text-[12px] text-[var(--text-error)]'>
|
||||
Domain not permitted by server policy
|
||||
</p>
|
||||
)}
|
||||
</FormField>
|
||||
|
||||
<div className='flex flex-col gap-[8px]'>
|
||||
@@ -1351,7 +1382,7 @@ export function MCP({ initialServerId }: MCPProps) {
|
||||
<Button
|
||||
variant='default'
|
||||
onClick={handleEditTestConnection}
|
||||
disabled={isEditTestingConnection || !isEditFormValid}
|
||||
disabled={isEditTestingConnection || !isEditFormValid || isEditDomainBlocked}
|
||||
>
|
||||
{editTestButtonLabel}
|
||||
</Button>
|
||||
@@ -1361,7 +1392,9 @@ export function MCP({ initialServerId }: MCPProps) {
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleSaveEdit}
|
||||
disabled={!hasEditChanges || isUpdatingServer || !isEditFormValid}
|
||||
disabled={
|
||||
!hasEditChanges || isUpdatingServer || !isEditFormValid || isEditDomainBlocked
|
||||
}
|
||||
variant='tertiary'
|
||||
>
|
||||
{isUpdatingServer ? 'Saving...' : 'Save'}
|
||||
@@ -1434,6 +1467,11 @@ export function MCP({ initialServerId }: MCPProps) {
|
||||
onChange={(e) => handleInputChange('url', e.target.value)}
|
||||
onScroll={(scrollLeft) => handleUrlScroll(scrollLeft)}
|
||||
/>
|
||||
{isAddDomainBlocked && (
|
||||
<p className='mt-[4px] text-[12px] text-[var(--text-error)]'>
|
||||
Domain not permitted by server policy
|
||||
</p>
|
||||
)}
|
||||
</FormField>
|
||||
|
||||
<div className='flex flex-col gap-[8px]'>
|
||||
@@ -1479,7 +1517,7 @@ export function MCP({ initialServerId }: MCPProps) {
|
||||
<Button
|
||||
variant='default'
|
||||
onClick={handleTestConnection}
|
||||
disabled={isTestingConnection || !isFormValid}
|
||||
disabled={isTestingConnection || !isFormValid || isAddDomainBlocked}
|
||||
>
|
||||
{testButtonLabel}
|
||||
</Button>
|
||||
@@ -1489,7 +1527,9 @@ export function MCP({ initialServerId }: MCPProps) {
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleAddServer} disabled={isSubmitDisabled} variant='tertiary'>
|
||||
{isSubmitDisabled && isFormValid ? 'Adding...' : 'Add Server'}
|
||||
{isSubmitDisabled && isFormValid && !isAddDomainBlocked
|
||||
? 'Adding...'
|
||||
: 'Add Server'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { useState } from 'react'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { Check, ChevronDown, Copy, Eye, EyeOff } from 'lucide-react'
|
||||
import { Check, ChevronDown, Clipboard, Eye, EyeOff } from 'lucide-react'
|
||||
import { Button, Combobox, Input, Switch, Textarea } from '@/components/emcn'
|
||||
import { Skeleton } from '@/components/ui'
|
||||
import { useSession } from '@/lib/auth/auth-client'
|
||||
@@ -418,29 +418,29 @@ export function SSO() {
|
||||
|
||||
{/* Callback URL */}
|
||||
<div className='flex flex-col gap-[8px]'>
|
||||
<span className='font-medium text-[13px] text-[var(--text-secondary)]'>
|
||||
Callback URL
|
||||
</span>
|
||||
<div className='relative'>
|
||||
<div className='flex h-9 items-center rounded-[6px] border bg-[var(--surface-1)] px-[10px] pr-[40px]'>
|
||||
<code className='flex-1 truncate font-mono text-[13px] text-[var(--text-primary)]'>
|
||||
{providerCallbackUrl}
|
||||
</code>
|
||||
</div>
|
||||
<div className='flex items-center justify-between'>
|
||||
<span className='font-medium text-[13px] text-[var(--text-secondary)]'>
|
||||
Callback URL
|
||||
</span>
|
||||
<Button
|
||||
type='button'
|
||||
variant='ghost'
|
||||
onClick={() => copyToClipboard(providerCallbackUrl)}
|
||||
className='-translate-y-1/2 absolute top-1/2 right-[4px] h-[28px] w-[28px] rounded-[4px] text-[var(--text-muted)] hover:text-[var(--text-primary)]'
|
||||
className='h-[22px] w-[22px] rounded-[4px] p-0 text-[var(--text-muted)] hover:text-[var(--text-primary)]'
|
||||
>
|
||||
{copied ? (
|
||||
<Check className='h-[14px] w-[14px]' />
|
||||
<Check className='h-[13px] w-[13px]' />
|
||||
) : (
|
||||
<Copy className='h-[14px] w-[14px]' />
|
||||
<Clipboard className='h-[13px] w-[13px]' />
|
||||
)}
|
||||
<span className='sr-only'>Copy callback URL</span>
|
||||
</Button>
|
||||
</div>
|
||||
<div className='flex h-9 items-center rounded-[6px] border bg-[var(--surface-1)] px-[10px]'>
|
||||
<code className='flex-1 truncate font-mono text-[13px] text-[var(--text-primary)]'>
|
||||
{providerCallbackUrl}
|
||||
</code>
|
||||
</div>
|
||||
<p className='text-[13px] text-[var(--text-muted)]'>
|
||||
Configure this in your identity provider
|
||||
</p>
|
||||
@@ -852,29 +852,29 @@ export function SSO() {
|
||||
|
||||
{/* Callback URL display */}
|
||||
<div className='flex flex-col gap-[8px]'>
|
||||
<span className='font-medium text-[13px] text-[var(--text-secondary)]'>
|
||||
Callback URL
|
||||
</span>
|
||||
<div className='relative'>
|
||||
<div className='flex h-9 items-center rounded-[6px] border bg-[var(--surface-1)] px-[10px] pr-[40px]'>
|
||||
<code className='flex-1 truncate font-mono text-[13px] text-[var(--text-primary)]'>
|
||||
{callbackUrl}
|
||||
</code>
|
||||
</div>
|
||||
<div className='flex items-center justify-between'>
|
||||
<span className='font-medium text-[13px] text-[var(--text-secondary)]'>
|
||||
Callback URL
|
||||
</span>
|
||||
<Button
|
||||
type='button'
|
||||
variant='ghost'
|
||||
onClick={() => copyToClipboard(callbackUrl)}
|
||||
className='-translate-y-1/2 absolute top-1/2 right-[4px] h-[28px] w-[28px] rounded-[4px] text-[var(--text-muted)] hover:text-[var(--text-primary)]'
|
||||
className='h-[22px] w-[22px] rounded-[4px] p-0 text-[var(--text-muted)] hover:text-[var(--text-primary)]'
|
||||
>
|
||||
{copied ? (
|
||||
<Check className='h-[14px] w-[14px]' />
|
||||
<Check className='h-[13px] w-[13px]' />
|
||||
) : (
|
||||
<Copy className='h-[14px] w-[14px]' />
|
||||
<Clipboard className='h-[13px] w-[13px]' />
|
||||
)}
|
||||
<span className='sr-only'>Copy callback URL</span>
|
||||
</Button>
|
||||
</div>
|
||||
<div className='flex h-9 items-center rounded-[6px] border bg-[var(--surface-1)] px-[10px]'>
|
||||
<code className='flex-1 truncate font-mono text-[13px] text-[var(--text-primary)]'>
|
||||
{callbackUrl}
|
||||
</code>
|
||||
</div>
|
||||
<p className='text-[13px] text-[var(--text-muted)]'>
|
||||
Configure this in your identity provider
|
||||
</p>
|
||||
|
||||
@@ -93,6 +93,7 @@ export const env = createEnv({
|
||||
EXA_API_KEY: z.string().min(1).optional(), // Exa AI API key for enhanced online search
|
||||
BLACKLISTED_PROVIDERS: z.string().optional(), // Comma-separated provider IDs to hide (e.g., "openai,anthropic")
|
||||
BLACKLISTED_MODELS: z.string().optional(), // Comma-separated model names/prefixes to hide (e.g., "gpt-4,claude-*")
|
||||
ALLOWED_MCP_DOMAINS: z.string().optional(), // Comma-separated domains for MCP servers (e.g., "internal.company.com,mcp.example.org"). Empty = all allowed.
|
||||
|
||||
// Azure Configuration - Shared credentials with feature-specific models
|
||||
AZURE_OPENAI_ENDPOINT: z.string().url().optional(), // Shared Azure OpenAI service endpoint
|
||||
|
||||
@@ -123,6 +123,35 @@ export const isReactGrabEnabled = isDev && isTruthy(env.REACT_GRAB_ENABLED)
|
||||
*/
|
||||
export const isReactScanEnabled = isDev && isTruthy(env.REACT_SCAN_ENABLED)
|
||||
|
||||
/**
|
||||
* Normalizes a domain entry from the ALLOWED_MCP_DOMAINS env var.
|
||||
* Accepts bare hostnames (e.g., "mcp.company.com") or full URLs (e.g., "https://mcp.company.com").
|
||||
* Extracts the hostname in either case.
|
||||
*/
|
||||
function normalizeDomainEntry(entry: string): string {
|
||||
const trimmed = entry.trim().toLowerCase()
|
||||
if (!trimmed) return ''
|
||||
if (trimmed.includes('://')) {
|
||||
try {
|
||||
return new URL(trimmed).hostname
|
||||
} catch {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
/**
|
||||
* Get allowed MCP server domains from the ALLOWED_MCP_DOMAINS env var.
|
||||
* Returns null if not set (all domains allowed), or parsed array of lowercase hostnames.
|
||||
* Accepts both bare hostnames and full URLs in the env var value.
|
||||
*/
|
||||
export function getAllowedMcpDomainsFromEnv(): string[] | null {
|
||||
if (!env.ALLOWED_MCP_DOMAINS) return null
|
||||
const parsed = env.ALLOWED_MCP_DOMAINS.split(',').map(normalizeDomainEntry).filter(Boolean)
|
||||
return parsed.length > 0 ? parsed : null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get cost multiplier based on environment
|
||||
*/
|
||||
|
||||
163
apps/sim/lib/mcp/domain-check.test.ts
Normal file
163
apps/sim/lib/mcp/domain-check.test.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
/**
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const mockGetAllowedMcpDomainsFromEnv = vi.fn<() => string[] | null>()
|
||||
const mockGetBaseUrl = vi.fn<() => string>()
|
||||
|
||||
vi.doMock('@/lib/core/config/feature-flags', () => ({
|
||||
getAllowedMcpDomainsFromEnv: mockGetAllowedMcpDomainsFromEnv,
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/core/utils/urls', () => ({
|
||||
getBaseUrl: mockGetBaseUrl,
|
||||
}))
|
||||
|
||||
const { McpDomainNotAllowedError, isMcpDomainAllowed, validateMcpDomain } = await import(
|
||||
'./domain-check'
|
||||
)
|
||||
|
||||
describe('McpDomainNotAllowedError', () => {
|
||||
it.concurrent('creates error with correct name and message', () => {
|
||||
const error = new McpDomainNotAllowedError('evil.com')
|
||||
|
||||
expect(error).toBeInstanceOf(Error)
|
||||
expect(error).toBeInstanceOf(McpDomainNotAllowedError)
|
||||
expect(error.name).toBe('McpDomainNotAllowedError')
|
||||
expect(error.message).toContain('evil.com')
|
||||
})
|
||||
})
|
||||
|
||||
describe('isMcpDomainAllowed', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('when no allowlist is configured', () => {
|
||||
beforeEach(() => {
|
||||
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(null)
|
||||
})
|
||||
|
||||
it('allows any URL', () => {
|
||||
expect(isMcpDomainAllowed('https://any-server.com/mcp')).toBe(true)
|
||||
})
|
||||
|
||||
it('allows undefined URL', () => {
|
||||
expect(isMcpDomainAllowed(undefined)).toBe(true)
|
||||
})
|
||||
|
||||
it('allows empty string URL', () => {
|
||||
expect(isMcpDomainAllowed('')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when allowlist is configured', () => {
|
||||
beforeEach(() => {
|
||||
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['allowed.com', 'internal.company.com'])
|
||||
mockGetBaseUrl.mockReturnValue('https://platform.example.com')
|
||||
})
|
||||
|
||||
it('allows URLs on the allowlist', () => {
|
||||
expect(isMcpDomainAllowed('https://allowed.com/mcp')).toBe(true)
|
||||
expect(isMcpDomainAllowed('https://internal.company.com/tools')).toBe(true)
|
||||
})
|
||||
|
||||
it('rejects URLs not on the allowlist', () => {
|
||||
expect(isMcpDomainAllowed('https://evil.com/mcp')).toBe(false)
|
||||
})
|
||||
|
||||
it('rejects undefined URL (fail-closed)', () => {
|
||||
expect(isMcpDomainAllowed(undefined)).toBe(false)
|
||||
})
|
||||
|
||||
it('rejects empty string URL (fail-closed)', () => {
|
||||
expect(isMcpDomainAllowed('')).toBe(false)
|
||||
})
|
||||
|
||||
it('rejects malformed URLs', () => {
|
||||
expect(isMcpDomainAllowed('not-a-url')).toBe(false)
|
||||
})
|
||||
|
||||
it('matches case-insensitively', () => {
|
||||
expect(isMcpDomainAllowed('https://ALLOWED.COM/mcp')).toBe(true)
|
||||
})
|
||||
|
||||
it('always allows the platform hostname', () => {
|
||||
expect(isMcpDomainAllowed('https://platform.example.com/mcp')).toBe(true)
|
||||
})
|
||||
|
||||
it('allows platform hostname even when not in the allowlist', () => {
|
||||
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['other.com'])
|
||||
expect(isMcpDomainAllowed('https://platform.example.com/mcp')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when getBaseUrl is not configured', () => {
|
||||
beforeEach(() => {
|
||||
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['allowed.com'])
|
||||
mockGetBaseUrl.mockImplementation(() => {
|
||||
throw new Error('Not configured')
|
||||
})
|
||||
})
|
||||
|
||||
it('still allows URLs on the allowlist', () => {
|
||||
expect(isMcpDomainAllowed('https://allowed.com/mcp')).toBe(true)
|
||||
})
|
||||
|
||||
it('still rejects URLs not on the allowlist', () => {
|
||||
expect(isMcpDomainAllowed('https://evil.com/mcp')).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('validateMcpDomain', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('when no allowlist is configured', () => {
|
||||
beforeEach(() => {
|
||||
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(null)
|
||||
})
|
||||
|
||||
it('does not throw for any URL', () => {
|
||||
expect(() => validateMcpDomain('https://any-server.com/mcp')).not.toThrow()
|
||||
})
|
||||
|
||||
it('does not throw for undefined URL', () => {
|
||||
expect(() => validateMcpDomain(undefined)).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('when allowlist is configured', () => {
|
||||
beforeEach(() => {
|
||||
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['allowed.com'])
|
||||
mockGetBaseUrl.mockReturnValue('https://platform.example.com')
|
||||
})
|
||||
|
||||
it('does not throw for allowed URLs', () => {
|
||||
expect(() => validateMcpDomain('https://allowed.com/mcp')).not.toThrow()
|
||||
})
|
||||
|
||||
it('throws McpDomainNotAllowedError for disallowed URLs', () => {
|
||||
expect(() => validateMcpDomain('https://evil.com/mcp')).toThrow(McpDomainNotAllowedError)
|
||||
})
|
||||
|
||||
it('throws for undefined URL (fail-closed)', () => {
|
||||
expect(() => validateMcpDomain(undefined)).toThrow(McpDomainNotAllowedError)
|
||||
})
|
||||
|
||||
it('throws for malformed URLs', () => {
|
||||
expect(() => validateMcpDomain('not-a-url')).toThrow(McpDomainNotAllowedError)
|
||||
})
|
||||
|
||||
it('includes the rejected domain in the error message', () => {
|
||||
expect(() => validateMcpDomain('https://evil.com/mcp')).toThrow(/evil\.com/)
|
||||
})
|
||||
|
||||
it('does not throw for platform hostname', () => {
|
||||
expect(() => validateMcpDomain('https://platform.example.com/mcp')).not.toThrow()
|
||||
})
|
||||
})
|
||||
})
|
||||
69
apps/sim/lib/mcp/domain-check.ts
Normal file
69
apps/sim/lib/mcp/domain-check.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import { getAllowedMcpDomainsFromEnv } from '@/lib/core/config/feature-flags'
|
||||
import { getBaseUrl } from '@/lib/core/utils/urls'
|
||||
|
||||
export class McpDomainNotAllowedError extends Error {
|
||||
constructor(domain: string) {
|
||||
super(`MCP server domain "${domain}" is not allowed by the server's ALLOWED_MCP_DOMAINS policy`)
|
||||
this.name = 'McpDomainNotAllowedError'
|
||||
}
|
||||
}
|
||||
|
||||
let cachedPlatformHostname: string | null = null
|
||||
|
||||
/**
|
||||
* Returns the platform's own hostname (from getBaseUrl), lazy-cached.
|
||||
* Always lowercase. Returns null if the base URL is not configured or invalid.
|
||||
*/
|
||||
function getPlatformHostname(): string | null {
|
||||
if (cachedPlatformHostname !== null) return cachedPlatformHostname
|
||||
try {
|
||||
cachedPlatformHostname = new URL(getBaseUrl()).hostname.toLowerCase()
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
return cachedPlatformHostname
|
||||
}
|
||||
|
||||
/**
|
||||
* Core domain check. Returns null if the URL is allowed, or the hostname/url
|
||||
* string to use in the rejection error.
|
||||
*/
|
||||
function checkMcpDomain(url: string): string | null {
|
||||
const allowedDomains = getAllowedMcpDomainsFromEnv()
|
||||
if (allowedDomains === null) return null
|
||||
try {
|
||||
const hostname = new URL(url).hostname.toLowerCase()
|
||||
if (hostname === getPlatformHostname()) return null
|
||||
return allowedDomains.includes(hostname) ? null : hostname
|
||||
} catch {
|
||||
return url
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the URL's domain is allowed (or no restriction is configured).
|
||||
* The platform's own hostname (from getBaseUrl) is always allowed.
|
||||
*/
|
||||
export function isMcpDomainAllowed(url: string | undefined): boolean {
|
||||
if (!url) {
|
||||
return getAllowedMcpDomainsFromEnv() === null
|
||||
}
|
||||
return checkMcpDomain(url) === null
|
||||
}
|
||||
|
||||
/**
|
||||
* Throws McpDomainNotAllowedError if the URL's domain is not in the allowlist.
|
||||
* The platform's own hostname (from getBaseUrl) is always allowed.
|
||||
*/
|
||||
export function validateMcpDomain(url: string | undefined): void {
|
||||
if (!url) {
|
||||
if (getAllowedMcpDomainsFromEnv() !== null) {
|
||||
throw new McpDomainNotAllowedError('(empty)')
|
||||
}
|
||||
return
|
||||
}
|
||||
const rejected = checkMcpDomain(url)
|
||||
if (rejected !== null) {
|
||||
throw new McpDomainNotAllowedError(rejected)
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import { isTest } from '@/lib/core/config/feature-flags'
|
||||
import { generateRequestId } from '@/lib/core/utils/request'
|
||||
import { McpClient } from '@/lib/mcp/client'
|
||||
import { mcpConnectionManager } from '@/lib/mcp/connection-manager'
|
||||
import { isMcpDomainAllowed } from '@/lib/mcp/domain-check'
|
||||
import { resolveMcpConfigEnvVars } from '@/lib/mcp/resolve-config'
|
||||
import {
|
||||
createMcpCacheAdapter,
|
||||
@@ -93,6 +94,10 @@ class McpService {
|
||||
return null
|
||||
}
|
||||
|
||||
if (!isMcpDomainAllowed(server.url || undefined)) {
|
||||
return null
|
||||
}
|
||||
|
||||
return {
|
||||
id: server.id,
|
||||
name: server.name,
|
||||
@@ -123,19 +128,21 @@ class McpService {
|
||||
.from(mcpServers)
|
||||
.where(and(...whereConditions))
|
||||
|
||||
return servers.map((server) => ({
|
||||
id: server.id,
|
||||
name: server.name,
|
||||
description: server.description || undefined,
|
||||
transport: server.transport as McpTransport,
|
||||
url: server.url || undefined,
|
||||
headers: (server.headers as Record<string, string>) || {},
|
||||
timeout: server.timeout || 30000,
|
||||
retries: server.retries || 3,
|
||||
enabled: server.enabled,
|
||||
createdAt: server.createdAt.toISOString(),
|
||||
updatedAt: server.updatedAt.toISOString(),
|
||||
}))
|
||||
return servers
|
||||
.map((server) => ({
|
||||
id: server.id,
|
||||
name: server.name,
|
||||
description: server.description || undefined,
|
||||
transport: server.transport as McpTransport,
|
||||
url: server.url || undefined,
|
||||
headers: (server.headers as Record<string, string>) || {},
|
||||
timeout: server.timeout || 30000,
|
||||
retries: server.retries || 3,
|
||||
enabled: server.enabled,
|
||||
createdAt: server.createdAt.toISOString(),
|
||||
updatedAt: server.updatedAt.toISOString(),
|
||||
}))
|
||||
.filter((config) => isMcpDomainAllowed(config.url))
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user