From c77e351067296267d456559fe5142fcaa4f109a1 Mon Sep 17 00:00:00 2001 From: Waleed Date: Wed, 4 Feb 2026 11:02:49 -0800 Subject: [PATCH] fix(providers): correct tool calling message format across all providers (#3132) * fix(providers): correct tool calling message format across all providers * fix(bedrock): correct timestamp char count in comment * chore(gemini): remove dead executeToolCall function * remove unused var --- apps/sim/lib/mcp/client.ts | 2 +- apps/sim/lib/mcp/types.ts | 17 +- apps/sim/providers/anthropic/index.ts | 147 ++++++++++----- apps/sim/providers/bedrock/utils.ts | 11 +- apps/sim/providers/deepseek/index.ts | 2 +- apps/sim/providers/gemini/core.ts | 249 ++++++++++++++++---------- apps/sim/providers/google/utils.ts | 12 ++ apps/sim/providers/groq/index.ts | 1 + apps/sim/providers/mistral/index.ts | 1 + 9 files changed, 300 insertions(+), 142 deletions(-) diff --git a/apps/sim/lib/mcp/client.ts b/apps/sim/lib/mcp/client.ts index b65e9a145..56375613f 100644 --- a/apps/sim/lib/mcp/client.ts +++ b/apps/sim/lib/mcp/client.ts @@ -156,7 +156,7 @@ export class McpClient { return result.tools.map((tool: Tool) => ({ name: tool.name, description: tool.description, - inputSchema: tool.inputSchema, + inputSchema: tool.inputSchema as McpTool['inputSchema'], serverId: this.config.id, serverName: this.config.name, })) diff --git a/apps/sim/lib/mcp/types.ts b/apps/sim/lib/mcp/types.ts index 2e5e37c2e..f9e7948f0 100644 --- a/apps/sim/lib/mcp/types.ts +++ b/apps/sim/lib/mcp/types.ts @@ -57,14 +57,29 @@ export interface McpSecurityPolicy { auditLevel: 'none' | 'basic' | 'detailed' } +/** + * JSON Schema property definition for tool parameters. + * Follows JSON Schema specification with description support. + */ +export interface McpToolSchemaProperty { + type: string + description?: string + items?: McpToolSchemaProperty + properties?: Record + required?: string[] + enum?: Array + default?: unknown +} + /** * JSON Schema for tool input parameters. * Aligns with MCP SDK's Tool.inputSchema structure. */ export interface McpToolSchema { type: 'object' - properties?: Record + properties?: Record required?: string[] + description?: string } /** diff --git a/apps/sim/providers/anthropic/index.ts b/apps/sim/providers/anthropic/index.ts index ebb111ffc..6ce89f589 100644 --- a/apps/sim/providers/anthropic/index.ts +++ b/apps/sim/providers/anthropic/index.ts @@ -6,7 +6,6 @@ import { MAX_TOOL_ITERATIONS } from '@/providers' import { checkForForcedToolUsage, createReadableStreamFromAnthropicStream, - generateToolUseId, } from '@/providers/anthropic/utils' import { getMaxOutputTokensForModel, @@ -433,11 +432,32 @@ export const anthropicProvider: ProviderConfig = { const executionResults = await Promise.allSettled(toolExecutionPromises) + // Collect all tool_use and tool_result blocks for batching + const toolUseBlocks: Array<{ + type: 'tool_use' + id: string + name: string + input: Record + }> = [] + const toolResultBlocks: Array<{ + type: 'tool_result' + tool_use_id: string + content: string + }> = [] + for (const settledResult of executionResults) { if (settledResult.status === 'rejected' || !settledResult.value) continue - const { toolName, toolArgs, toolParams, result, startTime, endTime, duration } = - settledResult.value + const { + toolUse, + toolName, + toolArgs, + toolParams, + result, + startTime, + endTime, + duration, + } = settledResult.value timeSegments.push({ type: 'tool', @@ -447,7 +467,7 @@ export const anthropicProvider: ProviderConfig = { duration: duration, }) - let resultContent: any + let resultContent: unknown if (result.success) { toolResults.push(result.output) resultContent = result.output @@ -469,29 +489,34 @@ export const anthropicProvider: ProviderConfig = { success: result.success, }) - const toolUseId = generateToolUseId(toolName) - - currentMessages.push({ - role: 'assistant', - content: [ - { - type: 'tool_use', - id: toolUseId, - name: toolName, - input: toolArgs, - } as any, - ], + // Add to batched arrays using the ORIGINAL ID from Claude's response + toolUseBlocks.push({ + type: 'tool_use', + id: toolUse.id, + name: toolName, + input: toolArgs, }) + toolResultBlocks.push({ + type: 'tool_result', + tool_use_id: toolUse.id, + content: JSON.stringify(resultContent), + }) + } + + // Add ONE assistant message with ALL tool_use blocks + if (toolUseBlocks.length > 0) { + currentMessages.push({ + role: 'assistant', + content: toolUseBlocks as unknown as Anthropic.Messages.ContentBlock[], + }) + } + + // Add ONE user message with ALL tool_result blocks + if (toolResultBlocks.length > 0) { currentMessages.push({ role: 'user', - content: [ - { - type: 'tool_result', - tool_use_id: toolUseId, - content: JSON.stringify(resultContent), - } as any, - ], + content: toolResultBlocks as unknown as Anthropic.Messages.ContentBlockParam[], }) } @@ -777,6 +802,8 @@ export const anthropicProvider: ProviderConfig = { const toolCallStartTime = Date.now() const toolName = toolUse.name const toolArgs = toolUse.input as Record + // Preserve the original tool_use ID from Claude's response + const toolUseId = toolUse.id try { const tool = request.tools?.find((t) => t.id === toolName) @@ -787,6 +814,7 @@ export const anthropicProvider: ProviderConfig = { const toolCallEndTime = Date.now() return { + toolUseId, toolName, toolArgs, toolParams, @@ -800,6 +828,7 @@ export const anthropicProvider: ProviderConfig = { logger.error('Error processing tool call:', { error, toolName }) return { + toolUseId, toolName, toolArgs, toolParams: {}, @@ -817,11 +846,32 @@ export const anthropicProvider: ProviderConfig = { const executionResults = await Promise.allSettled(toolExecutionPromises) + // Collect all tool_use and tool_result blocks for batching + const toolUseBlocks: Array<{ + type: 'tool_use' + id: string + name: string + input: Record + }> = [] + const toolResultBlocks: Array<{ + type: 'tool_result' + tool_use_id: string + content: string + }> = [] + for (const settledResult of executionResults) { if (settledResult.status === 'rejected' || !settledResult.value) continue - const { toolName, toolArgs, toolParams, result, startTime, endTime, duration } = - settledResult.value + const { + toolUseId, + toolName, + toolArgs, + toolParams, + result, + startTime, + endTime, + duration, + } = settledResult.value timeSegments.push({ type: 'tool', @@ -831,7 +881,7 @@ export const anthropicProvider: ProviderConfig = { duration: duration, }) - let resultContent: any + let resultContent: unknown if (result.success) { toolResults.push(result.output) resultContent = result.output @@ -853,29 +903,34 @@ export const anthropicProvider: ProviderConfig = { success: result.success, }) - const toolUseId = generateToolUseId(toolName) - - currentMessages.push({ - role: 'assistant', - content: [ - { - type: 'tool_use', - id: toolUseId, - name: toolName, - input: toolArgs, - } as any, - ], + // Add to batched arrays using the ORIGINAL ID from Claude's response + toolUseBlocks.push({ + type: 'tool_use', + id: toolUseId, + name: toolName, + input: toolArgs, }) + toolResultBlocks.push({ + type: 'tool_result', + tool_use_id: toolUseId, + content: JSON.stringify(resultContent), + }) + } + + // Add ONE assistant message with ALL tool_use blocks + if (toolUseBlocks.length > 0) { + currentMessages.push({ + role: 'assistant', + content: toolUseBlocks as unknown as Anthropic.Messages.ContentBlock[], + }) + } + + // Add ONE user message with ALL tool_result blocks + if (toolResultBlocks.length > 0) { currentMessages.push({ role: 'user', - content: [ - { - type: 'tool_result', - tool_use_id: toolUseId, - content: JSON.stringify(resultContent), - } as any, - ], + content: toolResultBlocks as unknown as Anthropic.Messages.ContentBlockParam[], }) } @@ -1061,7 +1116,7 @@ export const anthropicProvider: ProviderConfig = { startTime: tc.startTime, endTime: tc.endTime, duration: tc.duration, - result: tc.result, + result: tc.result as Record | undefined, })) : undefined, toolResults: toolResults.length > 0 ? toolResults : undefined, diff --git a/apps/sim/providers/bedrock/utils.ts b/apps/sim/providers/bedrock/utils.ts index 0b92f247b..9400d2378 100644 --- a/apps/sim/providers/bedrock/utils.ts +++ b/apps/sim/providers/bedrock/utils.ts @@ -67,8 +67,17 @@ export function checkForForcedToolUsage( return null } +/** + * Generates a unique tool use ID for Bedrock. + * AWS Bedrock requires toolUseId to be 1-64 characters, pattern [a-zA-Z0-9_-]+ + */ export function generateToolUseId(toolName: string): string { - return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}` + const timestamp = Date.now().toString(36) // Base36 timestamp (9 chars) + const random = Math.random().toString(36).substring(2, 7) // 5 random chars + const suffix = `-${timestamp}-${random}` // ~15 chars + const maxNameLength = 64 - suffix.length + const truncatedName = toolName.substring(0, maxNameLength).replace(/[^a-zA-Z0-9_-]/g, '_') + return `${truncatedName}${suffix}` } /** diff --git a/apps/sim/providers/deepseek/index.ts b/apps/sim/providers/deepseek/index.ts index 026342498..9ac7faa1d 100644 --- a/apps/sim/providers/deepseek/index.ts +++ b/apps/sim/providers/deepseek/index.ts @@ -76,7 +76,7 @@ export const deepseekProvider: ProviderConfig = { : undefined const payload: any = { - model: 'deepseek-chat', + model: request.model, messages: allMessages, } diff --git a/apps/sim/providers/gemini/core.ts b/apps/sim/providers/gemini/core.ts index 2dca22e5b..5050672ea 100644 --- a/apps/sim/providers/gemini/core.ts +++ b/apps/sim/providers/gemini/core.ts @@ -20,7 +20,7 @@ import { convertUsageMetadata, createReadableStreamFromGeminiStream, ensureStructResponse, - extractFunctionCallPart, + extractAllFunctionCallParts, extractTextContent, mapToThinkingLevel, } from '@/providers/google/utils' @@ -32,7 +32,7 @@ import { prepareToolsWithUsageControl, } from '@/providers/utils' import { executeTool } from '@/tools' -import type { ExecutionState, GeminiProviderType, GeminiUsage, ParsedFunctionCall } from './types' +import type { ExecutionState, GeminiProviderType, GeminiUsage } from './types' /** * Creates initial execution state @@ -79,101 +79,168 @@ function createInitialState( } /** - * Executes a tool call and updates state + * Executes multiple tool calls in parallel and updates state. + * Per Gemini docs, all function calls from a single response should be executed + * together, with one model message containing all function calls and one user + * message containing all function responses. */ -async function executeToolCall( - functionCallPart: Part, - functionCall: ParsedFunctionCall, +async function executeToolCallsBatch( + functionCallParts: Part[], request: ProviderRequest, state: ExecutionState, forcedTools: string[], logger: ReturnType ): Promise<{ success: boolean; state: ExecutionState }> { - const toolCallStartTime = Date.now() - const toolName = functionCall.name - - const tool = request.tools?.find((t) => t.id === toolName) - if (!tool) { - logger.warn(`Tool ${toolName} not found in registry, skipping`) + if (functionCallParts.length === 0) { return { success: false, state } } - try { - const { toolParams, executionParams } = prepareToolExecution(tool, functionCall.args, request) - const result = await executeTool(toolName, executionParams) - const toolCallEndTime = Date.now() - const duration = toolCallEndTime - toolCallStartTime + const executionPromises = functionCallParts.map(async (part) => { + const toolCallStartTime = Date.now() + const functionCall = part.functionCall! + const toolName = functionCall.name ?? '' + const args = (functionCall.args ?? {}) as Record - const resultContent: Record = result.success - ? ensureStructResponse(result.output) - : { error: true, message: result.error || 'Tool execution failed', tool: toolName } - - const toolCall: FunctionCallResponse = { - name: toolName, - arguments: toolParams, - startTime: new Date(toolCallStartTime).toISOString(), - endTime: new Date(toolCallEndTime).toISOString(), - duration, - result: resultContent, + const tool = request.tools?.find((t) => t.id === toolName) + if (!tool) { + logger.warn(`Tool ${toolName} not found in registry, skipping`) + return { + success: false, + part, + toolName, + args, + resultContent: { error: true, message: `Tool ${toolName} not found`, tool: toolName }, + toolParams: {}, + startTime: toolCallStartTime, + endTime: Date.now(), + duration: Date.now() - toolCallStartTime, + } } - const updatedContents: Content[] = [ - ...state.contents, - { - role: 'model', - parts: [functionCallPart], - }, - { - role: 'user', - parts: [ - { - functionResponse: { - name: functionCall.name, - response: resultContent, - }, - }, - ], - }, - ] + try { + const { toolParams, executionParams } = prepareToolExecution(tool, args, request) + const result = await executeTool(toolName, executionParams) + const toolCallEndTime = Date.now() + const duration = toolCallEndTime - toolCallStartTime - const forcedToolCheck = checkForForcedToolUsage( - [{ name: functionCall.name, args: functionCall.args }], - state.currentToolConfig, - forcedTools, - state.usedForcedTools - ) + const resultContent: Record = result.success + ? ensureStructResponse(result.output) + : { error: true, message: result.error || 'Tool execution failed', tool: toolName } - return { - success: true, - state: { - ...state, - contents: updatedContents, - toolCalls: [...state.toolCalls, toolCall], - toolResults: result.success - ? [...state.toolResults, result.output as Record] - : state.toolResults, - toolsTime: state.toolsTime + duration, - timeSegments: [ - ...state.timeSegments, - { - type: 'tool', - name: toolName, - startTime: toolCallStartTime, - endTime: toolCallEndTime, - duration, - }, - ], - usedForcedTools: forcedToolCheck?.usedForcedTools ?? state.usedForcedTools, - currentToolConfig: forcedToolCheck?.nextToolConfig ?? state.currentToolConfig, - }, + return { + success: result.success, + part, + toolName, + args, + resultContent, + toolParams, + result, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration, + } + } catch (error) { + const toolCallEndTime = Date.now() + logger.error('Error processing function call:', { + error: error instanceof Error ? error.message : String(error), + functionName: toolName, + }) + return { + success: false, + part, + toolName, + args, + resultContent: { + error: true, + message: error instanceof Error ? error.message : 'Tool execution failed', + tool: toolName, + }, + toolParams: {}, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallEndTime - toolCallStartTime, + } } - } catch (error) { - logger.error('Error processing function call:', { - error: error instanceof Error ? error.message : String(error), - functionName: toolName, - }) + }) + + const results = await Promise.all(executionPromises) + + // Check if at least one tool was found (not all failed due to missing tools) + const hasValidResults = results.some((r) => r.result !== undefined) + if (!hasValidResults && results.every((r) => !r.success)) { return { success: false, state } } + + // Build batched messages per Gemini spec: + // ONE model message with ALL function call parts + // ONE user message with ALL function responses + const modelParts: Part[] = results.map((r) => r.part) + const userParts: Part[] = results.map((r) => ({ + functionResponse: { + name: r.toolName, + response: r.resultContent, + }, + })) + + const updatedContents: Content[] = [ + ...state.contents, + { role: 'model', parts: modelParts }, + { role: 'user', parts: userParts }, + ] + + // Collect all tool calls and results + const newToolCalls: FunctionCallResponse[] = [] + const newToolResults: Record[] = [] + const newTimeSegments: ExecutionState['timeSegments'] = [] + let totalToolsTime = 0 + + for (const r of results) { + newToolCalls.push({ + name: r.toolName, + arguments: r.toolParams, + startTime: new Date(r.startTime).toISOString(), + endTime: new Date(r.endTime).toISOString(), + duration: r.duration, + result: r.resultContent, + }) + + if (r.success && r.result?.output) { + newToolResults.push(r.result.output as Record) + } + + newTimeSegments.push({ + type: 'tool', + name: r.toolName, + startTime: r.startTime, + endTime: r.endTime, + duration: r.duration, + }) + + totalToolsTime += r.duration + } + + // Check forced tool usage for all executed tools + const executedToolsInfo = results.map((r) => ({ name: r.toolName, args: r.args })) + const forcedToolCheck = checkForForcedToolUsage( + executedToolsInfo, + state.currentToolConfig, + forcedTools, + state.usedForcedTools + ) + + return { + success: true, + state: { + ...state, + contents: updatedContents, + toolCalls: [...state.toolCalls, ...newToolCalls], + toolResults: [...state.toolResults, ...newToolResults], + toolsTime: state.toolsTime + totalToolsTime, + timeSegments: [...state.timeSegments, ...newTimeSegments], + usedForcedTools: forcedToolCheck?.usedForcedTools ?? state.usedForcedTools, + currentToolConfig: forcedToolCheck?.nextToolConfig ?? state.currentToolConfig, + }, + } } /** @@ -506,27 +573,25 @@ export async function executeGeminiRequest( // Tool execution loop const functionCalls = response.functionCalls if (functionCalls?.length) { - logger.info(`Received function call from Gemini: ${functionCalls[0].name}`) + const functionNames = functionCalls.map((fc) => fc.name).join(', ') + logger.info(`Received ${functionCalls.length} function call(s) from Gemini: ${functionNames}`) while (state.iterationCount < MAX_TOOL_ITERATIONS) { - const functionCallPart = extractFunctionCallPart(currentResponse.candidates?.[0]) - if (!functionCallPart?.functionCall) { + // Extract ALL function call parts from the response (Gemini can return multiple) + const functionCallParts = extractAllFunctionCallParts(currentResponse.candidates?.[0]) + if (functionCallParts.length === 0) { content = extractTextContent(currentResponse.candidates?.[0]) break } - const functionCall: ParsedFunctionCall = { - name: functionCallPart.functionCall.name ?? '', - args: (functionCallPart.functionCall.args ?? {}) as Record, - } - + const callNames = functionCallParts.map((p) => p.functionCall?.name ?? 'unknown').join(', ') logger.info( - `Processing function call: ${functionCall.name} (iteration ${state.iterationCount + 1})` + `Processing ${functionCallParts.length} function call(s): ${callNames} (iteration ${state.iterationCount + 1})` ) - const { success, state: updatedState } = await executeToolCall( - functionCallPart, - functionCall, + // Execute ALL function calls in this batch + const { success, state: updatedState } = await executeToolCallsBatch( + functionCallParts, request, state, forcedTools, diff --git a/apps/sim/providers/google/utils.ts b/apps/sim/providers/google/utils.ts index c5040aab4..0a23b50da 100644 --- a/apps/sim/providers/google/utils.ts +++ b/apps/sim/providers/google/utils.ts @@ -109,6 +109,7 @@ export function extractFunctionCall(candidate: Candidate | undefined): ParsedFun /** * Extracts the full Part containing the function call (preserves thoughtSignature) + * @deprecated Use extractAllFunctionCallParts for proper multi-tool handling */ export function extractFunctionCallPart(candidate: Candidate | undefined): Part | null { if (!candidate?.content?.parts) return null @@ -122,6 +123,17 @@ export function extractFunctionCallPart(candidate: Candidate | undefined): Part return null } +/** + * Extracts ALL Parts containing function calls from a candidate. + * Gemini can return multiple function calls in a single response, + * and all should be executed before continuing the conversation. + */ +export function extractAllFunctionCallParts(candidate: Candidate | undefined): Part[] { + if (!candidate?.content?.parts) return [] + + return candidate.content.parts.filter((part) => part.functionCall) +} + /** * Converts usage metadata from SDK response to our format. * Per Gemini docs, total = promptTokenCount + candidatesTokenCount + toolUsePromptTokenCount + thoughtsTokenCount diff --git a/apps/sim/providers/groq/index.ts b/apps/sim/providers/groq/index.ts index 7be9b7386..bf253bc7d 100644 --- a/apps/sim/providers/groq/index.ts +++ b/apps/sim/providers/groq/index.ts @@ -320,6 +320,7 @@ export const groqProvider: ProviderConfig = { currentMessages.push({ role: 'tool', tool_call_id: toolCall.id, + name: toolName, content: JSON.stringify(resultContent), }) } diff --git a/apps/sim/providers/mistral/index.ts b/apps/sim/providers/mistral/index.ts index f99a3e210..fb3e701ed 100644 --- a/apps/sim/providers/mistral/index.ts +++ b/apps/sim/providers/mistral/index.ts @@ -383,6 +383,7 @@ export const mistralProvider: ProviderConfig = { currentMessages.push({ role: 'tool', tool_call_id: toolCall.id, + name: toolName, content: JSON.stringify(resultContent), }) }