Files
sim/apps/sim/providers/xai/index.ts
Theodore Li 158d5236bc feat(hosted key): Add exa hosted key (#3221)
* feat(hosted keys): Implement serper hosted key

* Handle required fields correctly for hosted keys

* Add rate limiting (3 tries, exponential backoff)

* Add custom pricing, switch to exa as first hosted key

* Add telemetry

* Consolidate byok type definitions

* Add warning comment if default calculation is used

* Record usage to user stats table

* Fix unit tests, use cost property

* Include more metadata in cost output

* Fix disabled tests

* Fix spacing

* Fix lint

* Move knowledge cost restructuring away from generic block handler

* Migrate knowledge unit tests

* Lint

* Fix broken tests

* Add user based hosted key throttling

* Refactor hosted key handling. Add optimistic handling of throttling for custom throttle rules.

* Remove research as hosted key. Recommend BYOK if throtttling occurs

* Make adding api keys adjustable via env vars

* Remove vestigial fields from research

* Make billing actor id required for throttling

* Switch to round robin for api key distribution

* Add helper method for adding hosted key cost

* Strip leading double underscores to avoid breaking change

* Lint fix

* Remove falsy check in favor for explicit null check

* Add more detailed metrics for different throttling types

* Fix _costDollars field

* Handle hosted agent tool calls

* Fail loudly if cost field isn't found

* Remove any type

* Fix type error

* Fix lint

* Fix usage log double logging data

* Fix test

---------

Co-authored-by: Theodore Li <teddy@zenobiapay.com>
2026-03-07 13:06:57 -05:00

616 lines
20 KiB
TypeScript

import { createLogger } from '@sim/logger'
import OpenAI from 'openai'
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
Message,
ProviderConfig,
ProviderRequest,
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import { ProviderError } from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
sumToolCosts,
} from '@/providers/utils'
import {
checkForForcedToolUsage,
createReadableStreamFromXAIStream,
createResponseFormatPayload,
} from '@/providers/xai/utils'
import { executeTool } from '@/tools'
const logger = createLogger('XAIProvider')
export const xAIProvider: ProviderConfig = {
id: 'xai',
name: 'xAI',
description: "xAI's Grok models",
version: '1.0.0',
models: getProviderModels('xai'),
defaultModel: getProviderDefaultModel('xai'),
executeRequest: async (
request: ProviderRequest
): Promise<ProviderResponse | StreamingExecution> => {
if (!request.apiKey) {
throw new Error('API key is required for xAI')
}
const xai = new OpenAI({
apiKey: request.apiKey,
baseURL: 'https://api.x.ai/v1',
})
logger.info('XAI Provider - Initial request configuration:', {
hasTools: !!request.tools?.length,
toolCount: request.tools?.length || 0,
hasResponseFormat: !!request.responseFormat,
model: request.model,
streaming: !!request.stream,
})
const allMessages: Message[] = []
if (request.systemPrompt) {
allMessages.push({
role: 'system',
content: request.systemPrompt,
})
}
if (request.context) {
allMessages.push({
role: 'user',
content: request.context,
})
}
if (request.messages) {
allMessages.push(...request.messages)
}
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: tool.parameters,
},
}))
: undefined
if (tools?.length && request.responseFormat) {
logger.warn(
'XAI Provider - Detected both tools and response format. Using tools first, then response format for final response.'
)
}
const basePayload: any = {
model: request.model,
messages: allMessages,
}
if (request.temperature !== undefined) basePayload.temperature = request.temperature
if (request.maxTokens != null) basePayload.max_completion_tokens = request.maxTokens
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
if (tools?.length) {
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'xai')
}
if (request.stream && (!tools || tools.length === 0)) {
logger.info('XAI Provider - Using direct streaming (no tools)')
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
const streamingParams: ChatCompletionCreateParamsStreaming = request.responseFormat
? {
...createResponseFormatPayload(basePayload, allMessages, request.responseFormat),
stream: true,
stream_options: { include_usage: true },
}
: { ...basePayload, stream: true, stream_options: { include_usage: true } }
const streamResponse = await xai.chat.completions.create(
streamingParams,
request.abortSignal ? { signal: request.abortSignal } : undefined
)
const streamingResult = {
stream: createReadableStreamFromXAIStream(streamResponse, (content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
input: usage.prompt_tokens,
output: usage.completion_tokens,
total: usage.total_tokens,
}
const costResult = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
}),
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens: { input: 0, output: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
timeSegments: [
{
type: 'model',
name: 'Streaming response',
startTime: providerStartTime,
endTime: Date.now(),
duration: Date.now() - providerStartTime,
},
],
},
cost: { input: 0, output: 0, total: 0 },
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
},
isStreaming: true,
},
}
return streamingResult as StreamingExecution
}
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
const initialCallTime = Date.now()
// xAI cannot use tools and response_format together in the same request
const initialPayload = { ...basePayload }
let originalToolChoice: any
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
if (preparedTools?.tools?.length && preparedTools.toolChoice) {
const { tools: filteredTools, toolChoice } = preparedTools
initialPayload.tools = filteredTools
initialPayload.tool_choice = toolChoice
originalToolChoice = toolChoice
} else if (request.responseFormat) {
const responseFormatPayload = createResponseFormatPayload(
basePayload,
allMessages,
request.responseFormat
)
Object.assign(initialPayload, responseFormatPayload)
}
let currentResponse = await xai.chat.completions.create(
initialPayload,
request.abortSignal ? { signal: request.abortSignal } : undefined
)
const firstResponseTime = Date.now() - initialCallTime
let content = currentResponse.choices[0]?.message?.content || ''
const tokens = {
input: currentResponse.usage?.prompt_tokens || 0,
output: currentResponse.usage?.completion_tokens || 0,
total: currentResponse.usage?.total_tokens || 0,
}
const toolCalls = []
const toolResults: Record<string, unknown>[] = []
const currentMessages = [...allMessages]
let iterationCount = 0
let hasUsedForcedTool = false
let modelTime = firstResponseTime
let toolsTime = 0
const timeSegments: TimeSegment[] = [
{
type: 'model',
name: 'Initial response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
]
if (originalToolChoice) {
const result = checkForForcedToolUsage(
currentResponse,
originalToolChoice,
forcedTools,
usedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
}
try {
while (iterationCount < MAX_TOOL_ITERATIONS) {
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
const toolsStartTime = Date.now()
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) {
logger.warn('XAI Provider - Tool not found:', { toolName })
return null
}
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams)
const toolCallEndTime = Date.now()
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('XAI Provider - Error processing tool call:', {
error: error instanceof Error ? error.message : String(error),
toolCall: toolCall.function.name,
})
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: toolCallsInResponse.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: startTime,
endTime: endTime,
duration: duration,
})
let resultContent: any
if (result.success && result.output) {
toolResults.push(result.output)
resultContent = result.output
} else {
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
tool: toolName,
}
logger.warn('XAI Provider - Tool execution failed:', {
toolName,
error: result.error,
})
}
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
}
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
let nextPayload: any
if (
typeof originalToolChoice === 'object' &&
hasUsedForcedTool &&
forcedTools.length > 0
) {
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
if (remainingTools.length > 0) {
nextPayload = {
...basePayload,
messages: currentMessages,
tools: preparedTools?.tools,
tool_choice: {
type: 'function',
function: { name: remainingTools[0] },
},
}
} else {
if (request.responseFormat) {
nextPayload = createResponseFormatPayload(
basePayload,
allMessages,
request.responseFormat,
currentMessages
)
} else {
nextPayload = {
...basePayload,
messages: currentMessages,
tool_choice: 'auto',
tools: preparedTools?.tools,
}
}
}
} else {
if (request.responseFormat) {
nextPayload = createResponseFormatPayload(
basePayload,
allMessages,
request.responseFormat,
currentMessages
)
} else {
nextPayload = {
...basePayload,
messages: currentMessages,
tools: preparedTools?.tools,
tool_choice: 'auto',
}
}
}
const nextModelStartTime = Date.now()
currentResponse = await xai.chat.completions.create(
nextPayload,
request.abortSignal ? { signal: request.abortSignal } : undefined
)
if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') {
const result = checkForForcedToolUsage(
currentResponse,
nextPayload.tool_choice,
forcedTools,
usedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
}
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
startTime: nextModelStartTime,
endTime: nextModelEndTime,
duration: thisModelTime,
})
modelTime += thisModelTime
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
if (currentResponse.usage) {
tokens.input += currentResponse.usage.prompt_tokens || 0
tokens.output += currentResponse.usage.completion_tokens || 0
tokens.total += currentResponse.usage.total_tokens || 0
}
iterationCount++
}
} catch (error) {
logger.error('XAI Provider - Error in tool processing loop:', {
error: error instanceof Error ? error.message : String(error),
iterationCount,
})
}
if (request.stream) {
let finalStreamingPayload: any
if (request.responseFormat) {
finalStreamingPayload = {
...createResponseFormatPayload(
basePayload,
allMessages,
request.responseFormat,
currentMessages
),
stream: true,
}
} else {
finalStreamingPayload = {
...basePayload,
messages: currentMessages,
tool_choice: 'auto',
tools: preparedTools?.tools,
stream: true,
}
}
const streamResponse = await xai.chat.completions.create(
finalStreamingPayload as any,
request.abortSignal ? { signal: request.abortSignal } : undefined
)
const accumulatedCost = calculateCost(request.model, tokens.input, tokens.output)
const streamingResult = {
stream: createReadableStreamFromXAIStream(streamResponse as any, (content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
input: tokens.input + usage.prompt_tokens,
output: tokens.output + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
const streamCost = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
const tc = sumToolCosts(toolResults)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
toolCost: tc || undefined,
total: accumulatedCost.total + streamCost.total + tc,
}
}),
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens: {
input: tokens.input,
output: tokens.output,
total: tokens.total,
},
toolCalls:
toolCalls.length > 0
? {
list: toolCalls,
count: toolCalls.length,
}
: undefined,
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
modelTime: modelTime,
toolsTime: toolsTime,
firstResponseTime: firstResponseTime,
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
cost: {
input: accumulatedCost.input,
output: accumulatedCost.output,
toolCost: undefined as number | undefined,
total: accumulatedCost.total,
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
},
isStreaming: true,
},
}
return streamingResult as StreamingExecution
}
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
logger.info('XAI Provider - Request completed:', {
totalDuration,
iterationCount: iterationCount + 1,
toolCallCount: toolCalls.length,
hasContent: !!content,
contentLength: content?.length || 0,
})
return {
content,
model: request.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
timing: {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
modelTime: modelTime,
toolsTime: toolsTime,
firstResponseTime: firstResponseTime,
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
}
} catch (error) {
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
logger.error('XAI Provider - Request failed:', {
error: error instanceof Error ? error.message : String(error),
duration: totalDuration,
hasTools: !!tools?.length,
hasResponseFormat: !!request.responseFormat,
})
throw new ProviderError(error instanceof Error ? error.message : String(error), {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
})
}
},
}