mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
* feat(hosted keys): Implement serper hosted key * Handle required fields correctly for hosted keys * Add rate limiting (3 tries, exponential backoff) * Add custom pricing, switch to exa as first hosted key * Add telemetry * Consolidate byok type definitions * Add warning comment if default calculation is used * Record usage to user stats table * Fix unit tests, use cost property * Include more metadata in cost output * Fix disabled tests * Fix spacing * Fix lint * Move knowledge cost restructuring away from generic block handler * Migrate knowledge unit tests * Lint * Fix broken tests * Add user based hosted key throttling * Refactor hosted key handling. Add optimistic handling of throttling for custom throttle rules. * Remove research as hosted key. Recommend BYOK if throtttling occurs * Make adding api keys adjustable via env vars * Remove vestigial fields from research * Make billing actor id required for throttling * Switch to round robin for api key distribution * Add helper method for adding hosted key cost * Strip leading double underscores to avoid breaking change * Lint fix * Remove falsy check in favor for explicit null check * Add more detailed metrics for different throttling types * Fix _costDollars field * Handle hosted agent tool calls * Fail loudly if cost field isn't found * Remove any type * Fix type error * Fix lint * Fix usage log double logging data * Fix test --------- Co-authored-by: Theodore Li <teddy@zenobiapay.com>
616 lines
20 KiB
TypeScript
616 lines
20 KiB
TypeScript
import { createLogger } from '@sim/logger'
|
|
import OpenAI from 'openai'
|
|
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
|
|
import type { StreamingExecution } from '@/executor/types'
|
|
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
|
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
|
import type {
|
|
Message,
|
|
ProviderConfig,
|
|
ProviderRequest,
|
|
ProviderResponse,
|
|
TimeSegment,
|
|
} from '@/providers/types'
|
|
import { ProviderError } from '@/providers/types'
|
|
import {
|
|
calculateCost,
|
|
prepareToolExecution,
|
|
prepareToolsWithUsageControl,
|
|
sumToolCosts,
|
|
} from '@/providers/utils'
|
|
import {
|
|
checkForForcedToolUsage,
|
|
createReadableStreamFromXAIStream,
|
|
createResponseFormatPayload,
|
|
} from '@/providers/xai/utils'
|
|
import { executeTool } from '@/tools'
|
|
|
|
const logger = createLogger('XAIProvider')
|
|
|
|
export const xAIProvider: ProviderConfig = {
|
|
id: 'xai',
|
|
name: 'xAI',
|
|
description: "xAI's Grok models",
|
|
version: '1.0.0',
|
|
models: getProviderModels('xai'),
|
|
defaultModel: getProviderDefaultModel('xai'),
|
|
|
|
executeRequest: async (
|
|
request: ProviderRequest
|
|
): Promise<ProviderResponse | StreamingExecution> => {
|
|
if (!request.apiKey) {
|
|
throw new Error('API key is required for xAI')
|
|
}
|
|
|
|
const xai = new OpenAI({
|
|
apiKey: request.apiKey,
|
|
baseURL: 'https://api.x.ai/v1',
|
|
})
|
|
|
|
logger.info('XAI Provider - Initial request configuration:', {
|
|
hasTools: !!request.tools?.length,
|
|
toolCount: request.tools?.length || 0,
|
|
hasResponseFormat: !!request.responseFormat,
|
|
model: request.model,
|
|
streaming: !!request.stream,
|
|
})
|
|
|
|
const allMessages: Message[] = []
|
|
|
|
if (request.systemPrompt) {
|
|
allMessages.push({
|
|
role: 'system',
|
|
content: request.systemPrompt,
|
|
})
|
|
}
|
|
|
|
if (request.context) {
|
|
allMessages.push({
|
|
role: 'user',
|
|
content: request.context,
|
|
})
|
|
}
|
|
|
|
if (request.messages) {
|
|
allMessages.push(...request.messages)
|
|
}
|
|
const tools = request.tools?.length
|
|
? request.tools.map((tool) => ({
|
|
type: 'function',
|
|
function: {
|
|
name: tool.id,
|
|
description: tool.description,
|
|
parameters: tool.parameters,
|
|
},
|
|
}))
|
|
: undefined
|
|
if (tools?.length && request.responseFormat) {
|
|
logger.warn(
|
|
'XAI Provider - Detected both tools and response format. Using tools first, then response format for final response.'
|
|
)
|
|
}
|
|
const basePayload: any = {
|
|
model: request.model,
|
|
messages: allMessages,
|
|
}
|
|
|
|
if (request.temperature !== undefined) basePayload.temperature = request.temperature
|
|
if (request.maxTokens != null) basePayload.max_completion_tokens = request.maxTokens
|
|
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
|
|
|
if (tools?.length) {
|
|
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'xai')
|
|
}
|
|
|
|
if (request.stream && (!tools || tools.length === 0)) {
|
|
logger.info('XAI Provider - Using direct streaming (no tools)')
|
|
|
|
const providerStartTime = Date.now()
|
|
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
|
|
|
const streamingParams: ChatCompletionCreateParamsStreaming = request.responseFormat
|
|
? {
|
|
...createResponseFormatPayload(basePayload, allMessages, request.responseFormat),
|
|
stream: true,
|
|
stream_options: { include_usage: true },
|
|
}
|
|
: { ...basePayload, stream: true, stream_options: { include_usage: true } }
|
|
|
|
const streamResponse = await xai.chat.completions.create(
|
|
streamingParams,
|
|
request.abortSignal ? { signal: request.abortSignal } : undefined
|
|
)
|
|
|
|
const streamingResult = {
|
|
stream: createReadableStreamFromXAIStream(streamResponse, (content, usage) => {
|
|
streamingResult.execution.output.content = content
|
|
streamingResult.execution.output.tokens = {
|
|
input: usage.prompt_tokens,
|
|
output: usage.completion_tokens,
|
|
total: usage.total_tokens,
|
|
}
|
|
|
|
const costResult = calculateCost(
|
|
request.model,
|
|
usage.prompt_tokens,
|
|
usage.completion_tokens
|
|
)
|
|
streamingResult.execution.output.cost = {
|
|
input: costResult.input,
|
|
output: costResult.output,
|
|
total: costResult.total,
|
|
}
|
|
}),
|
|
execution: {
|
|
success: true,
|
|
output: {
|
|
content: '',
|
|
model: request.model,
|
|
tokens: { input: 0, output: 0, total: 0 },
|
|
toolCalls: undefined,
|
|
providerTiming: {
|
|
startTime: providerStartTimeISO,
|
|
endTime: new Date().toISOString(),
|
|
duration: Date.now() - providerStartTime,
|
|
timeSegments: [
|
|
{
|
|
type: 'model',
|
|
name: 'Streaming response',
|
|
startTime: providerStartTime,
|
|
endTime: Date.now(),
|
|
duration: Date.now() - providerStartTime,
|
|
},
|
|
],
|
|
},
|
|
cost: { input: 0, output: 0, total: 0 },
|
|
},
|
|
logs: [],
|
|
metadata: {
|
|
startTime: providerStartTimeISO,
|
|
endTime: new Date().toISOString(),
|
|
duration: Date.now() - providerStartTime,
|
|
},
|
|
isStreaming: true,
|
|
},
|
|
}
|
|
|
|
return streamingResult as StreamingExecution
|
|
}
|
|
const providerStartTime = Date.now()
|
|
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
|
|
|
try {
|
|
const initialCallTime = Date.now()
|
|
|
|
// xAI cannot use tools and response_format together in the same request
|
|
const initialPayload = { ...basePayload }
|
|
|
|
let originalToolChoice: any
|
|
const forcedTools = preparedTools?.forcedTools || []
|
|
let usedForcedTools: string[] = []
|
|
|
|
if (preparedTools?.tools?.length && preparedTools.toolChoice) {
|
|
const { tools: filteredTools, toolChoice } = preparedTools
|
|
initialPayload.tools = filteredTools
|
|
initialPayload.tool_choice = toolChoice
|
|
originalToolChoice = toolChoice
|
|
} else if (request.responseFormat) {
|
|
const responseFormatPayload = createResponseFormatPayload(
|
|
basePayload,
|
|
allMessages,
|
|
request.responseFormat
|
|
)
|
|
Object.assign(initialPayload, responseFormatPayload)
|
|
}
|
|
|
|
let currentResponse = await xai.chat.completions.create(
|
|
initialPayload,
|
|
request.abortSignal ? { signal: request.abortSignal } : undefined
|
|
)
|
|
const firstResponseTime = Date.now() - initialCallTime
|
|
|
|
let content = currentResponse.choices[0]?.message?.content || ''
|
|
const tokens = {
|
|
input: currentResponse.usage?.prompt_tokens || 0,
|
|
output: currentResponse.usage?.completion_tokens || 0,
|
|
total: currentResponse.usage?.total_tokens || 0,
|
|
}
|
|
const toolCalls = []
|
|
const toolResults: Record<string, unknown>[] = []
|
|
const currentMessages = [...allMessages]
|
|
let iterationCount = 0
|
|
|
|
let hasUsedForcedTool = false
|
|
let modelTime = firstResponseTime
|
|
let toolsTime = 0
|
|
const timeSegments: TimeSegment[] = [
|
|
{
|
|
type: 'model',
|
|
name: 'Initial response',
|
|
startTime: initialCallTime,
|
|
endTime: initialCallTime + firstResponseTime,
|
|
duration: firstResponseTime,
|
|
},
|
|
]
|
|
if (originalToolChoice) {
|
|
const result = checkForForcedToolUsage(
|
|
currentResponse,
|
|
originalToolChoice,
|
|
forcedTools,
|
|
usedForcedTools
|
|
)
|
|
hasUsedForcedTool = result.hasUsedForcedTool
|
|
usedForcedTools = result.usedForcedTools
|
|
}
|
|
|
|
try {
|
|
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
|
if (currentResponse.choices[0]?.message?.content) {
|
|
content = currentResponse.choices[0].message.content
|
|
}
|
|
|
|
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
|
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
|
break
|
|
}
|
|
|
|
const toolsStartTime = Date.now()
|
|
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
|
const toolCallStartTime = Date.now()
|
|
const toolName = toolCall.function.name
|
|
|
|
try {
|
|
const toolArgs = JSON.parse(toolCall.function.arguments)
|
|
const tool = request.tools?.find((t) => t.id === toolName)
|
|
|
|
if (!tool) {
|
|
logger.warn('XAI Provider - Tool not found:', { toolName })
|
|
return null
|
|
}
|
|
|
|
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
|
const result = await executeTool(toolName, executionParams)
|
|
const toolCallEndTime = Date.now()
|
|
|
|
return {
|
|
toolCall,
|
|
toolName,
|
|
toolParams,
|
|
result,
|
|
startTime: toolCallStartTime,
|
|
endTime: toolCallEndTime,
|
|
duration: toolCallEndTime - toolCallStartTime,
|
|
}
|
|
} catch (error) {
|
|
const toolCallEndTime = Date.now()
|
|
logger.error('XAI Provider - Error processing tool call:', {
|
|
error: error instanceof Error ? error.message : String(error),
|
|
toolCall: toolCall.function.name,
|
|
})
|
|
|
|
return {
|
|
toolCall,
|
|
toolName,
|
|
toolParams: {},
|
|
result: {
|
|
success: false,
|
|
output: undefined,
|
|
error: error instanceof Error ? error.message : 'Tool execution failed',
|
|
},
|
|
startTime: toolCallStartTime,
|
|
endTime: toolCallEndTime,
|
|
duration: toolCallEndTime - toolCallStartTime,
|
|
}
|
|
}
|
|
})
|
|
|
|
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
|
currentMessages.push({
|
|
role: 'assistant',
|
|
content: null,
|
|
tool_calls: toolCallsInResponse.map((tc) => ({
|
|
id: tc.id,
|
|
type: 'function',
|
|
function: {
|
|
name: tc.function.name,
|
|
arguments: tc.function.arguments,
|
|
},
|
|
})),
|
|
})
|
|
|
|
for (const settledResult of executionResults) {
|
|
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
|
|
|
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
|
settledResult.value
|
|
|
|
timeSegments.push({
|
|
type: 'tool',
|
|
name: toolName,
|
|
startTime: startTime,
|
|
endTime: endTime,
|
|
duration: duration,
|
|
})
|
|
let resultContent: any
|
|
if (result.success && result.output) {
|
|
toolResults.push(result.output)
|
|
resultContent = result.output
|
|
} else {
|
|
resultContent = {
|
|
error: true,
|
|
message: result.error || 'Tool execution failed',
|
|
tool: toolName,
|
|
}
|
|
logger.warn('XAI Provider - Tool execution failed:', {
|
|
toolName,
|
|
error: result.error,
|
|
})
|
|
}
|
|
|
|
toolCalls.push({
|
|
name: toolName,
|
|
arguments: toolParams,
|
|
startTime: new Date(startTime).toISOString(),
|
|
endTime: new Date(endTime).toISOString(),
|
|
duration: duration,
|
|
result: resultContent,
|
|
success: result.success,
|
|
})
|
|
currentMessages.push({
|
|
role: 'tool',
|
|
tool_call_id: toolCall.id,
|
|
content: JSON.stringify(resultContent),
|
|
})
|
|
}
|
|
|
|
const thisToolsTime = Date.now() - toolsStartTime
|
|
toolsTime += thisToolsTime
|
|
|
|
let nextPayload: any
|
|
if (
|
|
typeof originalToolChoice === 'object' &&
|
|
hasUsedForcedTool &&
|
|
forcedTools.length > 0
|
|
) {
|
|
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
|
|
|
|
if (remainingTools.length > 0) {
|
|
nextPayload = {
|
|
...basePayload,
|
|
messages: currentMessages,
|
|
tools: preparedTools?.tools,
|
|
tool_choice: {
|
|
type: 'function',
|
|
function: { name: remainingTools[0] },
|
|
},
|
|
}
|
|
} else {
|
|
if (request.responseFormat) {
|
|
nextPayload = createResponseFormatPayload(
|
|
basePayload,
|
|
allMessages,
|
|
request.responseFormat,
|
|
currentMessages
|
|
)
|
|
} else {
|
|
nextPayload = {
|
|
...basePayload,
|
|
messages: currentMessages,
|
|
tool_choice: 'auto',
|
|
tools: preparedTools?.tools,
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
if (request.responseFormat) {
|
|
nextPayload = createResponseFormatPayload(
|
|
basePayload,
|
|
allMessages,
|
|
request.responseFormat,
|
|
currentMessages
|
|
)
|
|
} else {
|
|
nextPayload = {
|
|
...basePayload,
|
|
messages: currentMessages,
|
|
tools: preparedTools?.tools,
|
|
tool_choice: 'auto',
|
|
}
|
|
}
|
|
}
|
|
|
|
const nextModelStartTime = Date.now()
|
|
|
|
currentResponse = await xai.chat.completions.create(
|
|
nextPayload,
|
|
request.abortSignal ? { signal: request.abortSignal } : undefined
|
|
)
|
|
if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') {
|
|
const result = checkForForcedToolUsage(
|
|
currentResponse,
|
|
nextPayload.tool_choice,
|
|
forcedTools,
|
|
usedForcedTools
|
|
)
|
|
hasUsedForcedTool = result.hasUsedForcedTool
|
|
usedForcedTools = result.usedForcedTools
|
|
}
|
|
|
|
const nextModelEndTime = Date.now()
|
|
const thisModelTime = nextModelEndTime - nextModelStartTime
|
|
timeSegments.push({
|
|
type: 'model',
|
|
name: `Model response (iteration ${iterationCount + 1})`,
|
|
startTime: nextModelStartTime,
|
|
endTime: nextModelEndTime,
|
|
duration: thisModelTime,
|
|
})
|
|
|
|
modelTime += thisModelTime
|
|
|
|
if (currentResponse.choices[0]?.message?.content) {
|
|
content = currentResponse.choices[0].message.content
|
|
}
|
|
|
|
if (currentResponse.usage) {
|
|
tokens.input += currentResponse.usage.prompt_tokens || 0
|
|
tokens.output += currentResponse.usage.completion_tokens || 0
|
|
tokens.total += currentResponse.usage.total_tokens || 0
|
|
}
|
|
|
|
iterationCount++
|
|
}
|
|
} catch (error) {
|
|
logger.error('XAI Provider - Error in tool processing loop:', {
|
|
error: error instanceof Error ? error.message : String(error),
|
|
iterationCount,
|
|
})
|
|
}
|
|
if (request.stream) {
|
|
let finalStreamingPayload: any
|
|
|
|
if (request.responseFormat) {
|
|
finalStreamingPayload = {
|
|
...createResponseFormatPayload(
|
|
basePayload,
|
|
allMessages,
|
|
request.responseFormat,
|
|
currentMessages
|
|
),
|
|
stream: true,
|
|
}
|
|
} else {
|
|
finalStreamingPayload = {
|
|
...basePayload,
|
|
messages: currentMessages,
|
|
tool_choice: 'auto',
|
|
tools: preparedTools?.tools,
|
|
stream: true,
|
|
}
|
|
}
|
|
|
|
const streamResponse = await xai.chat.completions.create(
|
|
finalStreamingPayload as any,
|
|
request.abortSignal ? { signal: request.abortSignal } : undefined
|
|
)
|
|
|
|
const accumulatedCost = calculateCost(request.model, tokens.input, tokens.output)
|
|
|
|
const streamingResult = {
|
|
stream: createReadableStreamFromXAIStream(streamResponse as any, (content, usage) => {
|
|
streamingResult.execution.output.content = content
|
|
streamingResult.execution.output.tokens = {
|
|
input: tokens.input + usage.prompt_tokens,
|
|
output: tokens.output + usage.completion_tokens,
|
|
total: tokens.total + usage.total_tokens,
|
|
}
|
|
|
|
const streamCost = calculateCost(
|
|
request.model,
|
|
usage.prompt_tokens,
|
|
usage.completion_tokens
|
|
)
|
|
const tc = sumToolCosts(toolResults)
|
|
streamingResult.execution.output.cost = {
|
|
input: accumulatedCost.input + streamCost.input,
|
|
output: accumulatedCost.output + streamCost.output,
|
|
toolCost: tc || undefined,
|
|
total: accumulatedCost.total + streamCost.total + tc,
|
|
}
|
|
}),
|
|
execution: {
|
|
success: true,
|
|
output: {
|
|
content: '',
|
|
model: request.model,
|
|
tokens: {
|
|
input: tokens.input,
|
|
output: tokens.output,
|
|
total: tokens.total,
|
|
},
|
|
toolCalls:
|
|
toolCalls.length > 0
|
|
? {
|
|
list: toolCalls,
|
|
count: toolCalls.length,
|
|
}
|
|
: undefined,
|
|
providerTiming: {
|
|
startTime: providerStartTimeISO,
|
|
endTime: new Date().toISOString(),
|
|
duration: Date.now() - providerStartTime,
|
|
modelTime: modelTime,
|
|
toolsTime: toolsTime,
|
|
firstResponseTime: firstResponseTime,
|
|
iterations: iterationCount + 1,
|
|
timeSegments: timeSegments,
|
|
},
|
|
cost: {
|
|
input: accumulatedCost.input,
|
|
output: accumulatedCost.output,
|
|
toolCost: undefined as number | undefined,
|
|
total: accumulatedCost.total,
|
|
},
|
|
},
|
|
logs: [],
|
|
metadata: {
|
|
startTime: providerStartTimeISO,
|
|
endTime: new Date().toISOString(),
|
|
duration: Date.now() - providerStartTime,
|
|
},
|
|
isStreaming: true,
|
|
},
|
|
}
|
|
|
|
return streamingResult as StreamingExecution
|
|
}
|
|
const providerEndTime = Date.now()
|
|
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
|
const totalDuration = providerEndTime - providerStartTime
|
|
|
|
logger.info('XAI Provider - Request completed:', {
|
|
totalDuration,
|
|
iterationCount: iterationCount + 1,
|
|
toolCallCount: toolCalls.length,
|
|
hasContent: !!content,
|
|
contentLength: content?.length || 0,
|
|
})
|
|
|
|
return {
|
|
content,
|
|
model: request.model,
|
|
tokens,
|
|
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
|
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
|
timing: {
|
|
startTime: providerStartTimeISO,
|
|
endTime: providerEndTimeISO,
|
|
duration: totalDuration,
|
|
modelTime: modelTime,
|
|
toolsTime: toolsTime,
|
|
firstResponseTime: firstResponseTime,
|
|
iterations: iterationCount + 1,
|
|
timeSegments: timeSegments,
|
|
},
|
|
}
|
|
} catch (error) {
|
|
const providerEndTime = Date.now()
|
|
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
|
const totalDuration = providerEndTime - providerStartTime
|
|
|
|
logger.error('XAI Provider - Request failed:', {
|
|
error: error instanceof Error ? error.message : String(error),
|
|
duration: totalDuration,
|
|
hasTools: !!tools?.length,
|
|
hasResponseFormat: !!request.responseFormat,
|
|
})
|
|
|
|
throw new ProviderError(error instanceof Error ? error.message : String(error), {
|
|
startTime: providerStartTimeISO,
|
|
endTime: providerEndTimeISO,
|
|
duration: totalDuration,
|
|
})
|
|
}
|
|
},
|
|
}
|