fix(providers): correct tool calling message format across all providers

This commit is contained in:
waleed
2026-02-04 01:20:49 -08:00
parent 0bc245b7a9
commit c86e779160
9 changed files with 322 additions and 63 deletions

View File

@@ -155,7 +155,7 @@ export class McpClient {
return result.tools.map((tool: Tool) => ({
name: tool.name,
description: tool.description,
inputSchema: tool.inputSchema,
inputSchema: tool.inputSchema as McpTool['inputSchema'],
serverId: this.config.id,
serverName: this.config.name,
}))

View File

@@ -57,14 +57,29 @@ export interface McpSecurityPolicy {
auditLevel: 'none' | 'basic' | 'detailed'
}
/**
* JSON Schema property definition for tool parameters.
* Follows JSON Schema specification with description support.
*/
export interface McpToolSchemaProperty {
type: string
description?: string
items?: McpToolSchemaProperty
properties?: Record<string, McpToolSchemaProperty>
required?: string[]
enum?: Array<string | number | boolean>
default?: unknown
}
/**
* JSON Schema for tool input parameters.
* Aligns with MCP SDK's Tool.inputSchema structure.
*/
export interface McpToolSchema {
type: 'object'
properties?: Record<string, unknown>
properties?: Record<string, McpToolSchemaProperty>
required?: string[]
description?: string
}
/**

View File

@@ -6,7 +6,6 @@ import { MAX_TOOL_ITERATIONS } from '@/providers'
import {
checkForForcedToolUsage,
createReadableStreamFromAnthropicStream,
generateToolUseId,
} from '@/providers/anthropic/utils'
import {
getMaxOutputTokensForModel,
@@ -433,11 +432,32 @@ export const anthropicProvider: ProviderConfig = {
const executionResults = await Promise.allSettled(toolExecutionPromises)
// Collect all tool_use and tool_result blocks for batching
const toolUseBlocks: Array<{
type: 'tool_use'
id: string
name: string
input: Record<string, unknown>
}> = []
const toolResultBlocks: Array<{
type: 'tool_result'
tool_use_id: string
content: string
}> = []
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolName, toolArgs, toolParams, result, startTime, endTime, duration } =
settledResult.value
const {
toolUse,
toolName,
toolArgs,
toolParams,
result,
startTime,
endTime,
duration,
} = settledResult.value
timeSegments.push({
type: 'tool',
@@ -447,7 +467,7 @@ export const anthropicProvider: ProviderConfig = {
duration: duration,
})
let resultContent: any
let resultContent: unknown
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
@@ -469,29 +489,34 @@ export const anthropicProvider: ProviderConfig = {
success: result.success,
})
const toolUseId = generateToolUseId(toolName)
currentMessages.push({
role: 'assistant',
content: [
{
type: 'tool_use',
id: toolUseId,
name: toolName,
input: toolArgs,
} as any,
],
// Add to batched arrays using the ORIGINAL ID from Claude's response
toolUseBlocks.push({
type: 'tool_use',
id: toolUse.id,
name: toolName,
input: toolArgs,
})
toolResultBlocks.push({
type: 'tool_result',
tool_use_id: toolUse.id,
content: JSON.stringify(resultContent),
})
}
// Add ONE assistant message with ALL tool_use blocks
if (toolUseBlocks.length > 0) {
currentMessages.push({
role: 'assistant',
content: toolUseBlocks as unknown as Anthropic.Messages.ContentBlock[],
})
}
// Add ONE user message with ALL tool_result blocks
if (toolResultBlocks.length > 0) {
currentMessages.push({
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: toolUseId,
content: JSON.stringify(resultContent),
} as any,
],
content: toolResultBlocks as unknown as Anthropic.Messages.ContentBlockParam[],
})
}
@@ -777,6 +802,8 @@ export const anthropicProvider: ProviderConfig = {
const toolCallStartTime = Date.now()
const toolName = toolUse.name
const toolArgs = toolUse.input as Record<string, any>
// Preserve the original tool_use ID from Claude's response
const toolUseId = toolUse.id
try {
const tool = request.tools?.find((t) => t.id === toolName)
@@ -787,6 +814,7 @@ export const anthropicProvider: ProviderConfig = {
const toolCallEndTime = Date.now()
return {
toolUseId,
toolName,
toolArgs,
toolParams,
@@ -800,6 +828,7 @@ export const anthropicProvider: ProviderConfig = {
logger.error('Error processing tool call:', { error, toolName })
return {
toolUseId,
toolName,
toolArgs,
toolParams: {},
@@ -817,11 +846,32 @@ export const anthropicProvider: ProviderConfig = {
const executionResults = await Promise.allSettled(toolExecutionPromises)
// Collect all tool_use and tool_result blocks for batching
const toolUseBlocks: Array<{
type: 'tool_use'
id: string
name: string
input: Record<string, unknown>
}> = []
const toolResultBlocks: Array<{
type: 'tool_result'
tool_use_id: string
content: string
}> = []
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolName, toolArgs, toolParams, result, startTime, endTime, duration } =
settledResult.value
const {
toolUseId,
toolName,
toolArgs,
toolParams,
result,
startTime,
endTime,
duration,
} = settledResult.value
timeSegments.push({
type: 'tool',
@@ -831,7 +881,7 @@ export const anthropicProvider: ProviderConfig = {
duration: duration,
})
let resultContent: any
let resultContent: unknown
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
@@ -853,29 +903,34 @@ export const anthropicProvider: ProviderConfig = {
success: result.success,
})
const toolUseId = generateToolUseId(toolName)
currentMessages.push({
role: 'assistant',
content: [
{
type: 'tool_use',
id: toolUseId,
name: toolName,
input: toolArgs,
} as any,
],
// Add to batched arrays using the ORIGINAL ID from Claude's response
toolUseBlocks.push({
type: 'tool_use',
id: toolUseId,
name: toolName,
input: toolArgs,
})
toolResultBlocks.push({
type: 'tool_result',
tool_use_id: toolUseId,
content: JSON.stringify(resultContent),
})
}
// Add ONE assistant message with ALL tool_use blocks
if (toolUseBlocks.length > 0) {
currentMessages.push({
role: 'assistant',
content: toolUseBlocks as unknown as Anthropic.Messages.ContentBlock[],
})
}
// Add ONE user message with ALL tool_result blocks
if (toolResultBlocks.length > 0) {
currentMessages.push({
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: toolUseId,
content: JSON.stringify(resultContent),
} as any,
],
content: toolResultBlocks as unknown as Anthropic.Messages.ContentBlockParam[],
})
}
@@ -1061,7 +1116,7 @@ export const anthropicProvider: ProviderConfig = {
startTime: tc.startTime,
endTime: tc.endTime,
duration: tc.duration,
result: tc.result,
result: tc.result as Record<string, unknown> | undefined,
}))
: undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,

View File

@@ -67,8 +67,17 @@ export function checkForForcedToolUsage(
return null
}
/**
* Generates a unique tool use ID for Bedrock.
* AWS Bedrock requires toolUseId to be 1-64 characters, pattern [a-zA-Z0-9_-]+
*/
export function generateToolUseId(toolName: string): string {
return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}`
const timestamp = Date.now().toString(36) // Base36 timestamp (8 chars)
const random = Math.random().toString(36).substring(2, 7) // 5 random chars
const suffix = `-${timestamp}-${random}` // ~15 chars
const maxNameLength = 64 - suffix.length
const truncatedName = toolName.substring(0, maxNameLength).replace(/[^a-zA-Z0-9_-]/g, '_')
return `${truncatedName}${suffix}`
}
/**

View File

@@ -76,7 +76,7 @@ export const deepseekProvider: ProviderConfig = {
: undefined
const payload: any = {
model: 'deepseek-chat',
model: request.model,
messages: allMessages,
}

View File

@@ -20,7 +20,7 @@ import {
convertUsageMetadata,
createReadableStreamFromGeminiStream,
ensureStructResponse,
extractFunctionCallPart,
extractAllFunctionCallParts,
extractTextContent,
mapToThinkingLevel,
} from '@/providers/google/utils'
@@ -176,6 +176,174 @@ async function executeToolCall(
}
}
/**
* Executes multiple tool calls in parallel and updates state.
* Per Gemini docs, all function calls from a single response should be executed
* together, with one model message containing all function calls and one user
* message containing all function responses.
*/
async function executeToolCallsBatch(
functionCallParts: Part[],
request: ProviderRequest,
state: ExecutionState,
forcedTools: string[],
logger: ReturnType<typeof createLogger>
): Promise<{ success: boolean; state: ExecutionState }> {
if (functionCallParts.length === 0) {
return { success: false, state }
}
const batchStartTime = Date.now()
// Execute all tool calls in parallel
const executionPromises = functionCallParts.map(async (part) => {
const toolCallStartTime = Date.now()
const functionCall = part.functionCall!
const toolName = functionCall.name ?? ''
const args = (functionCall.args ?? {}) as Record<string, unknown>
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) {
logger.warn(`Tool ${toolName} not found in registry, skipping`)
return {
success: false,
part,
toolName,
args,
resultContent: { error: true, message: `Tool ${toolName} not found`, tool: toolName },
toolParams: {},
startTime: toolCallStartTime,
endTime: Date.now(),
duration: Date.now() - toolCallStartTime,
}
}
try {
const { toolParams, executionParams } = prepareToolExecution(tool, args, request)
const result = await executeTool(toolName, executionParams)
const toolCallEndTime = Date.now()
const duration = toolCallEndTime - toolCallStartTime
const resultContent: Record<string, unknown> = result.success
? ensureStructResponse(result.output)
: { error: true, message: result.error || 'Tool execution failed', tool: toolName }
return {
success: result.success,
part,
toolName,
args,
resultContent,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('Error processing function call:', {
error: error instanceof Error ? error.message : String(error),
functionName: toolName,
})
return {
success: false,
part,
toolName,
args,
resultContent: {
error: true,
message: error instanceof Error ? error.message : 'Tool execution failed',
tool: toolName,
},
toolParams: {},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const results = await Promise.all(executionPromises)
// Check if at least one tool was found (not all failed due to missing tools)
const hasValidResults = results.some((r) => r.result !== undefined)
if (!hasValidResults && results.every((r) => !r.success)) {
return { success: false, state }
}
// Build batched messages per Gemini spec:
// ONE model message with ALL function call parts
// ONE user message with ALL function responses
const modelParts: Part[] = results.map((r) => r.part)
const userParts: Part[] = results.map((r) => ({
functionResponse: {
name: r.toolName,
response: r.resultContent,
},
}))
const updatedContents: Content[] = [
...state.contents,
{ role: 'model', parts: modelParts },
{ role: 'user', parts: userParts },
]
// Collect all tool calls and results
const newToolCalls: FunctionCallResponse[] = []
const newToolResults: Record<string, unknown>[] = []
const newTimeSegments: ExecutionState['timeSegments'] = []
let totalToolsTime = 0
for (const r of results) {
newToolCalls.push({
name: r.toolName,
arguments: r.toolParams,
startTime: new Date(r.startTime).toISOString(),
endTime: new Date(r.endTime).toISOString(),
duration: r.duration,
result: r.resultContent,
})
if (r.success && r.result?.output) {
newToolResults.push(r.result.output as Record<string, unknown>)
}
newTimeSegments.push({
type: 'tool',
name: r.toolName,
startTime: r.startTime,
endTime: r.endTime,
duration: r.duration,
})
totalToolsTime += r.duration
}
// Check forced tool usage for all executed tools
const executedToolsInfo = results.map((r) => ({ name: r.toolName, args: r.args }))
const forcedToolCheck = checkForForcedToolUsage(
executedToolsInfo,
state.currentToolConfig,
forcedTools,
state.usedForcedTools
)
return {
success: true,
state: {
...state,
contents: updatedContents,
toolCalls: [...state.toolCalls, ...newToolCalls],
toolResults: [...state.toolResults, ...newToolResults],
toolsTime: state.toolsTime + totalToolsTime,
timeSegments: [...state.timeSegments, ...newTimeSegments],
usedForcedTools: forcedToolCheck?.usedForcedTools ?? state.usedForcedTools,
currentToolConfig: forcedToolCheck?.nextToolConfig ?? state.currentToolConfig,
},
}
}
/**
* Updates state with model response metadata
*/
@@ -506,27 +674,25 @@ export async function executeGeminiRequest(
// Tool execution loop
const functionCalls = response.functionCalls
if (functionCalls?.length) {
logger.info(`Received function call from Gemini: ${functionCalls[0].name}`)
const functionNames = functionCalls.map((fc) => fc.name).join(', ')
logger.info(`Received ${functionCalls.length} function call(s) from Gemini: ${functionNames}`)
while (state.iterationCount < MAX_TOOL_ITERATIONS) {
const functionCallPart = extractFunctionCallPart(currentResponse.candidates?.[0])
if (!functionCallPart?.functionCall) {
// Extract ALL function call parts from the response (Gemini can return multiple)
const functionCallParts = extractAllFunctionCallParts(currentResponse.candidates?.[0])
if (functionCallParts.length === 0) {
content = extractTextContent(currentResponse.candidates?.[0])
break
}
const functionCall: ParsedFunctionCall = {
name: functionCallPart.functionCall.name ?? '',
args: (functionCallPart.functionCall.args ?? {}) as Record<string, unknown>,
}
const callNames = functionCallParts.map((p) => p.functionCall?.name ?? 'unknown').join(', ')
logger.info(
`Processing function call: ${functionCall.name} (iteration ${state.iterationCount + 1})`
`Processing ${functionCallParts.length} function call(s): ${callNames} (iteration ${state.iterationCount + 1})`
)
const { success, state: updatedState } = await executeToolCall(
functionCallPart,
functionCall,
// Execute ALL function calls in this batch
const { success, state: updatedState } = await executeToolCallsBatch(
functionCallParts,
request,
state,
forcedTools,

View File

@@ -109,6 +109,7 @@ export function extractFunctionCall(candidate: Candidate | undefined): ParsedFun
/**
* Extracts the full Part containing the function call (preserves thoughtSignature)
* @deprecated Use extractAllFunctionCallParts for proper multi-tool handling
*/
export function extractFunctionCallPart(candidate: Candidate | undefined): Part | null {
if (!candidate?.content?.parts) return null
@@ -122,6 +123,17 @@ export function extractFunctionCallPart(candidate: Candidate | undefined): Part
return null
}
/**
* Extracts ALL Parts containing function calls from a candidate.
* Gemini can return multiple function calls in a single response,
* and all should be executed before continuing the conversation.
*/
export function extractAllFunctionCallParts(candidate: Candidate | undefined): Part[] {
if (!candidate?.content?.parts) return []
return candidate.content.parts.filter((part) => part.functionCall)
}
/**
* Converts usage metadata from SDK response to our format.
* Per Gemini docs, total = promptTokenCount + candidatesTokenCount + toolUsePromptTokenCount + thoughtsTokenCount

View File

@@ -320,6 +320,7 @@ export const groqProvider: ProviderConfig = {
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
name: toolName,
content: JSON.stringify(resultContent),
})
}

View File

@@ -383,6 +383,7 @@ export const mistralProvider: ProviderConfig = {
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
name: toolName,
content: JSON.stringify(resultContent),
})
}