Removed duplicate logic to map model name onto provider

This commit is contained in:
Waleed Latif
2025-02-05 18:27:11 -08:00
parent ac079cc295
commit 4df1c8268c
2 changed files with 43 additions and 21 deletions

View File

@@ -13,6 +13,7 @@ import { generateRouterPrompt } from '@/blocks/blocks/router'
import { BlockOutput } from '@/blocks/types'
import { BlockConfig } from '@/blocks/types'
import { executeProviderRequest } from '@/providers/service'
import { getProviderFromModel } from '@/providers/utils'
import { SerializedBlock, SerializedWorkflow } from '@/serializer/types'
import { executeTool, getTool, tools } from '@/tools'
import { BlockLog, ExecutionContext, ExecutionResult, Tool } from './types'
@@ -262,16 +263,7 @@ export class Executor {
// });
const model = inputs.model || 'gpt-4o'
const providerId =
model.startsWith('gpt') || model.startsWith('o1')
? 'openai'
: model.startsWith('claude')
? 'anthropic'
: model.startsWith('gemini')
? 'google'
: model.startsWith('grok')
? 'xai'
: 'deepseek'
const providerId = getProviderFromModel(model)
// Format tools if they exist
const tools = Array.isArray(inputs.tools)
@@ -629,18 +621,8 @@ export class Executor {
temperature: resolvedInputs.temperature || 0,
}
// Determine provider based on model
const model = routerConfig.model || 'gpt-4o'
const providerId =
model.startsWith('gpt') || model.startsWith('o1')
? 'openai'
: model.startsWith('claude')
? 'anthropic'
: model.startsWith('gemini')
? 'google'
: model.startsWith('grok')
? 'xai'
: 'deepseek'
const providerId = getProviderFromModel(model)
const response = await executeProviderRequest(providerId, {
model: routerConfig.model,

40
providers/utils.ts Normal file
View File

@@ -0,0 +1,40 @@
import { MODEL_TOOLS, ModelType } from '@/blocks/consts'
import { ProviderId } from './registry'
/**
* Determines the provider ID based on the model name.
* Uses the existing MODEL_TOOLS mapping and falls back to pattern matching if needed.
*
* @param model - The model name/identifier
* @returns The corresponding provider ID
*/
export function getProviderFromModel(model: string): ProviderId {
const normalizedModel = model.toLowerCase()
// First try to match exactly from our MODEL_TOOLS mapping
if (normalizedModel in MODEL_TOOLS) {
const toolId = MODEL_TOOLS[normalizedModel as ModelType]
// Extract provider ID from tool ID (e.g., 'openai_chat' -> 'openai')
return toolId.split('_')[0] as ProviderId
}
// If no exact match, use pattern matching as fallback
if (normalizedModel.startsWith('gpt') || normalizedModel.startsWith('o1')) {
return 'openai'
}
if (normalizedModel.startsWith('claude')) {
return 'anthropic'
}
if (normalizedModel.startsWith('gemini')) {
return 'google'
}
if (normalizedModel.startsWith('grok')) {
return 'xai'
}
// Default to deepseek for any other models
return 'deepseek'
}