From 4df1c8268cadfdff7e6089b6b8c7f4679a03065e Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Wed, 5 Feb 2025 18:27:11 -0800 Subject: [PATCH] Removed duplicate logic to map model name onto provider --- executor/index.ts | 24 +++--------------------- providers/utils.ts | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 21 deletions(-) create mode 100644 providers/utils.ts diff --git a/executor/index.ts b/executor/index.ts index d3079376c..072c67ebb 100644 --- a/executor/index.ts +++ b/executor/index.ts @@ -13,6 +13,7 @@ import { generateRouterPrompt } from '@/blocks/blocks/router' import { BlockOutput } from '@/blocks/types' import { BlockConfig } from '@/blocks/types' import { executeProviderRequest } from '@/providers/service' +import { getProviderFromModel } from '@/providers/utils' import { SerializedBlock, SerializedWorkflow } from '@/serializer/types' import { executeTool, getTool, tools } from '@/tools' import { BlockLog, ExecutionContext, ExecutionResult, Tool } from './types' @@ -262,16 +263,7 @@ export class Executor { // }); const model = inputs.model || 'gpt-4o' - const providerId = - model.startsWith('gpt') || model.startsWith('o1') - ? 'openai' - : model.startsWith('claude') - ? 'anthropic' - : model.startsWith('gemini') - ? 'google' - : model.startsWith('grok') - ? 'xai' - : 'deepseek' + const providerId = getProviderFromModel(model) // Format tools if they exist const tools = Array.isArray(inputs.tools) @@ -629,18 +621,8 @@ export class Executor { temperature: resolvedInputs.temperature || 0, } - // Determine provider based on model const model = routerConfig.model || 'gpt-4o' - const providerId = - model.startsWith('gpt') || model.startsWith('o1') - ? 'openai' - : model.startsWith('claude') - ? 'anthropic' - : model.startsWith('gemini') - ? 'google' - : model.startsWith('grok') - ? 'xai' - : 'deepseek' + const providerId = getProviderFromModel(model) const response = await executeProviderRequest(providerId, { model: routerConfig.model, diff --git a/providers/utils.ts b/providers/utils.ts new file mode 100644 index 000000000..1bc9d9095 --- /dev/null +++ b/providers/utils.ts @@ -0,0 +1,40 @@ +import { MODEL_TOOLS, ModelType } from '@/blocks/consts' +import { ProviderId } from './registry' + +/** + * Determines the provider ID based on the model name. + * Uses the existing MODEL_TOOLS mapping and falls back to pattern matching if needed. + * + * @param model - The model name/identifier + * @returns The corresponding provider ID + */ +export function getProviderFromModel(model: string): ProviderId { + const normalizedModel = model.toLowerCase() + + // First try to match exactly from our MODEL_TOOLS mapping + if (normalizedModel in MODEL_TOOLS) { + const toolId = MODEL_TOOLS[normalizedModel as ModelType] + // Extract provider ID from tool ID (e.g., 'openai_chat' -> 'openai') + return toolId.split('_')[0] as ProviderId + } + + // If no exact match, use pattern matching as fallback + if (normalizedModel.startsWith('gpt') || normalizedModel.startsWith('o1')) { + return 'openai' + } + + if (normalizedModel.startsWith('claude')) { + return 'anthropic' + } + + if (normalizedModel.startsWith('gemini')) { + return 'google' + } + + if (normalizedModel.startsWith('grok')) { + return 'xai' + } + + // Default to deepseek for any other models + return 'deepseek' +}