mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-10 07:27:57 -05:00
* fix(tool-input): allow multiple instances of workflow block or kb tools as agent tools * ack PR comments
1064 lines
33 KiB
TypeScript
1064 lines
33 KiB
TypeScript
import { getEnv, isTruthy } from '@/lib/core/config/env'
|
|
import { isHosted } from '@/lib/core/config/feature-flags'
|
|
import { createLogger } from '@/lib/logs/console/logger'
|
|
import { anthropicProvider } from '@/providers/anthropic'
|
|
import { azureOpenAIProvider } from '@/providers/azure-openai'
|
|
import { cerebrasProvider } from '@/providers/cerebras'
|
|
import { deepseekProvider } from '@/providers/deepseek'
|
|
import { googleProvider } from '@/providers/google'
|
|
import { groqProvider } from '@/providers/groq'
|
|
import { mistralProvider } from '@/providers/mistral'
|
|
import {
|
|
getComputerUseModels,
|
|
getEmbeddingModelPricing,
|
|
getHostedModels as getHostedModelsFromDefinitions,
|
|
getMaxTemperature as getMaxTempFromDefinitions,
|
|
getModelPricing as getModelPricingFromDefinitions,
|
|
getModelsWithReasoningEffort,
|
|
getModelsWithTemperatureSupport,
|
|
getModelsWithTempRange01,
|
|
getModelsWithTempRange02,
|
|
getModelsWithVerbosity,
|
|
getProviderModels as getProviderModelsFromDefinitions,
|
|
getProvidersWithToolUsageControl,
|
|
getReasoningEffortValuesForModel as getReasoningEffortValuesForModelFromDefinitions,
|
|
getVerbosityValuesForModel as getVerbosityValuesForModelFromDefinitions,
|
|
PROVIDER_DEFINITIONS,
|
|
supportsTemperature as supportsTemperatureFromDefinitions,
|
|
supportsToolUsageControl as supportsToolUsageControlFromDefinitions,
|
|
updateOllamaModels as updateOllamaModelsInDefinitions,
|
|
} from '@/providers/models'
|
|
import { ollamaProvider } from '@/providers/ollama'
|
|
import { openaiProvider } from '@/providers/openai'
|
|
import { openRouterProvider } from '@/providers/openrouter'
|
|
import type { ProviderConfig, ProviderId, ProviderToolConfig } from '@/providers/types'
|
|
import { vertexProvider } from '@/providers/vertex'
|
|
import { vllmProvider } from '@/providers/vllm'
|
|
import { xAIProvider } from '@/providers/xai'
|
|
import { useCustomToolsStore } from '@/stores/custom-tools/store'
|
|
import { useProvidersStore } from '@/stores/providers/store'
|
|
|
|
const logger = createLogger('ProviderUtils')
|
|
|
|
/**
|
|
* Provider configurations - built from the comprehensive definitions
|
|
*/
|
|
export const providers: Record<
|
|
ProviderId,
|
|
ProviderConfig & {
|
|
models: string[]
|
|
computerUseModels?: string[]
|
|
modelPatterns?: RegExp[]
|
|
}
|
|
> = {
|
|
openai: {
|
|
...openaiProvider,
|
|
models: getProviderModelsFromDefinitions('openai'),
|
|
computerUseModels: ['computer-use-preview'],
|
|
modelPatterns: PROVIDER_DEFINITIONS.openai.modelPatterns,
|
|
},
|
|
anthropic: {
|
|
...anthropicProvider,
|
|
models: getProviderModelsFromDefinitions('anthropic'),
|
|
computerUseModels: getComputerUseModels().filter((model) =>
|
|
getProviderModelsFromDefinitions('anthropic').includes(model)
|
|
),
|
|
modelPatterns: PROVIDER_DEFINITIONS.anthropic.modelPatterns,
|
|
},
|
|
google: {
|
|
...googleProvider,
|
|
models: getProviderModelsFromDefinitions('google'),
|
|
modelPatterns: PROVIDER_DEFINITIONS.google.modelPatterns,
|
|
},
|
|
vertex: {
|
|
...vertexProvider,
|
|
models: getProviderModelsFromDefinitions('vertex'),
|
|
modelPatterns: PROVIDER_DEFINITIONS.vertex.modelPatterns,
|
|
},
|
|
deepseek: {
|
|
...deepseekProvider,
|
|
models: getProviderModelsFromDefinitions('deepseek'),
|
|
modelPatterns: PROVIDER_DEFINITIONS.deepseek.modelPatterns,
|
|
},
|
|
xai: {
|
|
...xAIProvider,
|
|
models: getProviderModelsFromDefinitions('xai'),
|
|
modelPatterns: PROVIDER_DEFINITIONS.xai.modelPatterns,
|
|
},
|
|
cerebras: {
|
|
...cerebrasProvider,
|
|
models: getProviderModelsFromDefinitions('cerebras'),
|
|
modelPatterns: PROVIDER_DEFINITIONS.cerebras.modelPatterns,
|
|
},
|
|
groq: {
|
|
...groqProvider,
|
|
models: getProviderModelsFromDefinitions('groq'),
|
|
modelPatterns: PROVIDER_DEFINITIONS.groq.modelPatterns,
|
|
},
|
|
vllm: {
|
|
...vllmProvider,
|
|
models: getProviderModelsFromDefinitions('vllm'),
|
|
modelPatterns: PROVIDER_DEFINITIONS.vllm.modelPatterns,
|
|
},
|
|
mistral: {
|
|
...mistralProvider,
|
|
models: getProviderModelsFromDefinitions('mistral'),
|
|
modelPatterns: PROVIDER_DEFINITIONS.mistral.modelPatterns,
|
|
},
|
|
'azure-openai': {
|
|
...azureOpenAIProvider,
|
|
models: getProviderModelsFromDefinitions('azure-openai'),
|
|
modelPatterns: PROVIDER_DEFINITIONS['azure-openai'].modelPatterns,
|
|
},
|
|
openrouter: {
|
|
...openRouterProvider,
|
|
models: getProviderModelsFromDefinitions('openrouter'),
|
|
modelPatterns: PROVIDER_DEFINITIONS.openrouter.modelPatterns,
|
|
},
|
|
ollama: {
|
|
...ollamaProvider,
|
|
models: getProviderModelsFromDefinitions('ollama'),
|
|
modelPatterns: PROVIDER_DEFINITIONS.ollama.modelPatterns,
|
|
},
|
|
}
|
|
|
|
Object.entries(providers).forEach(([id, provider]) => {
|
|
if (provider.initialize) {
|
|
provider.initialize().catch((error) => {
|
|
logger.error(`Failed to initialize ${id} provider`, {
|
|
error: error instanceof Error ? error.message : 'Unknown error',
|
|
})
|
|
})
|
|
}
|
|
})
|
|
|
|
export function updateOllamaProviderModels(models: string[]): void {
|
|
updateOllamaModelsInDefinitions(models)
|
|
providers.ollama.models = getProviderModelsFromDefinitions('ollama')
|
|
}
|
|
|
|
export function updateVLLMProviderModels(models: string[]): void {
|
|
const { updateVLLMModels } = require('@/providers/models')
|
|
updateVLLMModels(models)
|
|
providers.vllm.models = getProviderModelsFromDefinitions('vllm')
|
|
}
|
|
|
|
export async function updateOpenRouterProviderModels(models: string[]): Promise<void> {
|
|
const { updateOpenRouterModels } = await import('@/providers/models')
|
|
updateOpenRouterModels(models)
|
|
providers.openrouter.models = getProviderModelsFromDefinitions('openrouter')
|
|
}
|
|
|
|
export function getBaseModelProviders(): Record<string, ProviderId> {
|
|
const allProviders = Object.entries(providers)
|
|
.filter(
|
|
([providerId]) =>
|
|
providerId !== 'ollama' && providerId !== 'vllm' && providerId !== 'openrouter'
|
|
)
|
|
.reduce(
|
|
(map, [providerId, config]) => {
|
|
config.models.forEach((model) => {
|
|
map[model.toLowerCase()] = providerId as ProviderId
|
|
})
|
|
return map
|
|
},
|
|
{} as Record<string, ProviderId>
|
|
)
|
|
|
|
return filterBlacklistedModelsFromProviderMap(allProviders)
|
|
}
|
|
|
|
function filterBlacklistedModelsFromProviderMap(
|
|
providerMap: Record<string, ProviderId>
|
|
): Record<string, ProviderId> {
|
|
const filtered: Record<string, ProviderId> = {}
|
|
for (const [model, providerId] of Object.entries(providerMap)) {
|
|
if (!isModelBlacklisted(model)) {
|
|
filtered[model] = providerId
|
|
}
|
|
}
|
|
return filtered
|
|
}
|
|
|
|
export function getAllModelProviders(): Record<string, ProviderId> {
|
|
return Object.entries(providers).reduce(
|
|
(map, [providerId, config]) => {
|
|
config.models.forEach((model) => {
|
|
map[model.toLowerCase()] = providerId as ProviderId
|
|
})
|
|
return map
|
|
},
|
|
{} as Record<string, ProviderId>
|
|
)
|
|
}
|
|
|
|
export function getProviderFromModel(model: string): ProviderId {
|
|
const normalizedModel = model.toLowerCase()
|
|
if (normalizedModel in getAllModelProviders()) {
|
|
return getAllModelProviders()[normalizedModel]
|
|
}
|
|
|
|
for (const [providerId, config] of Object.entries(providers)) {
|
|
if (config.modelPatterns) {
|
|
for (const pattern of config.modelPatterns) {
|
|
if (pattern.test(normalizedModel)) {
|
|
return providerId as ProviderId
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
logger.warn(`No provider found for model: ${model}, defaulting to ollama`)
|
|
return 'ollama'
|
|
}
|
|
|
|
export function getProvider(id: string): ProviderConfig | undefined {
|
|
// Handle both formats: 'openai' and 'openai/chat'
|
|
const providerId = id.split('/')[0] as ProviderId
|
|
return providers[providerId]
|
|
}
|
|
|
|
export function getProviderConfigFromModel(model: string): ProviderConfig | undefined {
|
|
const providerId = getProviderFromModel(model)
|
|
return providers[providerId]
|
|
}
|
|
|
|
export function getAllModels(): string[] {
|
|
return Object.values(providers).flatMap((provider) => provider.models || [])
|
|
}
|
|
|
|
export function getAllProviderIds(): ProviderId[] {
|
|
return Object.keys(providers) as ProviderId[]
|
|
}
|
|
|
|
export function getProviderModels(providerId: ProviderId): string[] {
|
|
return getProviderModelsFromDefinitions(providerId)
|
|
}
|
|
|
|
interface ModelBlacklist {
|
|
models: string[]
|
|
prefixes: string[]
|
|
envOverride?: string
|
|
}
|
|
|
|
const MODEL_BLACKLISTS: ModelBlacklist[] = [
|
|
{
|
|
models: ['deepseek-chat', 'deepseek-v3', 'deepseek-r1'],
|
|
prefixes: ['openrouter/deepseek', 'openrouter/tngtech'],
|
|
envOverride: 'DEEPSEEK_MODELS_ENABLED',
|
|
},
|
|
]
|
|
|
|
function isModelBlacklisted(model: string): boolean {
|
|
const lowerModel = model.toLowerCase()
|
|
|
|
for (const blacklist of MODEL_BLACKLISTS) {
|
|
if (blacklist.envOverride && isTruthy(getEnv(blacklist.envOverride))) {
|
|
continue
|
|
}
|
|
|
|
if (blacklist.models.includes(lowerModel)) {
|
|
return true
|
|
}
|
|
|
|
if (blacklist.prefixes.some((prefix) => lowerModel.startsWith(prefix))) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
export function filterBlacklistedModels(models: string[]): string[] {
|
|
return models.filter((model) => !isModelBlacklisted(model))
|
|
}
|
|
|
|
/**
|
|
* Get provider icon for a given model
|
|
*/
|
|
export function getProviderIcon(model: string): React.ComponentType<{ className?: string }> | null {
|
|
const providerId = getProviderFromModel(model)
|
|
return PROVIDER_DEFINITIONS[providerId]?.icon || null
|
|
}
|
|
|
|
export function generateStructuredOutputInstructions(responseFormat: any): string {
|
|
if (!responseFormat) return ''
|
|
|
|
if (responseFormat.schema || (responseFormat.type === 'object' && responseFormat.properties)) {
|
|
return ''
|
|
}
|
|
|
|
if (!responseFormat.fields) return ''
|
|
|
|
function generateFieldStructure(field: any): string {
|
|
if (field.type === 'object' && field.properties) {
|
|
return `{
|
|
${Object.entries(field.properties)
|
|
.map(([key, prop]: [string, any]) => `"${key}": ${prop.type === 'number' ? '0' : '"value"'}`)
|
|
.join(',\n ')}
|
|
}`
|
|
}
|
|
return field.type === 'string'
|
|
? '"value"'
|
|
: field.type === 'number'
|
|
? '0'
|
|
: field.type === 'boolean'
|
|
? 'true/false'
|
|
: '[]'
|
|
}
|
|
|
|
const exampleFormat = responseFormat.fields
|
|
.map((field: any) => ` "${field.name}": ${generateFieldStructure(field)}`)
|
|
.join(',\n')
|
|
|
|
const fieldDescriptions = responseFormat.fields
|
|
.map((field: any) => {
|
|
let desc = `${field.name} (${field.type})`
|
|
if (field.description) desc += `: ${field.description}`
|
|
if (field.type === 'object' && field.properties) {
|
|
desc += '\nProperties:'
|
|
Object.entries(field.properties).forEach(([key, prop]: [string, any]) => {
|
|
desc += `\n - ${key} (${(prop as any).type}): ${(prop as any).description || ''}`
|
|
})
|
|
}
|
|
return desc
|
|
})
|
|
.join('\n')
|
|
|
|
return `
|
|
Please provide your response in the following JSON format:
|
|
{
|
|
${exampleFormat}
|
|
}
|
|
|
|
Field descriptions:
|
|
${fieldDescriptions}
|
|
|
|
Your response MUST be valid JSON and include all the specified fields with their correct types.
|
|
Each metric should be an object containing 'score' (number) and 'reasoning' (string).`
|
|
}
|
|
|
|
export function extractAndParseJSON(content: string): any {
|
|
const trimmed = content.trim()
|
|
|
|
const firstBrace = trimmed.indexOf('{')
|
|
const lastBrace = trimmed.lastIndexOf('}')
|
|
|
|
if (firstBrace === -1 || lastBrace === -1) {
|
|
throw new Error('No JSON object found in content')
|
|
}
|
|
|
|
const jsonStr = trimmed.slice(firstBrace, lastBrace + 1)
|
|
|
|
try {
|
|
return JSON.parse(jsonStr)
|
|
} catch (_error) {
|
|
const cleaned = jsonStr
|
|
.replace(/\n/g, ' ')
|
|
.replace(/\s+/g, ' ')
|
|
.replace(/,\s*([}\]])/g, '$1')
|
|
|
|
try {
|
|
return JSON.parse(cleaned)
|
|
} catch (innerError) {
|
|
logger.error('Failed to parse JSON response', {
|
|
contentLength: content.length,
|
|
extractedLength: jsonStr.length,
|
|
cleanedLength: cleaned.length,
|
|
error: innerError instanceof Error ? innerError.message : 'Unknown error',
|
|
})
|
|
|
|
throw new Error(
|
|
`Failed to parse JSON after cleanup: ${innerError instanceof Error ? innerError.message : 'Unknown error'}`
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Transforms a custom tool schema into a provider tool config
|
|
*/
|
|
export function transformCustomTool(customTool: any): ProviderToolConfig {
|
|
const schema = customTool.schema
|
|
|
|
if (!schema || !schema.function) {
|
|
throw new Error('Invalid custom tool schema')
|
|
}
|
|
|
|
return {
|
|
id: `custom_${customTool.id}`,
|
|
name: schema.function.name,
|
|
description: schema.function.description || '',
|
|
params: {},
|
|
parameters: {
|
|
type: schema.function.parameters.type,
|
|
properties: schema.function.parameters.properties,
|
|
required: schema.function.parameters.required || [],
|
|
},
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Gets all available custom tools as provider tool configs
|
|
*/
|
|
export function getCustomTools(): ProviderToolConfig[] {
|
|
const customTools = useCustomToolsStore.getState().getAllTools()
|
|
|
|
return customTools.map(transformCustomTool)
|
|
}
|
|
|
|
/**
|
|
* Transforms a block tool into a provider tool config with operation selection
|
|
*
|
|
* @param block The block to transform
|
|
* @param options Additional options including dependencies and selected operation
|
|
* @returns The provider tool config or null if transform fails
|
|
*/
|
|
export async function transformBlockTool(
|
|
block: any,
|
|
options: {
|
|
selectedOperation?: string
|
|
getAllBlocks: () => any[]
|
|
getTool: (toolId: string) => any
|
|
getToolAsync?: (toolId: string) => Promise<any>
|
|
}
|
|
): Promise<ProviderToolConfig | null> {
|
|
const { selectedOperation, getAllBlocks, getTool, getToolAsync } = options
|
|
|
|
const blockDef = getAllBlocks().find((b: any) => b.type === block.type)
|
|
if (!blockDef) {
|
|
logger.warn(`Block definition not found for type: ${block.type}`)
|
|
return null
|
|
}
|
|
|
|
let toolId: string | null = null
|
|
|
|
if ((blockDef.tools?.access?.length || 0) > 1) {
|
|
if (selectedOperation && blockDef.tools?.config?.tool) {
|
|
try {
|
|
toolId = blockDef.tools.config.tool({
|
|
...block.params,
|
|
operation: selectedOperation,
|
|
})
|
|
} catch (error) {
|
|
logger.error('Error selecting tool for block', {
|
|
blockType: block.type,
|
|
operation: selectedOperation,
|
|
error,
|
|
})
|
|
return null
|
|
}
|
|
} else {
|
|
toolId = blockDef.tools.access[0]
|
|
}
|
|
} else {
|
|
toolId = blockDef.tools?.access?.[0] || null
|
|
}
|
|
|
|
if (!toolId) {
|
|
logger.warn(`No tool ID found for block: ${block.type}`)
|
|
return null
|
|
}
|
|
|
|
let toolConfig: any
|
|
|
|
if (toolId.startsWith('custom_') && getToolAsync) {
|
|
toolConfig = await getToolAsync(toolId)
|
|
} else {
|
|
toolConfig = getTool(toolId)
|
|
}
|
|
|
|
if (!toolConfig) {
|
|
logger.warn(`Tool config not found for ID: ${toolId}`)
|
|
return null
|
|
}
|
|
|
|
const { createLLMToolSchema } = await import('@/tools/params')
|
|
|
|
const userProvidedParams = block.params || {}
|
|
|
|
const llmSchema = await createLLMToolSchema(toolConfig, userProvidedParams)
|
|
|
|
// Create unique tool ID by appending resource ID for multi-instance tools
|
|
let uniqueToolId = toolConfig.id
|
|
if (toolId === 'workflow_executor' && userProvidedParams.workflowId) {
|
|
uniqueToolId = `${toolConfig.id}_${userProvidedParams.workflowId}`
|
|
} else if (toolId.startsWith('knowledge_') && userProvidedParams.knowledgeBaseId) {
|
|
uniqueToolId = `${toolConfig.id}_${userProvidedParams.knowledgeBaseId}`
|
|
}
|
|
|
|
return {
|
|
id: uniqueToolId,
|
|
name: toolConfig.name,
|
|
description: toolConfig.description,
|
|
params: userProvidedParams,
|
|
parameters: llmSchema,
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Calculate cost for token usage based on model pricing
|
|
*
|
|
* @param model The model name
|
|
* @param promptTokens Number of prompt tokens used
|
|
* @param completionTokens Number of completion tokens used
|
|
* @param useCachedInput Whether to use cached input pricing (default: false)
|
|
* @param customMultiplier Optional custom multiplier to override the default cost multiplier
|
|
* @returns Cost calculation results with input, output and total costs
|
|
*/
|
|
export function calculateCost(
|
|
model: string,
|
|
promptTokens = 0,
|
|
completionTokens = 0,
|
|
useCachedInput = false,
|
|
inputMultiplier?: number,
|
|
outputMultiplier?: number
|
|
) {
|
|
let pricing = getEmbeddingModelPricing(model)
|
|
|
|
if (!pricing) {
|
|
pricing = getModelPricingFromDefinitions(model)
|
|
}
|
|
|
|
if (!pricing) {
|
|
const defaultPricing = {
|
|
input: 1.0,
|
|
cachedInput: 0.5,
|
|
output: 5.0,
|
|
updatedAt: '2025-03-21',
|
|
}
|
|
return {
|
|
input: 0,
|
|
output: 0,
|
|
total: 0,
|
|
pricing: defaultPricing,
|
|
}
|
|
}
|
|
|
|
const inputCost =
|
|
promptTokens *
|
|
(useCachedInput && pricing.cachedInput
|
|
? pricing.cachedInput / 1_000_000
|
|
: pricing.input / 1_000_000)
|
|
|
|
const outputCost = completionTokens * (pricing.output / 1_000_000)
|
|
const finalInputCost = inputCost * (inputMultiplier ?? 1)
|
|
const finalOutputCost = outputCost * (outputMultiplier ?? 1)
|
|
const finalTotalCost = finalInputCost + finalOutputCost
|
|
|
|
return {
|
|
input: Number.parseFloat(finalInputCost.toFixed(8)),
|
|
output: Number.parseFloat(finalOutputCost.toFixed(8)),
|
|
total: Number.parseFloat(finalTotalCost.toFixed(8)),
|
|
pricing,
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Get pricing information for a specific model (including embedding models)
|
|
*/
|
|
export function getModelPricing(modelId: string): any {
|
|
const embeddingPricing = getEmbeddingModelPricing(modelId)
|
|
if (embeddingPricing) {
|
|
return embeddingPricing
|
|
}
|
|
|
|
return getModelPricingFromDefinitions(modelId)
|
|
}
|
|
|
|
/**
|
|
* Format cost as a currency string
|
|
*
|
|
* @param cost Cost in USD
|
|
* @returns Formatted cost string
|
|
*/
|
|
export function formatCost(cost: number): string {
|
|
if (cost === undefined || cost === null) return '—'
|
|
|
|
if (cost >= 1) {
|
|
return `$${cost.toFixed(2)}`
|
|
}
|
|
if (cost >= 0.01) {
|
|
return `$${cost.toFixed(3)}`
|
|
}
|
|
if (cost >= 0.001) {
|
|
return `$${cost.toFixed(4)}`
|
|
}
|
|
if (cost > 0) {
|
|
const places = Math.max(4, Math.abs(Math.floor(Math.log10(cost))) + 3)
|
|
return `$${cost.toFixed(places)}`
|
|
}
|
|
return '$0'
|
|
}
|
|
|
|
/**
|
|
* Get the list of models that are hosted by the platform (don't require user API keys)
|
|
* These are the models for which we hide the API key field in the hosted environment
|
|
*/
|
|
export function getHostedModels(): string[] {
|
|
return getHostedModelsFromDefinitions()
|
|
}
|
|
|
|
/**
|
|
* Determine if model usage should be billed to the user
|
|
*
|
|
* @param model The model name
|
|
* @returns true if the usage should be billed to the user
|
|
*/
|
|
export function shouldBillModelUsage(model: string): boolean {
|
|
const hostedModels = getHostedModels()
|
|
return hostedModels.some((hostedModel) => model.toLowerCase() === hostedModel.toLowerCase())
|
|
}
|
|
|
|
/**
|
|
* Get an API key for a specific provider, handling rotation and fallbacks
|
|
* For use server-side only
|
|
*/
|
|
export function getApiKey(provider: string, model: string, userProvidedKey?: string): string {
|
|
// If user provided a key, use it as a fallback
|
|
const hasUserKey = !!userProvidedKey
|
|
|
|
// Ollama and vLLM models don't require API keys
|
|
const isOllamaModel =
|
|
provider === 'ollama' || useProvidersStore.getState().providers.ollama.models.includes(model)
|
|
if (isOllamaModel) {
|
|
return 'empty' // Ollama uses 'empty' as a placeholder API key
|
|
}
|
|
|
|
const isVllmModel =
|
|
provider === 'vllm' || useProvidersStore.getState().providers.vllm.models.includes(model)
|
|
if (isVllmModel) {
|
|
return userProvidedKey || 'empty' // vLLM uses 'empty' as a placeholder if no key provided
|
|
}
|
|
|
|
// Use server key rotation for all OpenAI models, Anthropic's Claude models, and Google's Gemini models on the hosted platform
|
|
const isOpenAIModel = provider === 'openai'
|
|
const isClaudeModel = provider === 'anthropic'
|
|
const isGeminiModel = provider === 'google'
|
|
|
|
if (isHosted && (isOpenAIModel || isClaudeModel || isGeminiModel)) {
|
|
// Only use server key if model is explicitly in our hosted list
|
|
const hostedModels = getHostedModels()
|
|
const isModelHosted = hostedModels.some((m) => m.toLowerCase() === model.toLowerCase())
|
|
|
|
if (isModelHosted) {
|
|
try {
|
|
const { getRotatingApiKey } = require('@/lib/core/config/api-keys')
|
|
const serverKey = getRotatingApiKey(isGeminiModel ? 'gemini' : provider)
|
|
return serverKey
|
|
} catch (_error) {
|
|
if (hasUserKey) {
|
|
return userProvidedKey!
|
|
}
|
|
|
|
throw new Error(`No API key available for ${provider} ${model}`)
|
|
}
|
|
}
|
|
}
|
|
|
|
// For all other cases, require user-provided key
|
|
if (!hasUserKey) {
|
|
throw new Error(`API key is required for ${provider} ${model}`)
|
|
}
|
|
|
|
return userProvidedKey!
|
|
}
|
|
|
|
/**
|
|
* Prepares tool configuration for provider requests with consistent tool usage control behavior
|
|
*
|
|
* @param tools Array of tools in provider-specific format
|
|
* @param providerTools Original tool configurations with usage control settings
|
|
* @param logger Logger instance to use for logging
|
|
* @param provider Optional provider ID to adjust format for specific providers
|
|
* @returns Object with prepared tools and tool_choice settings
|
|
*/
|
|
export function prepareToolsWithUsageControl(
|
|
tools: any[] | undefined,
|
|
providerTools: any[] | undefined,
|
|
logger: any,
|
|
provider?: string
|
|
): {
|
|
tools: any[] | undefined
|
|
toolChoice:
|
|
| 'auto'
|
|
| 'none'
|
|
| { type: 'function'; function: { name: string } }
|
|
| { type: 'tool'; name: string }
|
|
| { type: 'any'; any: { model: string; name: string } }
|
|
| undefined
|
|
toolConfig?: {
|
|
// Add toolConfig for Google's format
|
|
functionCallingConfig: {
|
|
mode: 'AUTO' | 'ANY' | 'NONE'
|
|
allowedFunctionNames?: string[]
|
|
}
|
|
}
|
|
hasFilteredTools: boolean
|
|
forcedTools: string[] // Return all forced tool IDs
|
|
} {
|
|
// If no tools, return early
|
|
if (!tools || tools.length === 0) {
|
|
return {
|
|
tools: undefined,
|
|
toolChoice: undefined,
|
|
hasFilteredTools: false,
|
|
forcedTools: [],
|
|
}
|
|
}
|
|
|
|
// Filter out tools marked with usageControl='none'
|
|
const filteredTools = tools.filter((tool) => {
|
|
const toolId = tool.function?.name || tool.name
|
|
const toolConfig = providerTools?.find((t) => t.id === toolId)
|
|
return toolConfig?.usageControl !== 'none'
|
|
})
|
|
|
|
// Check if any tools were filtered out
|
|
const hasFilteredTools = filteredTools.length < tools.length
|
|
if (hasFilteredTools) {
|
|
logger.info(
|
|
`Filtered out ${tools.length - filteredTools.length} tools with usageControl='none'`
|
|
)
|
|
}
|
|
|
|
// If all tools were filtered out, return empty
|
|
if (filteredTools.length === 0) {
|
|
logger.info('All tools were filtered out due to usageControl="none"')
|
|
return {
|
|
tools: undefined,
|
|
toolChoice: undefined,
|
|
hasFilteredTools: true,
|
|
forcedTools: [],
|
|
}
|
|
}
|
|
|
|
// Get all tools that should be forced
|
|
const forcedTools = providerTools?.filter((tool) => tool.usageControl === 'force') || []
|
|
const forcedToolIds = forcedTools.map((tool) => tool.id)
|
|
|
|
// Determine tool_choice setting
|
|
let toolChoice:
|
|
| 'auto'
|
|
| 'none'
|
|
| { type: 'function'; function: { name: string } }
|
|
| { type: 'tool'; name: string }
|
|
| { type: 'any'; any: { model: string; name: string } } = 'auto'
|
|
|
|
// For Google, we'll use a separate toolConfig object
|
|
let toolConfig:
|
|
| {
|
|
functionCallingConfig: {
|
|
mode: 'AUTO' | 'ANY' | 'NONE'
|
|
allowedFunctionNames?: string[]
|
|
}
|
|
}
|
|
| undefined
|
|
|
|
if (forcedTools.length > 0) {
|
|
// Force the first tool that has usageControl='force'
|
|
const forcedTool = forcedTools[0]
|
|
|
|
// Adjust format based on provider
|
|
if (provider === 'anthropic') {
|
|
toolChoice = {
|
|
type: 'tool',
|
|
name: forcedTool.id,
|
|
}
|
|
} else if (provider === 'google') {
|
|
// Google Gemini format uses a separate toolConfig object
|
|
toolConfig = {
|
|
functionCallingConfig: {
|
|
mode: 'ANY',
|
|
allowedFunctionNames:
|
|
forcedTools.length === 1
|
|
? [forcedTool.id] // If only one tool, specify just that one
|
|
: forcedToolIds, // If multiple tools, include all of them
|
|
},
|
|
}
|
|
// Keep toolChoice as 'auto' since we use toolConfig instead
|
|
toolChoice = 'auto'
|
|
} else {
|
|
// Default OpenAI format
|
|
toolChoice = {
|
|
type: 'function',
|
|
function: { name: forcedTool.id },
|
|
}
|
|
}
|
|
|
|
logger.info(`Forcing use of tool: ${forcedTool.id}`)
|
|
|
|
if (forcedTools.length > 1) {
|
|
logger.info(
|
|
`Multiple tools set to 'force' mode (${forcedToolIds.join(', ')}). Will cycle through them sequentially.`
|
|
)
|
|
}
|
|
} else {
|
|
// Default to auto if no forced tools
|
|
toolChoice = 'auto'
|
|
if (provider === 'google') {
|
|
toolConfig = { functionCallingConfig: { mode: 'AUTO' } }
|
|
}
|
|
logger.info('Setting tool_choice to auto - letting model decide which tools to use')
|
|
}
|
|
|
|
return {
|
|
tools: filteredTools,
|
|
toolChoice,
|
|
toolConfig,
|
|
hasFilteredTools,
|
|
forcedTools: forcedToolIds,
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Checks if a forced tool has been used in a response and manages the tool_choice accordingly
|
|
*
|
|
* @param toolCallsResponse Array of tool calls in the response
|
|
* @param originalToolChoice The original tool_choice setting used in the request
|
|
* @param logger Logger instance to use for logging
|
|
* @param provider Optional provider ID to adjust format for specific providers
|
|
* @param forcedTools Array of all tool IDs that should be forced in sequence
|
|
* @param usedForcedTools Array of tool IDs that have already been used
|
|
* @returns Object containing tracking information and next tool choice
|
|
*/
|
|
export function trackForcedToolUsage(
|
|
toolCallsResponse: any[] | undefined,
|
|
originalToolChoice: any,
|
|
logger: any,
|
|
provider?: string,
|
|
forcedTools: string[] = [],
|
|
usedForcedTools: string[] = []
|
|
): {
|
|
hasUsedForcedTool: boolean
|
|
usedForcedTools: string[]
|
|
nextToolChoice?:
|
|
| 'auto'
|
|
| { type: 'function'; function: { name: string } }
|
|
| { type: 'tool'; name: string }
|
|
| { type: 'any'; any: { model: string; name: string } }
|
|
| null
|
|
nextToolConfig?: {
|
|
functionCallingConfig: {
|
|
mode: 'AUTO' | 'ANY' | 'NONE'
|
|
allowedFunctionNames?: string[]
|
|
}
|
|
}
|
|
} {
|
|
// Default to keeping the original tool_choice
|
|
let hasUsedForcedTool = false
|
|
let nextToolChoice = originalToolChoice
|
|
let nextToolConfig:
|
|
| {
|
|
functionCallingConfig: {
|
|
mode: 'AUTO' | 'ANY' | 'NONE'
|
|
allowedFunctionNames?: string[]
|
|
}
|
|
}
|
|
| undefined
|
|
|
|
const updatedUsedForcedTools = [...usedForcedTools]
|
|
|
|
// Special handling for Google format
|
|
const isGoogleFormat = provider === 'google'
|
|
|
|
// Get the name of the current forced tool(s)
|
|
let forcedToolNames: string[] = []
|
|
if (isGoogleFormat && originalToolChoice?.functionCallingConfig?.allowedFunctionNames) {
|
|
// For Google format
|
|
forcedToolNames = originalToolChoice.functionCallingConfig.allowedFunctionNames
|
|
} else if (
|
|
typeof originalToolChoice === 'object' &&
|
|
(originalToolChoice?.function?.name ||
|
|
(originalToolChoice?.type === 'tool' && originalToolChoice?.name) ||
|
|
(originalToolChoice?.type === 'any' && originalToolChoice?.any?.name))
|
|
) {
|
|
// For other providers
|
|
forcedToolNames = [
|
|
originalToolChoice?.function?.name ||
|
|
originalToolChoice?.name ||
|
|
originalToolChoice?.any?.name,
|
|
].filter(Boolean)
|
|
}
|
|
|
|
// If we're forcing specific tools and we have tool calls in the response
|
|
if (forcedToolNames.length > 0 && toolCallsResponse && toolCallsResponse.length > 0) {
|
|
// Check if any of the tool calls used the forced tools
|
|
const toolNames = toolCallsResponse.map((tc) => tc.function?.name || tc.name || tc.id)
|
|
|
|
// Find any forced tools that were used
|
|
const usedTools = forcedToolNames.filter((toolName) => toolNames.includes(toolName))
|
|
|
|
if (usedTools.length > 0) {
|
|
// At least one forced tool was used
|
|
hasUsedForcedTool = true
|
|
updatedUsedForcedTools.push(...usedTools)
|
|
|
|
// Find the next tools to force that haven't been used yet
|
|
const remainingTools = forcedTools.filter((tool) => !updatedUsedForcedTools.includes(tool))
|
|
|
|
if (remainingTools.length > 0) {
|
|
// There are still forced tools to use
|
|
const nextToolToForce = remainingTools[0]
|
|
|
|
// Format based on provider
|
|
if (provider === 'anthropic') {
|
|
nextToolChoice = {
|
|
type: 'tool',
|
|
name: nextToolToForce,
|
|
}
|
|
} else if (provider === 'google') {
|
|
nextToolConfig = {
|
|
functionCallingConfig: {
|
|
mode: 'ANY',
|
|
allowedFunctionNames:
|
|
remainingTools.length === 1
|
|
? [nextToolToForce] // If only one tool left, specify just that one
|
|
: remainingTools, // If multiple tools, include all remaining
|
|
},
|
|
}
|
|
} else {
|
|
// Default OpenAI format
|
|
nextToolChoice = {
|
|
type: 'function',
|
|
function: { name: nextToolToForce },
|
|
}
|
|
}
|
|
|
|
logger.info(
|
|
`Forced tool(s) ${usedTools.join(', ')} used, switching to next forced tool(s): ${remainingTools.join(', ')}`
|
|
)
|
|
} else {
|
|
// All forced tools have been used, switch to auto mode
|
|
if (provider === 'anthropic') {
|
|
// Anthropic: return null to signal the parameter should be deleted/omitted
|
|
nextToolChoice = null
|
|
} else if (provider === 'google') {
|
|
nextToolConfig = { functionCallingConfig: { mode: 'AUTO' } }
|
|
} else {
|
|
nextToolChoice = 'auto'
|
|
}
|
|
|
|
logger.info('All forced tools have been used, switching to auto mode for future iterations')
|
|
}
|
|
}
|
|
}
|
|
|
|
return {
|
|
hasUsedForcedTool,
|
|
usedForcedTools: updatedUsedForcedTools,
|
|
nextToolChoice: hasUsedForcedTool ? nextToolChoice : originalToolChoice,
|
|
nextToolConfig: isGoogleFormat
|
|
? hasUsedForcedTool
|
|
? nextToolConfig
|
|
: originalToolChoice
|
|
: undefined,
|
|
}
|
|
}
|
|
|
|
export const MODELS_TEMP_RANGE_0_2 = getModelsWithTempRange02()
|
|
export const MODELS_TEMP_RANGE_0_1 = getModelsWithTempRange01()
|
|
export const MODELS_WITH_TEMPERATURE_SUPPORT = getModelsWithTemperatureSupport()
|
|
export const MODELS_WITH_REASONING_EFFORT = getModelsWithReasoningEffort()
|
|
export const MODELS_WITH_VERBOSITY = getModelsWithVerbosity()
|
|
export const PROVIDERS_WITH_TOOL_USAGE_CONTROL = getProvidersWithToolUsageControl()
|
|
|
|
/**
|
|
* Check if a model supports temperature parameter
|
|
*/
|
|
export function supportsTemperature(model: string): boolean {
|
|
return supportsTemperatureFromDefinitions(model)
|
|
}
|
|
|
|
/**
|
|
* Get the maximum temperature value for a model
|
|
* @returns Maximum temperature value (1 or 2) or undefined if temperature not supported
|
|
*/
|
|
export function getMaxTemperature(model: string): number | undefined {
|
|
return getMaxTempFromDefinitions(model)
|
|
}
|
|
|
|
/**
|
|
* Check if a provider supports tool usage control
|
|
*/
|
|
export function supportsToolUsageControl(provider: string): boolean {
|
|
return supportsToolUsageControlFromDefinitions(provider)
|
|
}
|
|
|
|
/**
|
|
* Get reasoning effort values for a specific model
|
|
* Returns the valid options for that model, or null if the model doesn't support reasoning effort
|
|
*/
|
|
export function getReasoningEffortValuesForModel(model: string): string[] | null {
|
|
return getReasoningEffortValuesForModelFromDefinitions(model)
|
|
}
|
|
|
|
/**
|
|
* Get verbosity values for a specific model
|
|
* Returns the valid options for that model, or null if the model doesn't support verbosity
|
|
*/
|
|
export function getVerbosityValuesForModel(model: string): string[] | null {
|
|
return getVerbosityValuesForModelFromDefinitions(model)
|
|
}
|
|
|
|
/**
|
|
* Prepare tool execution parameters, separating tool parameters from system parameters
|
|
*/
|
|
export function prepareToolExecution(
|
|
tool: { params?: Record<string, any>; parameters?: Record<string, any> },
|
|
llmArgs: Record<string, any>,
|
|
request: {
|
|
workflowId?: string
|
|
workspaceId?: string // Add workspaceId for MCP tools
|
|
chatId?: string
|
|
userId?: string
|
|
environmentVariables?: Record<string, any>
|
|
workflowVariables?: Record<string, any>
|
|
blockData?: Record<string, any>
|
|
blockNameMapping?: Record<string, string>
|
|
}
|
|
): {
|
|
toolParams: Record<string, any>
|
|
executionParams: Record<string, any>
|
|
} {
|
|
// Filter out empty/null/undefined values from user params
|
|
// so that cleared fields don't override LLM-generated values
|
|
const filteredUserParams: Record<string, any> = {}
|
|
if (tool.params) {
|
|
for (const [key, value] of Object.entries(tool.params)) {
|
|
if (value !== undefined && value !== null && value !== '') {
|
|
filteredUserParams[key] = value
|
|
}
|
|
}
|
|
}
|
|
|
|
// User-provided params take precedence over LLM-generated params
|
|
const toolParams = {
|
|
...llmArgs,
|
|
...filteredUserParams,
|
|
}
|
|
|
|
// Add system parameters for execution
|
|
const executionParams = {
|
|
...toolParams,
|
|
...(request.workflowId
|
|
? {
|
|
_context: {
|
|
workflowId: request.workflowId,
|
|
...(request.workspaceId ? { workspaceId: request.workspaceId } : {}),
|
|
...(request.chatId ? { chatId: request.chatId } : {}),
|
|
...(request.userId ? { userId: request.userId } : {}),
|
|
},
|
|
}
|
|
: {}),
|
|
...(request.environmentVariables ? { envVars: request.environmentVariables } : {}),
|
|
...(request.workflowVariables ? { workflowVariables: request.workflowVariables } : {}),
|
|
...(request.blockData ? { blockData: request.blockData } : {}),
|
|
...(request.blockNameMapping ? { blockNameMapping: request.blockNameMapping } : {}),
|
|
// Pass tool schema for MCP tools to skip discovery
|
|
...(tool.parameters ? { _toolSchema: tool.parameters } : {}),
|
|
}
|
|
|
|
return { toolParams, executionParams }
|
|
}
|