feat(blacklist): added ability to blacklist models & providers (#2709)

* feat(blacklist): added ability to blacklist models & providers

* ack PR comments
This commit is contained in:
Waleed
2026-01-07 10:41:57 -08:00
committed by GitHub
parent 3ecf7a15eb
commit 261becd129
10 changed files with 146 additions and 48 deletions

View File

@@ -2,6 +2,7 @@ import { createLogger } from '@sim/logger'
import { type NextRequest, NextResponse } from 'next/server'
import { env } from '@/lib/core/config/env'
import type { ModelsObject } from '@/providers/ollama/types'
import { filterBlacklistedModels, isProviderBlacklisted } from '@/providers/utils'
const logger = createLogger('OllamaModelsAPI')
const OLLAMA_HOST = env.OLLAMA_URL || 'http://localhost:11434'
@@ -9,7 +10,12 @@ const OLLAMA_HOST = env.OLLAMA_URL || 'http://localhost:11434'
/**
* Get available Ollama models
*/
export async function GET(request: NextRequest) {
export async function GET(_request: NextRequest) {
if (isProviderBlacklisted('ollama')) {
logger.info('Ollama provider is blacklisted, returning empty models')
return NextResponse.json({ models: [] })
}
try {
logger.info('Fetching Ollama models', {
host: OLLAMA_HOST,
@@ -31,10 +37,12 @@ export async function GET(request: NextRequest) {
}
const data = (await response.json()) as ModelsObject
const models = data.models.map((model) => model.name)
const allModels = data.models.map((model) => model.name)
const models = filterBlacklistedModels(allModels)
logger.info('Successfully fetched Ollama models', {
count: models.length,
filtered: allModels.length - models.length,
models,
})

View File

@@ -1,6 +1,6 @@
import { createLogger } from '@sim/logger'
import { type NextRequest, NextResponse } from 'next/server'
import { filterBlacklistedModels } from '@/providers/utils'
import { filterBlacklistedModels, isProviderBlacklisted } from '@/providers/utils'
const logger = createLogger('OpenRouterModelsAPI')
@@ -30,6 +30,11 @@ export interface OpenRouterModelInfo {
}
export async function GET(_request: NextRequest) {
if (isProviderBlacklisted('openrouter')) {
logger.info('OpenRouter provider is blacklisted, returning empty models')
return NextResponse.json({ models: [], modelInfo: {} })
}
try {
const response = await fetch('https://openrouter.ai/api/v1/models', {
headers: { 'Content-Type': 'application/json' },

View File

@@ -1,13 +1,19 @@
import { createLogger } from '@sim/logger'
import { type NextRequest, NextResponse } from 'next/server'
import { env } from '@/lib/core/config/env'
import { filterBlacklistedModels, isProviderBlacklisted } from '@/providers/utils'
const logger = createLogger('VLLMModelsAPI')
/**
* Get available vLLM models
*/
export async function GET(request: NextRequest) {
export async function GET(_request: NextRequest) {
if (isProviderBlacklisted('vllm')) {
logger.info('vLLM provider is blacklisted, returning empty models')
return NextResponse.json({ models: [] })
}
const baseUrl = (env.VLLM_BASE_URL || '').replace(/\/$/, '')
if (!baseUrl) {
@@ -42,10 +48,12 @@ export async function GET(request: NextRequest) {
}
const data = (await response.json()) as { data: Array<{ id: string }> }
const models = data.data.map((model) => `vllm/${model.id}`)
const allModels = data.data.map((model) => `vllm/${model.id}`)
const models = filterBlacklistedModels(allModels)
logger.info('Successfully fetched vLLM models', {
count: models.length,
filtered: allModels.length - models.length,
models,
})

View File

@@ -4,7 +4,7 @@ import { isHosted } from '@/lib/core/config/feature-flags'
import type { BlockConfig } from '@/blocks/types'
import { AuthMode } from '@/blocks/types'
import {
getAllModelProviders,
getBaseModelProviders,
getHostedModels,
getMaxTemperature,
getProviderIcon,
@@ -417,7 +417,7 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
condition: () => ({
field: 'model',
value: (() => {
const allModels = Object.keys(getAllModelProviders())
const allModels = Object.keys(getBaseModelProviders())
return allModels.filter(
(model) => supportsTemperature(model) && getMaxTemperature(model) === 1
)
@@ -434,7 +434,7 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
condition: () => ({
field: 'model',
value: (() => {
const allModels = Object.keys(getAllModelProviders())
const allModels = Object.keys(getBaseModelProviders())
return allModels.filter(
(model) => supportsTemperature(model) && getMaxTemperature(model) === 2
)
@@ -555,7 +555,7 @@ Example 3 (Array Input):
if (!model) {
throw new Error('No model selected')
}
const tool = getAllModelProviders()[model]
const tool = getBaseModelProviders()[model]
if (!tool) {
throw new Error(`Invalid model selected: ${model}`)
}

View File

@@ -4,7 +4,7 @@ import { isHosted } from '@/lib/core/config/feature-flags'
import type { BlockConfig, ParamType } from '@/blocks/types'
import type { ProviderId } from '@/providers/types'
import {
getAllModelProviders,
getBaseModelProviders,
getHostedModels,
getProviderIcon,
providers,
@@ -357,7 +357,7 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
if (!model) {
throw new Error('No model selected')
}
const tool = getAllModelProviders()[model as ProviderId]
const tool = getBaseModelProviders()[model as ProviderId]
if (!tool) {
throw new Error(`Invalid model selected: ${model}`)
}

View File

@@ -3,7 +3,7 @@ import { isHosted } from '@/lib/core/config/feature-flags'
import { AuthMode, type BlockConfig } from '@/blocks/types'
import type { ProviderId } from '@/providers/types'
import {
getAllModelProviders,
getBaseModelProviders,
getHostedModels,
getProviderIcon,
providers,
@@ -324,7 +324,7 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
if (!model) {
throw new Error('No model selected')
}
const tool = getAllModelProviders()[model as ProviderId]
const tool = getBaseModelProviders()[model as ProviderId]
if (!tool) {
throw new Error(`Invalid model selected: ${model}`)
}
@@ -508,7 +508,7 @@ export const RouterV2Block: BlockConfig<RouterV2Response> = {
if (!model) {
throw new Error('No model selected')
}
const tool = getAllModelProviders()[model as ProviderId]
const tool = getBaseModelProviders()[model as ProviderId]
if (!tool) {
throw new Error(`Invalid model selected: ${model}`)
}

View File

@@ -87,7 +87,8 @@ export const env = createEnv({
ELEVENLABS_API_KEY: z.string().min(1).optional(), // ElevenLabs API key for text-to-speech in deployed chat
SERPER_API_KEY: z.string().min(1).optional(), // Serper API key for online search
EXA_API_KEY: z.string().min(1).optional(), // Exa AI API key for enhanced online search
DEEPSEEK_MODELS_ENABLED: z.boolean().optional().default(false), // Enable Deepseek models in UI (defaults to false for compliance)
BLACKLISTED_PROVIDERS: z.string().optional(), // Comma-separated provider IDs to hide (e.g., "openai,anthropic")
BLACKLISTED_MODELS: z.string().optional(), // Comma-separated model names/prefixes to hide (e.g., "gpt-4,claude-*")
// Azure Configuration - Shared credentials with feature-specific models
AZURE_OPENAI_ENDPOINT: z.string().url().optional(), // Shared Azure OpenAI service endpoint

View File

@@ -3,6 +3,7 @@ import * as environmentModule from '@/lib/core/config/feature-flags'
import {
calculateCost,
extractAndParseJSON,
filterBlacklistedModels,
formatCost,
generateStructuredOutputInstructions,
getAllModelProviders,
@@ -17,6 +18,7 @@ import {
getProviderConfigFromModel,
getProviderFromModel,
getProviderModels,
isProviderBlacklisted,
MODELS_TEMP_RANGE_0_1,
MODELS_TEMP_RANGE_0_2,
MODELS_WITH_REASONING_EFFORT,
@@ -976,3 +978,46 @@ describe('Tool Management', () => {
})
})
})
describe('Provider/Model Blacklist', () => {
describe('isProviderBlacklisted', () => {
it.concurrent('should return false when no providers are blacklisted', () => {
expect(isProviderBlacklisted('openai')).toBe(false)
expect(isProviderBlacklisted('anthropic')).toBe(false)
})
})
describe('filterBlacklistedModels', () => {
it.concurrent('should return all models when no blacklist is set', () => {
const models = ['gpt-4o', 'claude-sonnet-4-5', 'gemini-2.5-pro']
const result = filterBlacklistedModels(models)
expect(result).toEqual(models)
})
it.concurrent('should return empty array for empty input', () => {
const result = filterBlacklistedModels([])
expect(result).toEqual([])
})
})
describe('getBaseModelProviders blacklist filtering', () => {
it.concurrent('should return providers when no blacklist is set', () => {
const providers = getBaseModelProviders()
expect(Object.keys(providers).length).toBeGreaterThan(0)
expect(providers['gpt-4o']).toBe('openai')
expect(providers['claude-sonnet-4-5']).toBe('anthropic')
})
})
describe('getProviderFromModel execution-time enforcement', () => {
it.concurrent('should return provider for non-blacklisted models', () => {
expect(getProviderFromModel('gpt-4o')).toBe('openai')
expect(getProviderFromModel('claude-sonnet-4-5')).toBe('anthropic')
})
it.concurrent('should be case insensitive', () => {
expect(getProviderFromModel('GPT-4O')).toBe('openai')
expect(getProviderFromModel('CLAUDE-SONNET-4-5')).toBe('anthropic')
})
})
})

View File

@@ -1,7 +1,7 @@
import { createLogger, type Logger } from '@sim/logger'
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import { getEnv, isTruthy } from '@/lib/core/config/env'
import { env } from '@/lib/core/config/env'
import { isHosted } from '@/lib/core/config/feature-flags'
import { isCustomTool } from '@/executor/constants'
import {
@@ -131,6 +131,9 @@ function filterBlacklistedModelsFromProviderMap(
): Record<string, ProviderId> {
const filtered: Record<string, ProviderId> = {}
for (const [model, providerId] of Object.entries(providerMap)) {
if (isProviderBlacklisted(providerId)) {
continue
}
if (!isModelBlacklisted(model)) {
filtered[model] = providerId
}
@@ -152,22 +155,39 @@ export function getAllModelProviders(): Record<string, ProviderId> {
export function getProviderFromModel(model: string): ProviderId {
const normalizedModel = model.toLowerCase()
if (normalizedModel in getAllModelProviders()) {
return getAllModelProviders()[normalizedModel]
}
for (const [providerId, config] of Object.entries(providers)) {
if (config.modelPatterns) {
for (const pattern of config.modelPatterns) {
if (pattern.test(normalizedModel)) {
return providerId as ProviderId
let providerId: ProviderId | null = null
if (normalizedModel in getAllModelProviders()) {
providerId = getAllModelProviders()[normalizedModel]
} else {
for (const [id, config] of Object.entries(providers)) {
if (config.modelPatterns) {
for (const pattern of config.modelPatterns) {
if (pattern.test(normalizedModel)) {
providerId = id as ProviderId
break
}
}
}
if (providerId) break
}
}
logger.warn(`No provider found for model: ${model}, defaulting to ollama`)
return 'ollama'
if (!providerId) {
logger.warn(`No provider found for model: ${model}, defaulting to ollama`)
providerId = 'ollama'
}
if (isProviderBlacklisted(providerId)) {
throw new Error(`Provider "${providerId}" is not available`)
}
if (isModelBlacklisted(normalizedModel)) {
throw new Error(`Model "${model}" is not available`)
}
return providerId
}
export function getProvider(id: string): ProviderMetadata | undefined {
@@ -192,35 +212,42 @@ export function getProviderModels(providerId: ProviderId): string[] {
return getProviderModelsFromDefinitions(providerId)
}
interface ModelBlacklist {
models: string[]
prefixes: string[]
envOverride?: string
function getBlacklistedProviders(): string[] {
if (!env.BLACKLISTED_PROVIDERS) return []
return env.BLACKLISTED_PROVIDERS.split(',').map((p) => p.trim().toLowerCase())
}
const MODEL_BLACKLISTS: ModelBlacklist[] = [
{
models: ['deepseek-chat', 'deepseek-v3', 'deepseek-r1'],
prefixes: ['openrouter/deepseek', 'openrouter/tngtech'],
envOverride: 'DEEPSEEK_MODELS_ENABLED',
},
]
export function isProviderBlacklisted(providerId: string): boolean {
const blacklist = getBlacklistedProviders()
return blacklist.includes(providerId.toLowerCase())
}
/**
* Get the list of blacklisted models from env var.
* BLACKLISTED_MODELS supports:
* - Exact model names: "gpt-4,claude-3-opus"
* - Prefix patterns with *: "claude-*,gpt-4-*" (matches models starting with that prefix)
*/
function getBlacklistedModels(): { models: string[]; prefixes: string[] } {
if (!env.BLACKLISTED_MODELS) return { models: [], prefixes: [] }
const entries = env.BLACKLISTED_MODELS.split(',').map((m) => m.trim().toLowerCase())
const models = entries.filter((e) => !e.endsWith('*'))
const prefixes = entries.filter((e) => e.endsWith('*')).map((e) => e.slice(0, -1))
return { models, prefixes }
}
function isModelBlacklisted(model: string): boolean {
const lowerModel = model.toLowerCase()
const blacklist = getBlacklistedModels()
for (const blacklist of MODEL_BLACKLISTS) {
if (blacklist.envOverride && isTruthy(getEnv(blacklist.envOverride))) {
continue
}
if (blacklist.models.includes(lowerModel)) {
return true
}
if (blacklist.models.includes(lowerModel)) {
return true
}
if (blacklist.prefixes.some((prefix) => lowerModel.startsWith(prefix))) {
return true
}
if (blacklist.prefixes.some((prefix) => lowerModel.startsWith(prefix))) {
return true
}
return false

View File

@@ -117,6 +117,10 @@ app:
ALLOWED_LOGIN_EMAILS: "" # Comma-separated list of allowed email addresses for login
ALLOWED_LOGIN_DOMAINS: "" # Comma-separated list of allowed email domains for login
# LLM Provider/Model Restrictions (leave empty if not restricting)
BLACKLISTED_PROVIDERS: "" # Comma-separated provider IDs to hide from UI (e.g., "openai,anthropic,google")
BLACKLISTED_MODELS: "" # Comma-separated model names/prefixes to hide (e.g., "gpt-4,claude-*")
# SSO Configuration (Enterprise Single Sign-On)
# Set to "true" AFTER running the SSO registration script
SSO_ENABLED: "" # Enable SSO authentication ("true" to enable)