feat(mcp): add ALLOWED_MCP_DOMAINS env var for domain allowlist (#3240)

* feat(mcp): add ALLOWED_MCP_DOMAINS env var for domain allowlist

* ack PR comments

* cleanup
This commit is contained in:
Waleed
2026-02-17 18:01:52 -08:00
committed by GitHub
parent 61a5c98717
commit 6421b1a0ca
12 changed files with 412 additions and 43 deletions

View File

@@ -3,6 +3,7 @@ import { mcpServers } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq, isNull } from 'drizzle-orm'
import type { NextRequest } from 'next/server'
import { McpDomainNotAllowedError, validateMcpDomain } from '@/lib/mcp/domain-check'
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
import { mcpService } from '@/lib/mcp/service'
import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils'
@@ -29,6 +30,17 @@ export const PATCH = withMcpAuth<{ id: string }>('write')(
// Remove workspaceId from body to prevent it from being updated
const { workspaceId: _, ...updateData } = body
if (updateData.url) {
try {
validateMcpDomain(updateData.url)
} catch (e) {
if (e instanceof McpDomainNotAllowedError) {
return createMcpErrorResponse(e, e.message, 403)
}
throw e
}
}
// Get the current server to check if URL is changing
const [currentServer] = await db
.select({ url: mcpServers.url })

View File

@@ -3,6 +3,7 @@ import { mcpServers } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq, isNull } from 'drizzle-orm'
import type { NextRequest } from 'next/server'
import { McpDomainNotAllowedError, validateMcpDomain } from '@/lib/mcp/domain-check'
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
import { mcpService } from '@/lib/mcp/service'
import {
@@ -72,6 +73,15 @@ export const POST = withMcpAuth('write')(
)
}
try {
validateMcpDomain(body.url)
} catch (e) {
if (e instanceof McpDomainNotAllowedError) {
return createMcpErrorResponse(e, e.message, 403)
}
throw e
}
const serverId = body.url ? generateMcpServerId(workspaceId, body.url) : crypto.randomUUID()
const [existingServer] = await db

View File

@@ -1,6 +1,7 @@
import { createLogger } from '@sim/logger'
import type { NextRequest } from 'next/server'
import { McpClient } from '@/lib/mcp/client'
import { McpDomainNotAllowedError, validateMcpDomain } from '@/lib/mcp/domain-check'
import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware'
import { resolveMcpConfigEnvVars } from '@/lib/mcp/resolve-config'
import type { McpTransport } from '@/lib/mcp/types'
@@ -71,6 +72,15 @@ export const POST = withMcpAuth('write')(
)
}
try {
validateMcpDomain(body.url)
} catch (e) {
if (e instanceof McpDomainNotAllowedError) {
return createMcpErrorResponse(e, e.message, 403)
}
throw e
}
// Build initial config for resolution
const initialConfig = {
id: `test-${requestId}`,

View File

@@ -0,0 +1,27 @@
import { NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { getAllowedMcpDomainsFromEnv } from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls'
export async function GET() {
const session = await getSession()
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const configuredDomains = getAllowedMcpDomainsFromEnv()
if (configuredDomains === null) {
return NextResponse.json({ allowedMcpDomains: null })
}
try {
const platformHostname = new URL(getBaseUrl()).hostname.toLowerCase()
if (!configuredDomains.includes(platformHostname)) {
return NextResponse.json({
allowedMcpDomains: [...configuredDomains, platformHostname],
})
}
} catch {}
return NextResponse.json({ allowedMcpDomains: configuredDomains })
}

View File

@@ -106,6 +106,21 @@ interface McpServer {
const logger = createLogger('McpSettings')
/**
* Checks if a URL's hostname is in the allowed domains list.
* Returns true if no allowlist is configured (null) or the domain matches.
*/
function isDomainAllowed(url: string | undefined, allowedDomains: string[] | null): boolean {
if (allowedDomains === null) return true
if (!url) return true
try {
const hostname = new URL(url).hostname.toLowerCase()
return allowedDomains.includes(hostname)
} catch {
return false
}
}
const DEFAULT_FORM_DATA: McpServerFormData = {
name: '',
transport: 'streamable-http',
@@ -390,6 +405,15 @@ export function MCP({ initialServerId }: MCPProps) {
} = useMcpServerTest()
const availableEnvVars = useAvailableEnvVarKeys(workspaceId)
const [allowedMcpDomains, setAllowedMcpDomains] = useState<string[] | null>(null)
useEffect(() => {
fetch('/api/settings/allowed-mcp-domains')
.then((res) => res.json())
.then((data) => setAllowedMcpDomains(data.allowedMcpDomains ?? null))
.catch(() => setAllowedMcpDomains(null))
}, [])
const urlInputRef = useRef<HTMLInputElement>(null)
const [showAddForm, setShowAddForm] = useState(false)
@@ -1006,10 +1030,12 @@ export function MCP({ initialServerId }: MCPProps) {
const showNoResults = searchTerm.trim() && filteredServers.length === 0 && servers.length > 0
const isFormValid = formData.name.trim() && formData.url?.trim()
const isSubmitDisabled = serversLoading || isAddingServer || !isFormValid
const isAddDomainBlocked = !isDomainAllowed(formData.url, allowedMcpDomains)
const isSubmitDisabled = serversLoading || isAddingServer || !isFormValid || isAddDomainBlocked
const testButtonLabel = getTestButtonLabel(testResult, isTestingConnection)
const isEditFormValid = editFormData.name.trim() && editFormData.url?.trim()
const isEditDomainBlocked = !isDomainAllowed(editFormData.url, allowedMcpDomains)
const editTestButtonLabel = getTestButtonLabel(editTestResult, isEditTestingConnection)
const hasEditChanges = useMemo(() => {
if (editFormData.name !== editOriginalData.name) return true
@@ -1299,6 +1325,11 @@ export function MCP({ initialServerId }: MCPProps) {
onChange={(e) => handleEditInputChange('url', e.target.value)}
onScroll={setEditUrlScrollLeft}
/>
{isEditDomainBlocked && (
<p className='mt-[4px] text-[12px] text-[var(--text-error)]'>
Domain not permitted by server policy
</p>
)}
</FormField>
<div className='flex flex-col gap-[8px]'>
@@ -1351,7 +1382,7 @@ export function MCP({ initialServerId }: MCPProps) {
<Button
variant='default'
onClick={handleEditTestConnection}
disabled={isEditTestingConnection || !isEditFormValid}
disabled={isEditTestingConnection || !isEditFormValid || isEditDomainBlocked}
>
{editTestButtonLabel}
</Button>
@@ -1361,7 +1392,9 @@ export function MCP({ initialServerId }: MCPProps) {
</Button>
<Button
onClick={handleSaveEdit}
disabled={!hasEditChanges || isUpdatingServer || !isEditFormValid}
disabled={
!hasEditChanges || isUpdatingServer || !isEditFormValid || isEditDomainBlocked
}
variant='tertiary'
>
{isUpdatingServer ? 'Saving...' : 'Save'}
@@ -1434,6 +1467,11 @@ export function MCP({ initialServerId }: MCPProps) {
onChange={(e) => handleInputChange('url', e.target.value)}
onScroll={(scrollLeft) => handleUrlScroll(scrollLeft)}
/>
{isAddDomainBlocked && (
<p className='mt-[4px] text-[12px] text-[var(--text-error)]'>
Domain not permitted by server policy
</p>
)}
</FormField>
<div className='flex flex-col gap-[8px]'>
@@ -1479,7 +1517,7 @@ export function MCP({ initialServerId }: MCPProps) {
<Button
variant='default'
onClick={handleTestConnection}
disabled={isTestingConnection || !isFormValid}
disabled={isTestingConnection || !isFormValid || isAddDomainBlocked}
>
{testButtonLabel}
</Button>
@@ -1489,7 +1527,9 @@ export function MCP({ initialServerId }: MCPProps) {
Cancel
</Button>
<Button onClick={handleAddServer} disabled={isSubmitDisabled} variant='tertiary'>
{isSubmitDisabled && isFormValid ? 'Adding...' : 'Add Server'}
{isSubmitDisabled && isFormValid && !isAddDomainBlocked
? 'Adding...'
: 'Add Server'}
</Button>
</div>
</div>

View File

@@ -2,7 +2,7 @@
import { useState } from 'react'
import { createLogger } from '@sim/logger'
import { Check, ChevronDown, Copy, Eye, EyeOff } from 'lucide-react'
import { Check, ChevronDown, Clipboard, Eye, EyeOff } from 'lucide-react'
import { Button, Combobox, Input, Switch, Textarea } from '@/components/emcn'
import { Skeleton } from '@/components/ui'
import { useSession } from '@/lib/auth/auth-client'
@@ -418,29 +418,29 @@ export function SSO() {
{/* Callback URL */}
<div className='flex flex-col gap-[8px]'>
<span className='font-medium text-[13px] text-[var(--text-secondary)]'>
Callback URL
</span>
<div className='relative'>
<div className='flex h-9 items-center rounded-[6px] border bg-[var(--surface-1)] px-[10px] pr-[40px]'>
<code className='flex-1 truncate font-mono text-[13px] text-[var(--text-primary)]'>
{providerCallbackUrl}
</code>
</div>
<div className='flex items-center justify-between'>
<span className='font-medium text-[13px] text-[var(--text-secondary)]'>
Callback URL
</span>
<Button
type='button'
variant='ghost'
onClick={() => copyToClipboard(providerCallbackUrl)}
className='-translate-y-1/2 absolute top-1/2 right-[4px] h-[28px] w-[28px] rounded-[4px] text-[var(--text-muted)] hover:text-[var(--text-primary)]'
className='h-[22px] w-[22px] rounded-[4px] p-0 text-[var(--text-muted)] hover:text-[var(--text-primary)]'
>
{copied ? (
<Check className='h-[14px] w-[14px]' />
<Check className='h-[13px] w-[13px]' />
) : (
<Copy className='h-[14px] w-[14px]' />
<Clipboard className='h-[13px] w-[13px]' />
)}
<span className='sr-only'>Copy callback URL</span>
</Button>
</div>
<div className='flex h-9 items-center rounded-[6px] border bg-[var(--surface-1)] px-[10px]'>
<code className='flex-1 truncate font-mono text-[13px] text-[var(--text-primary)]'>
{providerCallbackUrl}
</code>
</div>
<p className='text-[13px] text-[var(--text-muted)]'>
Configure this in your identity provider
</p>
@@ -852,29 +852,29 @@ export function SSO() {
{/* Callback URL display */}
<div className='flex flex-col gap-[8px]'>
<span className='font-medium text-[13px] text-[var(--text-secondary)]'>
Callback URL
</span>
<div className='relative'>
<div className='flex h-9 items-center rounded-[6px] border bg-[var(--surface-1)] px-[10px] pr-[40px]'>
<code className='flex-1 truncate font-mono text-[13px] text-[var(--text-primary)]'>
{callbackUrl}
</code>
</div>
<div className='flex items-center justify-between'>
<span className='font-medium text-[13px] text-[var(--text-secondary)]'>
Callback URL
</span>
<Button
type='button'
variant='ghost'
onClick={() => copyToClipboard(callbackUrl)}
className='-translate-y-1/2 absolute top-1/2 right-[4px] h-[28px] w-[28px] rounded-[4px] text-[var(--text-muted)] hover:text-[var(--text-primary)]'
className='h-[22px] w-[22px] rounded-[4px] p-0 text-[var(--text-muted)] hover:text-[var(--text-primary)]'
>
{copied ? (
<Check className='h-[14px] w-[14px]' />
<Check className='h-[13px] w-[13px]' />
) : (
<Copy className='h-[14px] w-[14px]' />
<Clipboard className='h-[13px] w-[13px]' />
)}
<span className='sr-only'>Copy callback URL</span>
</Button>
</div>
<div className='flex h-9 items-center rounded-[6px] border bg-[var(--surface-1)] px-[10px]'>
<code className='flex-1 truncate font-mono text-[13px] text-[var(--text-primary)]'>
{callbackUrl}
</code>
</div>
<p className='text-[13px] text-[var(--text-muted)]'>
Configure this in your identity provider
</p>

View File

@@ -93,6 +93,7 @@ export const env = createEnv({
EXA_API_KEY: z.string().min(1).optional(), // Exa AI API key for enhanced online search
BLACKLISTED_PROVIDERS: z.string().optional(), // Comma-separated provider IDs to hide (e.g., "openai,anthropic")
BLACKLISTED_MODELS: z.string().optional(), // Comma-separated model names/prefixes to hide (e.g., "gpt-4,claude-*")
ALLOWED_MCP_DOMAINS: z.string().optional(), // Comma-separated domains for MCP servers (e.g., "internal.company.com,mcp.example.org"). Empty = all allowed.
// Azure Configuration - Shared credentials with feature-specific models
AZURE_OPENAI_ENDPOINT: z.string().url().optional(), // Shared Azure OpenAI service endpoint

View File

@@ -123,6 +123,35 @@ export const isReactGrabEnabled = isDev && isTruthy(env.REACT_GRAB_ENABLED)
*/
export const isReactScanEnabled = isDev && isTruthy(env.REACT_SCAN_ENABLED)
/**
* Normalizes a domain entry from the ALLOWED_MCP_DOMAINS env var.
* Accepts bare hostnames (e.g., "mcp.company.com") or full URLs (e.g., "https://mcp.company.com").
* Extracts the hostname in either case.
*/
function normalizeDomainEntry(entry: string): string {
const trimmed = entry.trim().toLowerCase()
if (!trimmed) return ''
if (trimmed.includes('://')) {
try {
return new URL(trimmed).hostname
} catch {
return trimmed
}
}
return trimmed
}
/**
* Get allowed MCP server domains from the ALLOWED_MCP_DOMAINS env var.
* Returns null if not set (all domains allowed), or parsed array of lowercase hostnames.
* Accepts both bare hostnames and full URLs in the env var value.
*/
export function getAllowedMcpDomainsFromEnv(): string[] | null {
if (!env.ALLOWED_MCP_DOMAINS) return null
const parsed = env.ALLOWED_MCP_DOMAINS.split(',').map(normalizeDomainEntry).filter(Boolean)
return parsed.length > 0 ? parsed : null
}
/**
* Get cost multiplier based on environment
*/

View File

@@ -0,0 +1,163 @@
/**
* @vitest-environment node
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'
const mockGetAllowedMcpDomainsFromEnv = vi.fn<() => string[] | null>()
const mockGetBaseUrl = vi.fn<() => string>()
vi.doMock('@/lib/core/config/feature-flags', () => ({
getAllowedMcpDomainsFromEnv: mockGetAllowedMcpDomainsFromEnv,
}))
vi.doMock('@/lib/core/utils/urls', () => ({
getBaseUrl: mockGetBaseUrl,
}))
const { McpDomainNotAllowedError, isMcpDomainAllowed, validateMcpDomain } = await import(
'./domain-check'
)
describe('McpDomainNotAllowedError', () => {
it.concurrent('creates error with correct name and message', () => {
const error = new McpDomainNotAllowedError('evil.com')
expect(error).toBeInstanceOf(Error)
expect(error).toBeInstanceOf(McpDomainNotAllowedError)
expect(error.name).toBe('McpDomainNotAllowedError')
expect(error.message).toContain('evil.com')
})
})
describe('isMcpDomainAllowed', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('when no allowlist is configured', () => {
beforeEach(() => {
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(null)
})
it('allows any URL', () => {
expect(isMcpDomainAllowed('https://any-server.com/mcp')).toBe(true)
})
it('allows undefined URL', () => {
expect(isMcpDomainAllowed(undefined)).toBe(true)
})
it('allows empty string URL', () => {
expect(isMcpDomainAllowed('')).toBe(true)
})
})
describe('when allowlist is configured', () => {
beforeEach(() => {
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['allowed.com', 'internal.company.com'])
mockGetBaseUrl.mockReturnValue('https://platform.example.com')
})
it('allows URLs on the allowlist', () => {
expect(isMcpDomainAllowed('https://allowed.com/mcp')).toBe(true)
expect(isMcpDomainAllowed('https://internal.company.com/tools')).toBe(true)
})
it('rejects URLs not on the allowlist', () => {
expect(isMcpDomainAllowed('https://evil.com/mcp')).toBe(false)
})
it('rejects undefined URL (fail-closed)', () => {
expect(isMcpDomainAllowed(undefined)).toBe(false)
})
it('rejects empty string URL (fail-closed)', () => {
expect(isMcpDomainAllowed('')).toBe(false)
})
it('rejects malformed URLs', () => {
expect(isMcpDomainAllowed('not-a-url')).toBe(false)
})
it('matches case-insensitively', () => {
expect(isMcpDomainAllowed('https://ALLOWED.COM/mcp')).toBe(true)
})
it('always allows the platform hostname', () => {
expect(isMcpDomainAllowed('https://platform.example.com/mcp')).toBe(true)
})
it('allows platform hostname even when not in the allowlist', () => {
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['other.com'])
expect(isMcpDomainAllowed('https://platform.example.com/mcp')).toBe(true)
})
})
describe('when getBaseUrl is not configured', () => {
beforeEach(() => {
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['allowed.com'])
mockGetBaseUrl.mockImplementation(() => {
throw new Error('Not configured')
})
})
it('still allows URLs on the allowlist', () => {
expect(isMcpDomainAllowed('https://allowed.com/mcp')).toBe(true)
})
it('still rejects URLs not on the allowlist', () => {
expect(isMcpDomainAllowed('https://evil.com/mcp')).toBe(false)
})
})
})
describe('validateMcpDomain', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('when no allowlist is configured', () => {
beforeEach(() => {
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(null)
})
it('does not throw for any URL', () => {
expect(() => validateMcpDomain('https://any-server.com/mcp')).not.toThrow()
})
it('does not throw for undefined URL', () => {
expect(() => validateMcpDomain(undefined)).not.toThrow()
})
})
describe('when allowlist is configured', () => {
beforeEach(() => {
mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['allowed.com'])
mockGetBaseUrl.mockReturnValue('https://platform.example.com')
})
it('does not throw for allowed URLs', () => {
expect(() => validateMcpDomain('https://allowed.com/mcp')).not.toThrow()
})
it('throws McpDomainNotAllowedError for disallowed URLs', () => {
expect(() => validateMcpDomain('https://evil.com/mcp')).toThrow(McpDomainNotAllowedError)
})
it('throws for undefined URL (fail-closed)', () => {
expect(() => validateMcpDomain(undefined)).toThrow(McpDomainNotAllowedError)
})
it('throws for malformed URLs', () => {
expect(() => validateMcpDomain('not-a-url')).toThrow(McpDomainNotAllowedError)
})
it('includes the rejected domain in the error message', () => {
expect(() => validateMcpDomain('https://evil.com/mcp')).toThrow(/evil\.com/)
})
it('does not throw for platform hostname', () => {
expect(() => validateMcpDomain('https://platform.example.com/mcp')).not.toThrow()
})
})
})

View File

@@ -0,0 +1,69 @@
import { getAllowedMcpDomainsFromEnv } from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls'
export class McpDomainNotAllowedError extends Error {
constructor(domain: string) {
super(`MCP server domain "${domain}" is not allowed by the server's ALLOWED_MCP_DOMAINS policy`)
this.name = 'McpDomainNotAllowedError'
}
}
let cachedPlatformHostname: string | null = null
/**
* Returns the platform's own hostname (from getBaseUrl), lazy-cached.
* Always lowercase. Returns null if the base URL is not configured or invalid.
*/
function getPlatformHostname(): string | null {
if (cachedPlatformHostname !== null) return cachedPlatformHostname
try {
cachedPlatformHostname = new URL(getBaseUrl()).hostname.toLowerCase()
} catch {
return null
}
return cachedPlatformHostname
}
/**
* Core domain check. Returns null if the URL is allowed, or the hostname/url
* string to use in the rejection error.
*/
function checkMcpDomain(url: string): string | null {
const allowedDomains = getAllowedMcpDomainsFromEnv()
if (allowedDomains === null) return null
try {
const hostname = new URL(url).hostname.toLowerCase()
if (hostname === getPlatformHostname()) return null
return allowedDomains.includes(hostname) ? null : hostname
} catch {
return url
}
}
/**
* Returns true if the URL's domain is allowed (or no restriction is configured).
* The platform's own hostname (from getBaseUrl) is always allowed.
*/
export function isMcpDomainAllowed(url: string | undefined): boolean {
if (!url) {
return getAllowedMcpDomainsFromEnv() === null
}
return checkMcpDomain(url) === null
}
/**
* Throws McpDomainNotAllowedError if the URL's domain is not in the allowlist.
* The platform's own hostname (from getBaseUrl) is always allowed.
*/
export function validateMcpDomain(url: string | undefined): void {
if (!url) {
if (getAllowedMcpDomainsFromEnv() !== null) {
throw new McpDomainNotAllowedError('(empty)')
}
return
}
const rejected = checkMcpDomain(url)
if (rejected !== null) {
throw new McpDomainNotAllowedError(rejected)
}
}

View File

@@ -10,6 +10,7 @@ import { isTest } from '@/lib/core/config/feature-flags'
import { generateRequestId } from '@/lib/core/utils/request'
import { McpClient } from '@/lib/mcp/client'
import { mcpConnectionManager } from '@/lib/mcp/connection-manager'
import { isMcpDomainAllowed } from '@/lib/mcp/domain-check'
import { resolveMcpConfigEnvVars } from '@/lib/mcp/resolve-config'
import {
createMcpCacheAdapter,
@@ -93,6 +94,10 @@ class McpService {
return null
}
if (!isMcpDomainAllowed(server.url || undefined)) {
return null
}
return {
id: server.id,
name: server.name,
@@ -123,19 +128,21 @@ class McpService {
.from(mcpServers)
.where(and(...whereConditions))
return servers.map((server) => ({
id: server.id,
name: server.name,
description: server.description || undefined,
transport: server.transport as McpTransport,
url: server.url || undefined,
headers: (server.headers as Record<string, string>) || {},
timeout: server.timeout || 30000,
retries: server.retries || 3,
enabled: server.enabled,
createdAt: server.createdAt.toISOString(),
updatedAt: server.updatedAt.toISOString(),
}))
return servers
.map((server) => ({
id: server.id,
name: server.name,
description: server.description || undefined,
transport: server.transport as McpTransport,
url: server.url || undefined,
headers: (server.headers as Record<string, string>) || {},
timeout: server.timeout || 30000,
retries: server.retries || 3,
enabled: server.enabled,
createdAt: server.createdAt.toISOString(),
updatedAt: server.updatedAt.toISOString(),
}))
.filter((config) => isMcpDomainAllowed(config.url))
}
/**