mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
feat(blacklist): added ability to blacklist models & providers (#2709)
* feat(blacklist): added ability to blacklist models & providers * ack PR comments
This commit is contained in:
@@ -2,6 +2,7 @@ import { createLogger } from '@sim/logger'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import type { ModelsObject } from '@/providers/ollama/types'
|
||||
import { filterBlacklistedModels, isProviderBlacklisted } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('OllamaModelsAPI')
|
||||
const OLLAMA_HOST = env.OLLAMA_URL || 'http://localhost:11434'
|
||||
@@ -9,7 +10,12 @@ const OLLAMA_HOST = env.OLLAMA_URL || 'http://localhost:11434'
|
||||
/**
|
||||
* Get available Ollama models
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
export async function GET(_request: NextRequest) {
|
||||
if (isProviderBlacklisted('ollama')) {
|
||||
logger.info('Ollama provider is blacklisted, returning empty models')
|
||||
return NextResponse.json({ models: [] })
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info('Fetching Ollama models', {
|
||||
host: OLLAMA_HOST,
|
||||
@@ -31,10 +37,12 @@ export async function GET(request: NextRequest) {
|
||||
}
|
||||
|
||||
const data = (await response.json()) as ModelsObject
|
||||
const models = data.models.map((model) => model.name)
|
||||
const allModels = data.models.map((model) => model.name)
|
||||
const models = filterBlacklistedModels(allModels)
|
||||
|
||||
logger.info('Successfully fetched Ollama models', {
|
||||
count: models.length,
|
||||
filtered: allModels.length - models.length,
|
||||
models,
|
||||
})
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { filterBlacklistedModels } from '@/providers/utils'
|
||||
import { filterBlacklistedModels, isProviderBlacklisted } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('OpenRouterModelsAPI')
|
||||
|
||||
@@ -30,6 +30,11 @@ export interface OpenRouterModelInfo {
|
||||
}
|
||||
|
||||
export async function GET(_request: NextRequest) {
|
||||
if (isProviderBlacklisted('openrouter')) {
|
||||
logger.info('OpenRouter provider is blacklisted, returning empty models')
|
||||
return NextResponse.json({ models: [], modelInfo: {} })
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch('https://openrouter.ai/api/v1/models', {
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { filterBlacklistedModels, isProviderBlacklisted } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('VLLMModelsAPI')
|
||||
|
||||
/**
|
||||
* Get available vLLM models
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
export async function GET(_request: NextRequest) {
|
||||
if (isProviderBlacklisted('vllm')) {
|
||||
logger.info('vLLM provider is blacklisted, returning empty models')
|
||||
return NextResponse.json({ models: [] })
|
||||
}
|
||||
|
||||
const baseUrl = (env.VLLM_BASE_URL || '').replace(/\/$/, '')
|
||||
|
||||
if (!baseUrl) {
|
||||
@@ -42,10 +48,12 @@ export async function GET(request: NextRequest) {
|
||||
}
|
||||
|
||||
const data = (await response.json()) as { data: Array<{ id: string }> }
|
||||
const models = data.data.map((model) => `vllm/${model.id}`)
|
||||
const allModels = data.data.map((model) => `vllm/${model.id}`)
|
||||
const models = filterBlacklistedModels(allModels)
|
||||
|
||||
logger.info('Successfully fetched vLLM models', {
|
||||
count: models.length,
|
||||
filtered: allModels.length - models.length,
|
||||
models,
|
||||
})
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import type { BlockConfig } from '@/blocks/types'
|
||||
import { AuthMode } from '@/blocks/types'
|
||||
import {
|
||||
getAllModelProviders,
|
||||
getBaseModelProviders,
|
||||
getHostedModels,
|
||||
getMaxTemperature,
|
||||
getProviderIcon,
|
||||
@@ -417,7 +417,7 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
||||
condition: () => ({
|
||||
field: 'model',
|
||||
value: (() => {
|
||||
const allModels = Object.keys(getAllModelProviders())
|
||||
const allModels = Object.keys(getBaseModelProviders())
|
||||
return allModels.filter(
|
||||
(model) => supportsTemperature(model) && getMaxTemperature(model) === 1
|
||||
)
|
||||
@@ -434,7 +434,7 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
||||
condition: () => ({
|
||||
field: 'model',
|
||||
value: (() => {
|
||||
const allModels = Object.keys(getAllModelProviders())
|
||||
const allModels = Object.keys(getBaseModelProviders())
|
||||
return allModels.filter(
|
||||
(model) => supportsTemperature(model) && getMaxTemperature(model) === 2
|
||||
)
|
||||
@@ -555,7 +555,7 @@ Example 3 (Array Input):
|
||||
if (!model) {
|
||||
throw new Error('No model selected')
|
||||
}
|
||||
const tool = getAllModelProviders()[model]
|
||||
const tool = getBaseModelProviders()[model]
|
||||
if (!tool) {
|
||||
throw new Error(`Invalid model selected: ${model}`)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import type { BlockConfig, ParamType } from '@/blocks/types'
|
||||
import type { ProviderId } from '@/providers/types'
|
||||
import {
|
||||
getAllModelProviders,
|
||||
getBaseModelProviders,
|
||||
getHostedModels,
|
||||
getProviderIcon,
|
||||
providers,
|
||||
@@ -357,7 +357,7 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
|
||||
if (!model) {
|
||||
throw new Error('No model selected')
|
||||
}
|
||||
const tool = getAllModelProviders()[model as ProviderId]
|
||||
const tool = getBaseModelProviders()[model as ProviderId]
|
||||
if (!tool) {
|
||||
throw new Error(`Invalid model selected: ${model}`)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import { AuthMode, type BlockConfig } from '@/blocks/types'
|
||||
import type { ProviderId } from '@/providers/types'
|
||||
import {
|
||||
getAllModelProviders,
|
||||
getBaseModelProviders,
|
||||
getHostedModels,
|
||||
getProviderIcon,
|
||||
providers,
|
||||
@@ -324,7 +324,7 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
|
||||
if (!model) {
|
||||
throw new Error('No model selected')
|
||||
}
|
||||
const tool = getAllModelProviders()[model as ProviderId]
|
||||
const tool = getBaseModelProviders()[model as ProviderId]
|
||||
if (!tool) {
|
||||
throw new Error(`Invalid model selected: ${model}`)
|
||||
}
|
||||
@@ -508,7 +508,7 @@ export const RouterV2Block: BlockConfig<RouterV2Response> = {
|
||||
if (!model) {
|
||||
throw new Error('No model selected')
|
||||
}
|
||||
const tool = getAllModelProviders()[model as ProviderId]
|
||||
const tool = getBaseModelProviders()[model as ProviderId]
|
||||
if (!tool) {
|
||||
throw new Error(`Invalid model selected: ${model}`)
|
||||
}
|
||||
|
||||
@@ -87,7 +87,8 @@ export const env = createEnv({
|
||||
ELEVENLABS_API_KEY: z.string().min(1).optional(), // ElevenLabs API key for text-to-speech in deployed chat
|
||||
SERPER_API_KEY: z.string().min(1).optional(), // Serper API key for online search
|
||||
EXA_API_KEY: z.string().min(1).optional(), // Exa AI API key for enhanced online search
|
||||
DEEPSEEK_MODELS_ENABLED: z.boolean().optional().default(false), // Enable Deepseek models in UI (defaults to false for compliance)
|
||||
BLACKLISTED_PROVIDERS: z.string().optional(), // Comma-separated provider IDs to hide (e.g., "openai,anthropic")
|
||||
BLACKLISTED_MODELS: z.string().optional(), // Comma-separated model names/prefixes to hide (e.g., "gpt-4,claude-*")
|
||||
|
||||
// Azure Configuration - Shared credentials with feature-specific models
|
||||
AZURE_OPENAI_ENDPOINT: z.string().url().optional(), // Shared Azure OpenAI service endpoint
|
||||
|
||||
@@ -3,6 +3,7 @@ import * as environmentModule from '@/lib/core/config/feature-flags'
|
||||
import {
|
||||
calculateCost,
|
||||
extractAndParseJSON,
|
||||
filterBlacklistedModels,
|
||||
formatCost,
|
||||
generateStructuredOutputInstructions,
|
||||
getAllModelProviders,
|
||||
@@ -17,6 +18,7 @@ import {
|
||||
getProviderConfigFromModel,
|
||||
getProviderFromModel,
|
||||
getProviderModels,
|
||||
isProviderBlacklisted,
|
||||
MODELS_TEMP_RANGE_0_1,
|
||||
MODELS_TEMP_RANGE_0_2,
|
||||
MODELS_WITH_REASONING_EFFORT,
|
||||
@@ -976,3 +978,46 @@ describe('Tool Management', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider/Model Blacklist', () => {
|
||||
describe('isProviderBlacklisted', () => {
|
||||
it.concurrent('should return false when no providers are blacklisted', () => {
|
||||
expect(isProviderBlacklisted('openai')).toBe(false)
|
||||
expect(isProviderBlacklisted('anthropic')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('filterBlacklistedModels', () => {
|
||||
it.concurrent('should return all models when no blacklist is set', () => {
|
||||
const models = ['gpt-4o', 'claude-sonnet-4-5', 'gemini-2.5-pro']
|
||||
const result = filterBlacklistedModels(models)
|
||||
expect(result).toEqual(models)
|
||||
})
|
||||
|
||||
it.concurrent('should return empty array for empty input', () => {
|
||||
const result = filterBlacklistedModels([])
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getBaseModelProviders blacklist filtering', () => {
|
||||
it.concurrent('should return providers when no blacklist is set', () => {
|
||||
const providers = getBaseModelProviders()
|
||||
expect(Object.keys(providers).length).toBeGreaterThan(0)
|
||||
expect(providers['gpt-4o']).toBe('openai')
|
||||
expect(providers['claude-sonnet-4-5']).toBe('anthropic')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getProviderFromModel execution-time enforcement', () => {
|
||||
it.concurrent('should return provider for non-blacklisted models', () => {
|
||||
expect(getProviderFromModel('gpt-4o')).toBe('openai')
|
||||
expect(getProviderFromModel('claude-sonnet-4-5')).toBe('anthropic')
|
||||
})
|
||||
|
||||
it.concurrent('should be case insensitive', () => {
|
||||
expect(getProviderFromModel('GPT-4O')).toBe('openai')
|
||||
expect(getProviderFromModel('CLAUDE-SONNET-4-5')).toBe('anthropic')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { createLogger, type Logger } from '@sim/logger'
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import { getEnv, isTruthy } from '@/lib/core/config/env'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import { isCustomTool } from '@/executor/constants'
|
||||
import {
|
||||
@@ -131,6 +131,9 @@ function filterBlacklistedModelsFromProviderMap(
|
||||
): Record<string, ProviderId> {
|
||||
const filtered: Record<string, ProviderId> = {}
|
||||
for (const [model, providerId] of Object.entries(providerMap)) {
|
||||
if (isProviderBlacklisted(providerId)) {
|
||||
continue
|
||||
}
|
||||
if (!isModelBlacklisted(model)) {
|
||||
filtered[model] = providerId
|
||||
}
|
||||
@@ -152,22 +155,39 @@ export function getAllModelProviders(): Record<string, ProviderId> {
|
||||
|
||||
export function getProviderFromModel(model: string): ProviderId {
|
||||
const normalizedModel = model.toLowerCase()
|
||||
if (normalizedModel in getAllModelProviders()) {
|
||||
return getAllModelProviders()[normalizedModel]
|
||||
}
|
||||
|
||||
for (const [providerId, config] of Object.entries(providers)) {
|
||||
if (config.modelPatterns) {
|
||||
for (const pattern of config.modelPatterns) {
|
||||
if (pattern.test(normalizedModel)) {
|
||||
return providerId as ProviderId
|
||||
let providerId: ProviderId | null = null
|
||||
|
||||
if (normalizedModel in getAllModelProviders()) {
|
||||
providerId = getAllModelProviders()[normalizedModel]
|
||||
} else {
|
||||
for (const [id, config] of Object.entries(providers)) {
|
||||
if (config.modelPatterns) {
|
||||
for (const pattern of config.modelPatterns) {
|
||||
if (pattern.test(normalizedModel)) {
|
||||
providerId = id as ProviderId
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if (providerId) break
|
||||
}
|
||||
}
|
||||
|
||||
logger.warn(`No provider found for model: ${model}, defaulting to ollama`)
|
||||
return 'ollama'
|
||||
if (!providerId) {
|
||||
logger.warn(`No provider found for model: ${model}, defaulting to ollama`)
|
||||
providerId = 'ollama'
|
||||
}
|
||||
|
||||
if (isProviderBlacklisted(providerId)) {
|
||||
throw new Error(`Provider "${providerId}" is not available`)
|
||||
}
|
||||
|
||||
if (isModelBlacklisted(normalizedModel)) {
|
||||
throw new Error(`Model "${model}" is not available`)
|
||||
}
|
||||
|
||||
return providerId
|
||||
}
|
||||
|
||||
export function getProvider(id: string): ProviderMetadata | undefined {
|
||||
@@ -192,35 +212,42 @@ export function getProviderModels(providerId: ProviderId): string[] {
|
||||
return getProviderModelsFromDefinitions(providerId)
|
||||
}
|
||||
|
||||
interface ModelBlacklist {
|
||||
models: string[]
|
||||
prefixes: string[]
|
||||
envOverride?: string
|
||||
function getBlacklistedProviders(): string[] {
|
||||
if (!env.BLACKLISTED_PROVIDERS) return []
|
||||
return env.BLACKLISTED_PROVIDERS.split(',').map((p) => p.trim().toLowerCase())
|
||||
}
|
||||
|
||||
const MODEL_BLACKLISTS: ModelBlacklist[] = [
|
||||
{
|
||||
models: ['deepseek-chat', 'deepseek-v3', 'deepseek-r1'],
|
||||
prefixes: ['openrouter/deepseek', 'openrouter/tngtech'],
|
||||
envOverride: 'DEEPSEEK_MODELS_ENABLED',
|
||||
},
|
||||
]
|
||||
export function isProviderBlacklisted(providerId: string): boolean {
|
||||
const blacklist = getBlacklistedProviders()
|
||||
return blacklist.includes(providerId.toLowerCase())
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the list of blacklisted models from env var.
|
||||
* BLACKLISTED_MODELS supports:
|
||||
* - Exact model names: "gpt-4,claude-3-opus"
|
||||
* - Prefix patterns with *: "claude-*,gpt-4-*" (matches models starting with that prefix)
|
||||
*/
|
||||
function getBlacklistedModels(): { models: string[]; prefixes: string[] } {
|
||||
if (!env.BLACKLISTED_MODELS) return { models: [], prefixes: [] }
|
||||
|
||||
const entries = env.BLACKLISTED_MODELS.split(',').map((m) => m.trim().toLowerCase())
|
||||
const models = entries.filter((e) => !e.endsWith('*'))
|
||||
const prefixes = entries.filter((e) => e.endsWith('*')).map((e) => e.slice(0, -1))
|
||||
|
||||
return { models, prefixes }
|
||||
}
|
||||
|
||||
function isModelBlacklisted(model: string): boolean {
|
||||
const lowerModel = model.toLowerCase()
|
||||
const blacklist = getBlacklistedModels()
|
||||
|
||||
for (const blacklist of MODEL_BLACKLISTS) {
|
||||
if (blacklist.envOverride && isTruthy(getEnv(blacklist.envOverride))) {
|
||||
continue
|
||||
}
|
||||
if (blacklist.models.includes(lowerModel)) {
|
||||
return true
|
||||
}
|
||||
|
||||
if (blacklist.models.includes(lowerModel)) {
|
||||
return true
|
||||
}
|
||||
|
||||
if (blacklist.prefixes.some((prefix) => lowerModel.startsWith(prefix))) {
|
||||
return true
|
||||
}
|
||||
if (blacklist.prefixes.some((prefix) => lowerModel.startsWith(prefix))) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
@@ -117,6 +117,10 @@ app:
|
||||
ALLOWED_LOGIN_EMAILS: "" # Comma-separated list of allowed email addresses for login
|
||||
ALLOWED_LOGIN_DOMAINS: "" # Comma-separated list of allowed email domains for login
|
||||
|
||||
# LLM Provider/Model Restrictions (leave empty if not restricting)
|
||||
BLACKLISTED_PROVIDERS: "" # Comma-separated provider IDs to hide from UI (e.g., "openai,anthropic,google")
|
||||
BLACKLISTED_MODELS: "" # Comma-separated model names/prefixes to hide (e.g., "gpt-4,claude-*")
|
||||
|
||||
# SSO Configuration (Enterprise Single Sign-On)
|
||||
# Set to "true" AFTER running the SSO registration script
|
||||
SSO_ENABLED: "" # Enable SSO authentication ("true" to enable)
|
||||
|
||||
Reference in New Issue
Block a user