mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-09 23:17:59 -05:00
fix(providers): fixed xai response format + tool calls not working when used together (#455)
* fix xai response format + tool calls not working when used together * removed extraneous comments
This commit is contained in:
@@ -44,14 +44,20 @@ export const xAIProvider: ProviderConfig = {
|
||||
throw new Error('API key is required for xAI')
|
||||
}
|
||||
|
||||
// Initialize OpenAI client for xAI
|
||||
const xai = new OpenAI({
|
||||
apiKey: request.apiKey,
|
||||
baseURL: 'https://api.x.ai/v1',
|
||||
})
|
||||
|
||||
// Prepare messages
|
||||
const allMessages = []
|
||||
logger.info('XAI Provider - Initial request configuration:', {
|
||||
hasTools: !!request.tools?.length,
|
||||
toolCount: request.tools?.length || 0,
|
||||
hasResponseFormat: !!request.responseFormat,
|
||||
model: request.model || 'grok-3-latest',
|
||||
streaming: !!request.stream,
|
||||
})
|
||||
|
||||
const allMessages: any[] = []
|
||||
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
@@ -83,34 +89,41 @@ export const xAIProvider: ProviderConfig = {
|
||||
}))
|
||||
: undefined
|
||||
|
||||
// Build the request payload
|
||||
const payload: any = {
|
||||
// Log tools and response format conflict detection
|
||||
if (tools?.length && request.responseFormat) {
|
||||
logger.warn(
|
||||
'XAI Provider - Detected both tools and response format. Using tools first, then response format for final response.'
|
||||
)
|
||||
}
|
||||
|
||||
// Build the base request payload
|
||||
const basePayload: any = {
|
||||
model: request.model || 'grok-3-latest',
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
if (request.temperature !== undefined) basePayload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) basePayload.max_tokens = request.maxTokens
|
||||
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: request.responseFormat.name || 'structured_response',
|
||||
schema: request.responseFormat.schema || request.responseFormat,
|
||||
strict: request.responseFormat.strict !== false,
|
||||
},
|
||||
// Function to create response format configuration
|
||||
const createResponseFormatPayload = (messages: any[] = allMessages) => {
|
||||
const payload = {
|
||||
...basePayload,
|
||||
messages,
|
||||
}
|
||||
|
||||
if (allMessages.length > 0 && allMessages[0].role === 'system') {
|
||||
allMessages[0].content = `${allMessages[0].content}\n\nYou MUST respond with a valid JSON object. DO NOT include any other text, explanations, or markdown formatting in your response - ONLY the JSON object.`
|
||||
} else {
|
||||
allMessages.unshift({
|
||||
role: 'system',
|
||||
content:
|
||||
'You MUST respond with a valid JSON object. DO NOT include any other text, explanations, or markdown formatting in your response - ONLY the JSON object.',
|
||||
})
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: request.responseFormat.name || 'structured_response',
|
||||
schema: request.responseFormat.schema || request.responseFormat,
|
||||
strict: request.responseFormat.strict !== false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// Handle tools and tool usage control
|
||||
@@ -118,42 +131,29 @@ export const xAIProvider: ProviderConfig = {
|
||||
|
||||
if (tools?.length) {
|
||||
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'xai')
|
||||
const { tools: filteredTools, toolChoice } = preparedTools
|
||||
|
||||
if (filteredTools?.length && toolChoice) {
|
||||
payload.tools = filteredTools
|
||||
payload.tool_choice = toolChoice
|
||||
|
||||
logger.info('XAI request configuration:', {
|
||||
toolCount: filteredTools.length,
|
||||
toolChoice:
|
||||
typeof toolChoice === 'string'
|
||||
? toolChoice
|
||||
: toolChoice.type === 'function'
|
||||
? `force:${toolChoice.function.name}`
|
||||
: toolChoice.type === 'tool'
|
||||
? `force:${toolChoice.name}`
|
||||
: toolChoice.type === 'any'
|
||||
? `force:${toolChoice.any?.name || 'unknown'}`
|
||||
: 'unknown',
|
||||
model: request.model || 'grok-3-latest',
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// EARLY STREAMING: if caller requested streaming and there are no tools to execute,
|
||||
// we can directly stream the completion.
|
||||
// we can directly stream the completion with response format if needed
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info('Using streaming response for XAI request (no tools)')
|
||||
logger.info('XAI Provider - Using direct streaming (no tools)')
|
||||
|
||||
// Start execution timer for the entire provider execution
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
const streamResponse = await xai.chat.completions.create({
|
||||
...payload,
|
||||
stream: true,
|
||||
})
|
||||
// Use response format payload if needed, otherwise use base payload
|
||||
const streamingPayload = request.responseFormat
|
||||
? createResponseFormatPayload()
|
||||
: { ...basePayload, stream: true }
|
||||
|
||||
if (!request.responseFormat) {
|
||||
streamingPayload.stream = true
|
||||
} else {
|
||||
streamingPayload.stream = true
|
||||
}
|
||||
|
||||
const streamResponse = await xai.chat.completions.create(streamingPayload)
|
||||
|
||||
// Start collecting token usage
|
||||
const tokenUsage = {
|
||||
@@ -217,14 +217,29 @@ export const xAIProvider: ProviderConfig = {
|
||||
// Make the initial API request
|
||||
const initialCallTime = Date.now()
|
||||
|
||||
// For the initial request with tools, we NEVER include response_format
|
||||
// This is the key fix: tools and response_format cannot be used together with xAI
|
||||
const initialPayload = { ...basePayload }
|
||||
|
||||
// Track the original tool_choice for forced tool tracking
|
||||
const originalToolChoice = payload.tool_choice
|
||||
let originalToolChoice: any
|
||||
|
||||
// Track forced tools and their usage
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
|
||||
let currentResponse = await xai.chat.completions.create(payload)
|
||||
if (preparedTools?.tools?.length && preparedTools.toolChoice) {
|
||||
const { tools: filteredTools, toolChoice } = preparedTools
|
||||
initialPayload.tools = filteredTools
|
||||
initialPayload.tool_choice = toolChoice
|
||||
originalToolChoice = toolChoice
|
||||
} else if (request.responseFormat) {
|
||||
// Only add response format if there are no tools
|
||||
const responseFormatPayload = createResponseFormatPayload()
|
||||
Object.assign(initialPayload, responseFormatPayload)
|
||||
}
|
||||
|
||||
let currentResponse = await xai.chat.completions.create(initialPayload)
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
@@ -278,7 +293,9 @@ export const xAIProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
// Check if a forced tool was used in the first response
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
if (originalToolChoice) {
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
}
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
@@ -297,7 +314,10 @@ export const xAIProvider: ProviderConfig = {
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
if (!tool) {
|
||||
logger.warn('XAI Provider - Tool not found:', { toolName })
|
||||
continue
|
||||
}
|
||||
|
||||
const toolCallStartTime = Date.now()
|
||||
const mergedArgs = {
|
||||
@@ -309,7 +329,13 @@ export const xAIProvider: ProviderConfig = {
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
if (!result.success) continue
|
||||
if (!result.success) {
|
||||
logger.warn('XAI Provider - Tool execution failed:', {
|
||||
toolName,
|
||||
error: result.error,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
@@ -351,7 +377,10 @@ export const xAIProvider: ProviderConfig = {
|
||||
content: JSON.stringify(result.output),
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error processing tool call:', { error })
|
||||
logger.error('XAI Provider - Error processing tool call:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
toolCall: toolCall.function.name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,10 +388,8 @@ export const xAIProvider: ProviderConfig = {
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
// After tool calls, create next payload based on whether we need more tools or final response
|
||||
let nextPayload: any
|
||||
|
||||
// Update tool_choice based on which forced tools have been used
|
||||
if (
|
||||
@@ -374,16 +401,41 @@ export const xAIProvider: ProviderConfig = {
|
||||
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
|
||||
|
||||
if (remainingTools.length > 0) {
|
||||
// Force the next tool
|
||||
nextPayload.tool_choice = {
|
||||
type: 'function',
|
||||
function: { name: remainingTools[0] },
|
||||
// Force the next tool - continue with tools, no response format
|
||||
nextPayload = {
|
||||
...basePayload,
|
||||
messages: currentMessages,
|
||||
tools: preparedTools?.tools,
|
||||
tool_choice: {
|
||||
type: 'function',
|
||||
function: { name: remainingTools[0] },
|
||||
},
|
||||
}
|
||||
logger.info(`Forcing next tool: ${remainingTools[0]}`)
|
||||
} else {
|
||||
// All forced tools have been used, switch to auto
|
||||
nextPayload.tool_choice = 'auto'
|
||||
logger.info('All forced tools have been used, switching to auto tool_choice')
|
||||
// All forced tools have been used, check if we need response format for final response
|
||||
if (request.responseFormat) {
|
||||
nextPayload = createResponseFormatPayload(currentMessages)
|
||||
} else {
|
||||
nextPayload = {
|
||||
...basePayload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto',
|
||||
tools: preparedTools?.tools,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Normal tool processing - check if this might be the final response
|
||||
if (request.responseFormat) {
|
||||
// Use response format for what might be the final response
|
||||
nextPayload = createResponseFormatPayload(currentMessages)
|
||||
} else {
|
||||
nextPayload = {
|
||||
...basePayload,
|
||||
messages: currentMessages,
|
||||
tools: preparedTools?.tools,
|
||||
tool_choice: 'auto',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,7 +445,9 @@ export const xAIProvider: ProviderConfig = {
|
||||
currentResponse = await xai.chat.completions.create(nextPayload)
|
||||
|
||||
// Check if any forced tools were used in this response
|
||||
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
|
||||
if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') {
|
||||
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
|
||||
}
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
@@ -423,23 +477,35 @@ export const xAIProvider: ProviderConfig = {
|
||||
iterationCount++
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error in xAI request:', { error })
|
||||
logger.error('XAI Provider - Error in tool processing loop:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
iterationCount,
|
||||
})
|
||||
}
|
||||
|
||||
// After all tool processing complete, if streaming was requested and we have messages, use streaming for the final response
|
||||
if (request.stream && iterationCount > 0) {
|
||||
logger.info('Using streaming for final XAI response after tool calls')
|
||||
// For final streaming response, choose between tools (auto) or response_format (never both)
|
||||
let finalStreamingPayload: any
|
||||
|
||||
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
|
||||
// This prevents the API from trying to force tool usage again in the final streaming response
|
||||
const streamingPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
|
||||
stream: true,
|
||||
if (request.responseFormat) {
|
||||
// Use response format, no tools
|
||||
finalStreamingPayload = {
|
||||
...createResponseFormatPayload(currentMessages),
|
||||
stream: true,
|
||||
}
|
||||
} else {
|
||||
// Use tools with auto choice
|
||||
finalStreamingPayload = {
|
||||
...basePayload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto',
|
||||
tools: preparedTools?.tools,
|
||||
stream: true,
|
||||
}
|
||||
}
|
||||
|
||||
const streamResponse = await xai.chat.completions.create(streamingPayload)
|
||||
const streamResponse = await xai.chat.completions.create(finalStreamingPayload)
|
||||
|
||||
// Create a StreamingExecution response with all collected data
|
||||
const streamingResult = {
|
||||
@@ -498,6 +564,14 @@ export const xAIProvider: ProviderConfig = {
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
logger.info('XAI Provider - Request completed:', {
|
||||
totalDuration,
|
||||
iterationCount: iterationCount + 1,
|
||||
toolCallCount: toolCalls.length,
|
||||
hasContent: !!content,
|
||||
contentLength: content?.length || 0,
|
||||
})
|
||||
|
||||
return {
|
||||
content,
|
||||
model: request.model,
|
||||
@@ -521,7 +595,12 @@ export const xAIProvider: ProviderConfig = {
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
logger.error('Error in xAI request:', { error, duration: totalDuration })
|
||||
logger.error('XAI Provider - Request failed:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
duration: totalDuration,
|
||||
hasTools: !!tools?.length,
|
||||
hasResponseFormat: !!request.responseFormat,
|
||||
})
|
||||
|
||||
// Create a new error with timing information
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
|
||||
Reference in New Issue
Block a user