fix(providers): fixed xai response format + tool calls not working when used together (#455)

* fix xai response format + tool calls not working when used together

* removed extraneous comments
This commit is contained in:
Waleed Latif
2025-06-03 15:16:50 -07:00
committed by GitHub
parent f3a405364f
commit fc7171b038

View File

@@ -44,14 +44,20 @@ export const xAIProvider: ProviderConfig = {
throw new Error('API key is required for xAI')
}
// Initialize OpenAI client for xAI
const xai = new OpenAI({
apiKey: request.apiKey,
baseURL: 'https://api.x.ai/v1',
})
// Prepare messages
const allMessages = []
logger.info('XAI Provider - Initial request configuration:', {
hasTools: !!request.tools?.length,
toolCount: request.tools?.length || 0,
hasResponseFormat: !!request.responseFormat,
model: request.model || 'grok-3-latest',
streaming: !!request.stream,
})
const allMessages: any[] = []
if (request.systemPrompt) {
allMessages.push({
@@ -83,34 +89,41 @@ export const xAIProvider: ProviderConfig = {
}))
: undefined
// Build the request payload
const payload: any = {
// Log tools and response format conflict detection
if (tools?.length && request.responseFormat) {
logger.warn(
'XAI Provider - Detected both tools and response format. Using tools first, then response format for final response.'
)
}
// Build the base request payload
const basePayload: any = {
model: request.model || 'grok-3-latest',
messages: allMessages,
}
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
if (request.temperature !== undefined) basePayload.temperature = request.temperature
if (request.maxTokens !== undefined) basePayload.max_tokens = request.maxTokens
if (request.responseFormat) {
payload.response_format = {
type: 'json_schema',
json_schema: {
name: request.responseFormat.name || 'structured_response',
schema: request.responseFormat.schema || request.responseFormat,
strict: request.responseFormat.strict !== false,
},
// Function to create response format configuration
const createResponseFormatPayload = (messages: any[] = allMessages) => {
const payload = {
...basePayload,
messages,
}
if (allMessages.length > 0 && allMessages[0].role === 'system') {
allMessages[0].content = `${allMessages[0].content}\n\nYou MUST respond with a valid JSON object. DO NOT include any other text, explanations, or markdown formatting in your response - ONLY the JSON object.`
} else {
allMessages.unshift({
role: 'system',
content:
'You MUST respond with a valid JSON object. DO NOT include any other text, explanations, or markdown formatting in your response - ONLY the JSON object.',
})
if (request.responseFormat) {
payload.response_format = {
type: 'json_schema',
json_schema: {
name: request.responseFormat.name || 'structured_response',
schema: request.responseFormat.schema || request.responseFormat,
strict: request.responseFormat.strict !== false,
},
}
}
return payload
}
// Handle tools and tool usage control
@@ -118,42 +131,29 @@ export const xAIProvider: ProviderConfig = {
if (tools?.length) {
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'xai')
const { tools: filteredTools, toolChoice } = preparedTools
if (filteredTools?.length && toolChoice) {
payload.tools = filteredTools
payload.tool_choice = toolChoice
logger.info('XAI request configuration:', {
toolCount: filteredTools.length,
toolChoice:
typeof toolChoice === 'string'
? toolChoice
: toolChoice.type === 'function'
? `force:${toolChoice.function.name}`
: toolChoice.type === 'tool'
? `force:${toolChoice.name}`
: toolChoice.type === 'any'
? `force:${toolChoice.any?.name || 'unknown'}`
: 'unknown',
model: request.model || 'grok-3-latest',
})
}
}
// EARLY STREAMING: if caller requested streaming and there are no tools to execute,
// we can directly stream the completion.
// we can directly stream the completion with response format if needed
if (request.stream && (!tools || tools.length === 0)) {
logger.info('Using streaming response for XAI request (no tools)')
logger.info('XAI Provider - Using direct streaming (no tools)')
// Start execution timer for the entire provider execution
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
const streamResponse = await xai.chat.completions.create({
...payload,
stream: true,
})
// Use response format payload if needed, otherwise use base payload
const streamingPayload = request.responseFormat
? createResponseFormatPayload()
: { ...basePayload, stream: true }
if (!request.responseFormat) {
streamingPayload.stream = true
} else {
streamingPayload.stream = true
}
const streamResponse = await xai.chat.completions.create(streamingPayload)
// Start collecting token usage
const tokenUsage = {
@@ -217,14 +217,29 @@ export const xAIProvider: ProviderConfig = {
// Make the initial API request
const initialCallTime = Date.now()
// For the initial request with tools, we NEVER include response_format
// This is the key fix: tools and response_format cannot be used together with xAI
const initialPayload = { ...basePayload }
// Track the original tool_choice for forced tool tracking
const originalToolChoice = payload.tool_choice
let originalToolChoice: any
// Track forced tools and their usage
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
let currentResponse = await xai.chat.completions.create(payload)
if (preparedTools?.tools?.length && preparedTools.toolChoice) {
const { tools: filteredTools, toolChoice } = preparedTools
initialPayload.tools = filteredTools
initialPayload.tool_choice = toolChoice
originalToolChoice = toolChoice
} else if (request.responseFormat) {
// Only add response format if there are no tools
const responseFormatPayload = createResponseFormatPayload()
Object.assign(initialPayload, responseFormatPayload)
}
let currentResponse = await xai.chat.completions.create(initialPayload)
const firstResponseTime = Date.now() - initialCallTime
let content = currentResponse.choices[0]?.message?.content || ''
@@ -278,7 +293,9 @@ export const xAIProvider: ProviderConfig = {
}
// Check if a forced tool was used in the first response
checkForForcedToolUsage(currentResponse, originalToolChoice)
if (originalToolChoice) {
checkForForcedToolUsage(currentResponse, originalToolChoice)
}
try {
while (iterationCount < MAX_ITERATIONS) {
@@ -297,7 +314,10 @@ export const xAIProvider: ProviderConfig = {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
if (!tool) {
logger.warn('XAI Provider - Tool not found:', { toolName })
continue
}
const toolCallStartTime = Date.now()
const mergedArgs = {
@@ -309,7 +329,13 @@ export const xAIProvider: ProviderConfig = {
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
if (!result.success) continue
if (!result.success) {
logger.warn('XAI Provider - Tool execution failed:', {
toolName,
error: result.error,
})
continue
}
// Add to time segments
timeSegments.push({
@@ -351,7 +377,10 @@ export const xAIProvider: ProviderConfig = {
content: JSON.stringify(result.output),
})
} catch (error) {
logger.error('Error processing tool call:', { error })
logger.error('XAI Provider - Error processing tool call:', {
error: error instanceof Error ? error.message : String(error),
toolCall: toolCall.function.name,
})
}
}
@@ -359,10 +388,8 @@ export const xAIProvider: ProviderConfig = {
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
const nextPayload = {
...payload,
messages: currentMessages,
}
// After tool calls, create next payload based on whether we need more tools or final response
let nextPayload: any
// Update tool_choice based on which forced tools have been used
if (
@@ -374,16 +401,41 @@ export const xAIProvider: ProviderConfig = {
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
if (remainingTools.length > 0) {
// Force the next tool
nextPayload.tool_choice = {
type: 'function',
function: { name: remainingTools[0] },
// Force the next tool - continue with tools, no response format
nextPayload = {
...basePayload,
messages: currentMessages,
tools: preparedTools?.tools,
tool_choice: {
type: 'function',
function: { name: remainingTools[0] },
},
}
logger.info(`Forcing next tool: ${remainingTools[0]}`)
} else {
// All forced tools have been used, switch to auto
nextPayload.tool_choice = 'auto'
logger.info('All forced tools have been used, switching to auto tool_choice')
// All forced tools have been used, check if we need response format for final response
if (request.responseFormat) {
nextPayload = createResponseFormatPayload(currentMessages)
} else {
nextPayload = {
...basePayload,
messages: currentMessages,
tool_choice: 'auto',
tools: preparedTools?.tools,
}
}
}
} else {
// Normal tool processing - check if this might be the final response
if (request.responseFormat) {
// Use response format for what might be the final response
nextPayload = createResponseFormatPayload(currentMessages)
} else {
nextPayload = {
...basePayload,
messages: currentMessages,
tools: preparedTools?.tools,
tool_choice: 'auto',
}
}
}
@@ -393,7 +445,9 @@ export const xAIProvider: ProviderConfig = {
currentResponse = await xai.chat.completions.create(nextPayload)
// Check if any forced tools were used in this response
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') {
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
}
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
@@ -423,23 +477,35 @@ export const xAIProvider: ProviderConfig = {
iterationCount++
}
} catch (error) {
logger.error('Error in xAI request:', { error })
logger.error('XAI Provider - Error in tool processing loop:', {
error: error instanceof Error ? error.message : String(error),
iterationCount,
})
}
// After all tool processing complete, if streaming was requested and we have messages, use streaming for the final response
if (request.stream && iterationCount > 0) {
logger.info('Using streaming for final XAI response after tool calls')
// For final streaming response, choose between tools (auto) or response_format (never both)
let finalStreamingPayload: any
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
// This prevents the API from trying to force tool usage again in the final streaming response
const streamingPayload = {
...payload,
messages: currentMessages,
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
stream: true,
if (request.responseFormat) {
// Use response format, no tools
finalStreamingPayload = {
...createResponseFormatPayload(currentMessages),
stream: true,
}
} else {
// Use tools with auto choice
finalStreamingPayload = {
...basePayload,
messages: currentMessages,
tool_choice: 'auto',
tools: preparedTools?.tools,
stream: true,
}
}
const streamResponse = await xai.chat.completions.create(streamingPayload)
const streamResponse = await xai.chat.completions.create(finalStreamingPayload)
// Create a StreamingExecution response with all collected data
const streamingResult = {
@@ -498,6 +564,14 @@ export const xAIProvider: ProviderConfig = {
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
logger.info('XAI Provider - Request completed:', {
totalDuration,
iterationCount: iterationCount + 1,
toolCallCount: toolCalls.length,
hasContent: !!content,
contentLength: content?.length || 0,
})
return {
content,
model: request.model,
@@ -521,7 +595,12 @@ export const xAIProvider: ProviderConfig = {
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
logger.error('Error in xAI request:', { error, duration: totalDuration })
logger.error('XAI Provider - Request failed:', {
error: error instanceof Error ? error.message : String(error),
duration: totalDuration,
hasTools: !!tools?.length,
hasResponseFormat: !!request.responseFormat,
})
// Create a new error with timing information
const enhancedError = new Error(error instanceof Error ? error.message : String(error))