mirror of
https://github.com/simstudioai/sim.git
synced 2026-02-04 19:55:08 -05:00
improvement(openai): migrate to responses api (#3135)
* Migrate openai to use responses api * Consolidate azure * Fix streaming * Bug fixes * Bug fixes * Fix responseformat * Refactor * Fix bugs * Fix * Fix azure openai response format with tool calls * Fixes * Fixes * Fix temp
This commit is contained in:
committed by
GitHub
parent
793adda986
commit
1933e1aad5
@@ -3,7 +3,6 @@ import { userStats, workflow } from '@sim/db/schema'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { getBYOKKey } from '@/lib/api-key/byok'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { logModelUsage } from '@/lib/billing/core/usage-log'
|
||||
@@ -12,6 +11,7 @@ import { env } from '@/lib/core/config/env'
|
||||
import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-flags'
|
||||
import { generateRequestId } from '@/lib/core/utils/request'
|
||||
import { verifyWorkspaceMembership } from '@/app/api/workflows/utils'
|
||||
import { extractResponseText, parseResponsesUsage } from '@/providers/openai/utils'
|
||||
import { getModelPricing } from '@/providers/utils'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
@@ -28,18 +28,6 @@ const openaiApiKey = env.OPENAI_API_KEY
|
||||
|
||||
const useWandAzure = azureApiKey && azureEndpoint && azureApiVersion
|
||||
|
||||
const client = useWandAzure
|
||||
? new AzureOpenAI({
|
||||
apiKey: azureApiKey,
|
||||
apiVersion: azureApiVersion,
|
||||
endpoint: azureEndpoint,
|
||||
})
|
||||
: openaiApiKey
|
||||
? new OpenAI({
|
||||
apiKey: openaiApiKey,
|
||||
})
|
||||
: null
|
||||
|
||||
if (!useWandAzure && !openaiApiKey) {
|
||||
logger.warn(
|
||||
'Neither Azure OpenAI nor OpenAI API key found. Wand generation API will not function.'
|
||||
@@ -202,20 +190,18 @@ export async function POST(req: NextRequest) {
|
||||
}
|
||||
|
||||
let isBYOK = false
|
||||
let activeClient = client
|
||||
let byokApiKey: string | null = null
|
||||
let activeOpenAIKey = openaiApiKey
|
||||
|
||||
if (workspaceId && !useWandAzure) {
|
||||
const byokResult = await getBYOKKey(workspaceId, 'openai')
|
||||
if (byokResult) {
|
||||
isBYOK = true
|
||||
byokApiKey = byokResult.apiKey
|
||||
activeClient = new OpenAI({ apiKey: byokResult.apiKey })
|
||||
activeOpenAIKey = byokResult.apiKey
|
||||
logger.info(`[${requestId}] Using BYOK OpenAI key for wand generation`)
|
||||
}
|
||||
}
|
||||
|
||||
if (!activeClient) {
|
||||
if (!useWandAzure && !activeOpenAIKey) {
|
||||
logger.error(`[${requestId}] AI client not initialized. Missing API key.`)
|
||||
return NextResponse.json(
|
||||
{ success: false, error: 'Wand generation service is not configured.' },
|
||||
@@ -276,17 +262,18 @@ Use this context to calculate relative dates like "yesterday", "last week", "beg
|
||||
)
|
||||
|
||||
const apiUrl = useWandAzure
|
||||
? `${azureEndpoint}/openai/deployments/${wandModelName}/chat/completions?api-version=${azureApiVersion}`
|
||||
: 'https://api.openai.com/v1/chat/completions'
|
||||
? `${azureEndpoint?.replace(/\/$/, '')}/openai/v1/responses?api-version=${azureApiVersion}`
|
||||
: 'https://api.openai.com/v1/responses'
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
'OpenAI-Beta': 'responses=v1',
|
||||
}
|
||||
|
||||
if (useWandAzure) {
|
||||
headers['api-key'] = azureApiKey!
|
||||
} else {
|
||||
headers.Authorization = `Bearer ${byokApiKey || openaiApiKey}`
|
||||
headers.Authorization = `Bearer ${activeOpenAIKey}`
|
||||
}
|
||||
|
||||
logger.debug(`[${requestId}] Making streaming request to: ${apiUrl}`)
|
||||
@@ -296,11 +283,10 @@ Use this context to calculate relative dates like "yesterday", "last week", "beg
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
model: useWandAzure ? wandModelName : 'gpt-4o',
|
||||
messages: messages,
|
||||
input: messages,
|
||||
temperature: 0.2,
|
||||
max_tokens: 10000,
|
||||
max_output_tokens: 10000,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}),
|
||||
})
|
||||
|
||||
@@ -327,16 +313,29 @@ Use this context to calculate relative dates like "yesterday", "last week", "beg
|
||||
return
|
||||
}
|
||||
|
||||
let finalUsage: any = null
|
||||
let usageRecorded = false
|
||||
|
||||
const recordUsage = async () => {
|
||||
if (usageRecorded || !finalUsage) {
|
||||
return
|
||||
}
|
||||
|
||||
usageRecorded = true
|
||||
await updateUserStatsForWand(session.user.id, finalUsage, requestId, isBYOK)
|
||||
}
|
||||
|
||||
try {
|
||||
let buffer = ''
|
||||
let chunkCount = 0
|
||||
let finalUsage: any = null
|
||||
let activeEventType: string | undefined
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
|
||||
if (done) {
|
||||
logger.info(`[${requestId}] Stream completed. Total chunks: ${chunkCount}`)
|
||||
await recordUsage()
|
||||
controller.enqueue(encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`))
|
||||
controller.close()
|
||||
break
|
||||
@@ -348,47 +347,90 @@ Use this context to calculate relative dates like "yesterday", "last week", "beg
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
const data = line.slice(6).trim()
|
||||
const trimmed = line.trim()
|
||||
if (!trimmed) {
|
||||
continue
|
||||
}
|
||||
|
||||
if (data === '[DONE]') {
|
||||
logger.info(`[${requestId}] Received [DONE] signal`)
|
||||
if (trimmed.startsWith('event:')) {
|
||||
activeEventType = trimmed.slice(6).trim()
|
||||
continue
|
||||
}
|
||||
|
||||
if (finalUsage) {
|
||||
await updateUserStatsForWand(session.user.id, finalUsage, requestId, isBYOK)
|
||||
if (!trimmed.startsWith('data:')) {
|
||||
continue
|
||||
}
|
||||
|
||||
const data = trimmed.slice(5).trim()
|
||||
if (data === '[DONE]') {
|
||||
logger.info(`[${requestId}] Received [DONE] signal`)
|
||||
|
||||
await recordUsage()
|
||||
|
||||
controller.enqueue(
|
||||
encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`)
|
||||
)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
|
||||
let parsed: any
|
||||
try {
|
||||
parsed = JSON.parse(data)
|
||||
} catch (parseError) {
|
||||
logger.debug(`[${requestId}] Skipped non-JSON line: ${data.substring(0, 100)}`)
|
||||
continue
|
||||
}
|
||||
|
||||
const eventType = parsed?.type ?? activeEventType
|
||||
|
||||
if (
|
||||
eventType === 'response.error' ||
|
||||
eventType === 'error' ||
|
||||
eventType === 'response.failed'
|
||||
) {
|
||||
throw new Error(parsed?.error?.message || 'Responses stream error')
|
||||
}
|
||||
|
||||
if (
|
||||
eventType === 'response.output_text.delta' ||
|
||||
eventType === 'response.output_json.delta'
|
||||
) {
|
||||
let content = ''
|
||||
if (typeof parsed.delta === 'string') {
|
||||
content = parsed.delta
|
||||
} else if (parsed.delta && typeof parsed.delta.text === 'string') {
|
||||
content = parsed.delta.text
|
||||
} else if (parsed.delta && parsed.delta.json !== undefined) {
|
||||
content = JSON.stringify(parsed.delta.json)
|
||||
} else if (parsed.json !== undefined) {
|
||||
content = JSON.stringify(parsed.json)
|
||||
} else if (typeof parsed.text === 'string') {
|
||||
content = parsed.text
|
||||
}
|
||||
|
||||
if (content) {
|
||||
chunkCount++
|
||||
if (chunkCount === 1) {
|
||||
logger.info(`[${requestId}] Received first content chunk`)
|
||||
}
|
||||
|
||||
controller.enqueue(
|
||||
encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`)
|
||||
encoder.encode(`data: ${JSON.stringify({ chunk: content })}\n\n`)
|
||||
)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data)
|
||||
const content = parsed.choices?.[0]?.delta?.content
|
||||
|
||||
if (content) {
|
||||
chunkCount++
|
||||
if (chunkCount === 1) {
|
||||
logger.info(`[${requestId}] Received first content chunk`)
|
||||
}
|
||||
|
||||
controller.enqueue(
|
||||
encoder.encode(`data: ${JSON.stringify({ chunk: content })}\n\n`)
|
||||
)
|
||||
if (eventType === 'response.completed') {
|
||||
const usage = parseResponsesUsage(parsed?.response?.usage ?? parsed?.usage)
|
||||
if (usage) {
|
||||
finalUsage = {
|
||||
prompt_tokens: usage.promptTokens,
|
||||
completion_tokens: usage.completionTokens,
|
||||
total_tokens: usage.totalTokens,
|
||||
}
|
||||
|
||||
if (parsed.usage) {
|
||||
finalUsage = parsed.usage
|
||||
logger.info(
|
||||
`[${requestId}] Received usage data: ${JSON.stringify(parsed.usage)}`
|
||||
)
|
||||
}
|
||||
} catch (parseError) {
|
||||
logger.debug(
|
||||
`[${requestId}] Skipped non-JSON line: ${data.substring(0, 100)}`
|
||||
logger.info(
|
||||
`[${requestId}] Received usage data: ${JSON.stringify(finalUsage)}`
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -401,6 +443,12 @@ Use this context to calculate relative dates like "yesterday", "last week", "beg
|
||||
stack: streamError?.stack,
|
||||
})
|
||||
|
||||
try {
|
||||
await recordUsage()
|
||||
} catch (usageError) {
|
||||
logger.warn(`[${requestId}] Failed to record usage after stream error`, usageError)
|
||||
}
|
||||
|
||||
const errorData = `data: ${JSON.stringify({ error: 'Streaming failed', done: true })}\n\n`
|
||||
controller.enqueue(encoder.encode(errorData))
|
||||
controller.close()
|
||||
@@ -424,8 +472,6 @@ Use this context to calculate relative dates like "yesterday", "last week", "beg
|
||||
message: error?.message || 'Unknown error',
|
||||
code: error?.code,
|
||||
status: error?.status,
|
||||
responseStatus: error?.response?.status,
|
||||
responseData: error?.response?.data ? safeStringify(error.response.data) : undefined,
|
||||
stack: error?.stack,
|
||||
useWandAzure,
|
||||
model: useWandAzure ? wandModelName : 'gpt-4o',
|
||||
@@ -440,14 +486,43 @@ Use this context to calculate relative dates like "yesterday", "last week", "beg
|
||||
}
|
||||
}
|
||||
|
||||
const completion = await activeClient.chat.completions.create({
|
||||
model: useWandAzure ? wandModelName : 'gpt-4o',
|
||||
messages: messages,
|
||||
temperature: 0.3,
|
||||
max_tokens: 10000,
|
||||
const apiUrl = useWandAzure
|
||||
? `${azureEndpoint?.replace(/\/$/, '')}/openai/v1/responses?api-version=${azureApiVersion}`
|
||||
: 'https://api.openai.com/v1/responses'
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
'OpenAI-Beta': 'responses=v1',
|
||||
}
|
||||
|
||||
if (useWandAzure) {
|
||||
headers['api-key'] = azureApiKey!
|
||||
} else {
|
||||
headers.Authorization = `Bearer ${activeOpenAIKey}`
|
||||
}
|
||||
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
model: useWandAzure ? wandModelName : 'gpt-4o',
|
||||
input: messages,
|
||||
temperature: 0.2,
|
||||
max_output_tokens: 10000,
|
||||
}),
|
||||
})
|
||||
|
||||
const generatedContent = completion.choices[0]?.message?.content?.trim()
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
const apiError = new Error(
|
||||
`API request failed: ${response.status} ${response.statusText} - ${errorText}`
|
||||
)
|
||||
;(apiError as any).status = response.status
|
||||
throw apiError
|
||||
}
|
||||
|
||||
const completion = await response.json()
|
||||
const generatedContent = extractResponseText(completion.output)?.trim()
|
||||
|
||||
if (!generatedContent) {
|
||||
logger.error(
|
||||
@@ -461,8 +536,18 @@ Use this context to calculate relative dates like "yesterday", "last week", "beg
|
||||
|
||||
logger.info(`[${requestId}] Wand generation successful`)
|
||||
|
||||
if (completion.usage) {
|
||||
await updateUserStatsForWand(session.user.id, completion.usage, requestId, isBYOK)
|
||||
const usage = parseResponsesUsage(completion.usage)
|
||||
if (usage) {
|
||||
await updateUserStatsForWand(
|
||||
session.user.id,
|
||||
{
|
||||
prompt_tokens: usage.promptTokens,
|
||||
completion_tokens: usage.completionTokens,
|
||||
total_tokens: usage.totalTokens,
|
||||
},
|
||||
requestId,
|
||||
isBYOK
|
||||
)
|
||||
}
|
||||
|
||||
return NextResponse.json({ success: true, content: generatedContent })
|
||||
@@ -472,10 +557,6 @@ Use this context to calculate relative dates like "yesterday", "last week", "beg
|
||||
message: error?.message || 'Unknown error',
|
||||
code: error?.code,
|
||||
status: error?.status,
|
||||
responseStatus: error instanceof OpenAI.APIError ? error.status : error?.response?.status,
|
||||
responseData: (error as any)?.response?.data
|
||||
? safeStringify((error as any).response.data)
|
||||
: undefined,
|
||||
stack: error?.stack,
|
||||
useWandAzure,
|
||||
model: useWandAzure ? wandModelName : 'gpt-4o',
|
||||
@@ -484,26 +565,19 @@ Use this context to calculate relative dates like "yesterday", "last week", "beg
|
||||
})
|
||||
|
||||
let clientErrorMessage = 'Wand generation failed. Please try again later.'
|
||||
let status = 500
|
||||
let status = typeof (error as any)?.status === 'number' ? (error as any).status : 500
|
||||
|
||||
if (error instanceof OpenAI.APIError) {
|
||||
status = error.status || 500
|
||||
logger.error(
|
||||
`[${requestId}] ${useWandAzure ? 'Azure OpenAI' : 'OpenAI'} API Error: ${status} - ${error.message}`
|
||||
)
|
||||
|
||||
if (status === 401) {
|
||||
clientErrorMessage = 'Authentication failed. Please check your API key configuration.'
|
||||
} else if (status === 429) {
|
||||
clientErrorMessage = 'Rate limit exceeded. Please try again later.'
|
||||
} else if (status >= 500) {
|
||||
clientErrorMessage =
|
||||
'The wand generation service is currently unavailable. Please try again later.'
|
||||
}
|
||||
} else if (useWandAzure && error.message?.includes('DeploymentNotFound')) {
|
||||
if (useWandAzure && error?.message?.includes('DeploymentNotFound')) {
|
||||
clientErrorMessage =
|
||||
'Azure OpenAI deployment not found. Please check your model deployment configuration.'
|
||||
status = 404
|
||||
} else if (status === 401) {
|
||||
clientErrorMessage = 'Authentication failed. Please check your API key configuration.'
|
||||
} else if (status === 429) {
|
||||
clientErrorMessage = 'Rate limit exceeded. Please try again later.'
|
||||
} else if (status >= 500) {
|
||||
clientErrorMessage =
|
||||
'The wand generation service is currently unavailable. Please try again later.'
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { extractResponseText } from '@/providers/openai/utils'
|
||||
|
||||
const logger = createLogger('SimAgentUtils')
|
||||
|
||||
@@ -12,47 +12,65 @@ const openaiApiKey = env.OPENAI_API_KEY
|
||||
|
||||
const useChatTitleAzure = azureApiKey && azureEndpoint && azureApiVersion
|
||||
|
||||
const client = useChatTitleAzure
|
||||
? new AzureOpenAI({
|
||||
apiKey: azureApiKey,
|
||||
apiVersion: azureApiVersion,
|
||||
endpoint: azureEndpoint,
|
||||
})
|
||||
: openaiApiKey
|
||||
? new OpenAI({
|
||||
apiKey: openaiApiKey,
|
||||
})
|
||||
: null
|
||||
|
||||
/**
|
||||
* Generates a short title for a chat based on the first message
|
||||
* @param message First user message in the chat
|
||||
* @returns A short title or null if API key is not available
|
||||
*/
|
||||
export async function generateChatTitle(message: string): Promise<string | null> {
|
||||
if (!client) {
|
||||
if (!useChatTitleAzure && !openaiApiKey) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await client.chat.completions.create({
|
||||
model: useChatTitleAzure ? chatTitleModelName : 'gpt-4o',
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content:
|
||||
'Generate a very short title (3-5 words max) for a chat that starts with this message. The title should be concise and descriptive. Do not wrap the title in quotes.',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: message,
|
||||
},
|
||||
],
|
||||
max_tokens: 20,
|
||||
temperature: 0.2,
|
||||
const apiUrl = useChatTitleAzure
|
||||
? `${azureEndpoint?.replace(/\/$/, '')}/openai/v1/responses?api-version=${azureApiVersion}`
|
||||
: 'https://api.openai.com/v1/responses'
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
'OpenAI-Beta': 'responses=v1',
|
||||
}
|
||||
|
||||
if (useChatTitleAzure) {
|
||||
headers['api-key'] = azureApiKey!
|
||||
} else {
|
||||
headers.Authorization = `Bearer ${openaiApiKey}`
|
||||
}
|
||||
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
model: useChatTitleAzure ? chatTitleModelName : 'gpt-4o',
|
||||
input: [
|
||||
{
|
||||
role: 'system',
|
||||
content:
|
||||
'Generate a very short title (3-5 words max) for a chat that starts with this message. The title should be concise and descriptive. Do not wrap the title in quotes.',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: message,
|
||||
},
|
||||
],
|
||||
max_output_tokens: 20,
|
||||
temperature: 0.2,
|
||||
}),
|
||||
})
|
||||
|
||||
const title = response.choices[0]?.message?.content?.trim() || null
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
logger.error('Error generating chat title:', {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
error: errorText,
|
||||
})
|
||||
return null
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
const title = extractResponseText(data.output)?.trim() || null
|
||||
return title
|
||||
} catch (error) {
|
||||
logger.error('Error generating chat title:', error)
|
||||
|
||||
@@ -1,26 +1,9 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import { AzureOpenAI } from 'openai'
|
||||
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import {
|
||||
checkForForcedToolUsage,
|
||||
createReadableStreamFromAzureOpenAIStream,
|
||||
} from '@/providers/azure-openai/utils'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
import { executeResponsesProviderRequest } from '@/providers/openai/core'
|
||||
import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types'
|
||||
|
||||
const logger = createLogger('AzureOpenAIProvider')
|
||||
|
||||
@@ -38,16 +21,6 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
executeRequest: async (
|
||||
request: ProviderRequest
|
||||
): Promise<ProviderResponse | StreamingExecution> => {
|
||||
logger.info('Preparing Azure OpenAI request', {
|
||||
model: request.model,
|
||||
hasSystemPrompt: !!request.systemPrompt,
|
||||
hasMessages: !!request.messages?.length,
|
||||
hasTools: !!request.tools?.length,
|
||||
toolCount: request.tools?.length || 0,
|
||||
hasResponseFormat: !!request.responseFormat,
|
||||
stream: !!request.stream,
|
||||
})
|
||||
|
||||
const azureEndpoint = request.azureEndpoint || env.AZURE_OPENAI_ENDPOINT
|
||||
const azureApiVersion =
|
||||
request.azureApiVersion || env.AZURE_OPENAI_API_VERSION || '2024-07-01-preview'
|
||||
@@ -58,520 +31,24 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
)
|
||||
}
|
||||
|
||||
const azureOpenAI = new AzureOpenAI({
|
||||
apiKey: request.apiKey,
|
||||
apiVersion: azureApiVersion,
|
||||
endpoint: azureEndpoint,
|
||||
})
|
||||
|
||||
const allMessages = []
|
||||
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
content: request.systemPrompt,
|
||||
})
|
||||
if (!request.apiKey) {
|
||||
throw new Error('API key is required for Azure OpenAI')
|
||||
}
|
||||
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
content: request.context,
|
||||
})
|
||||
}
|
||||
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}))
|
||||
: undefined
|
||||
|
||||
const deploymentName = request.model.replace('azure/', '')
|
||||
const payload: any = {
|
||||
model: deploymentName,
|
||||
messages: allMessages,
|
||||
}
|
||||
const apiUrl = `${azureEndpoint.replace(/\/$/, '')}/openai/v1/responses?api-version=${azureApiVersion}`
|
||||
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens != null) payload.max_completion_tokens = request.maxTokens
|
||||
|
||||
if (request.reasoningEffort !== undefined) payload.reasoning_effort = request.reasoningEffort
|
||||
if (request.verbosity !== undefined) payload.verbosity = request.verbosity
|
||||
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: request.responseFormat.name || 'response_schema',
|
||||
schema: request.responseFormat.schema || request.responseFormat,
|
||||
strict: request.responseFormat.strict !== false,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info('Added JSON schema response format to Azure OpenAI request')
|
||||
}
|
||||
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
|
||||
if (tools?.length) {
|
||||
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'azure-openai')
|
||||
const { tools: filteredTools, toolChoice } = preparedTools
|
||||
|
||||
if (filteredTools?.length && toolChoice) {
|
||||
payload.tools = filteredTools
|
||||
payload.tool_choice = toolChoice
|
||||
|
||||
logger.info('Azure OpenAI request configuration:', {
|
||||
toolCount: filteredTools.length,
|
||||
toolChoice:
|
||||
typeof toolChoice === 'string'
|
||||
? toolChoice
|
||||
: toolChoice.type === 'function'
|
||||
? `force:${toolChoice.function.name}`
|
||||
: toolChoice.type === 'tool'
|
||||
? `force:${toolChoice.name}`
|
||||
: toolChoice.type === 'any'
|
||||
? `force:${toolChoice.any?.name || 'unknown'}`
|
||||
: 'unknown',
|
||||
model: deploymentName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info('Using streaming response for Azure OpenAI request')
|
||||
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
const streamResponse = await azureOpenAI.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromAzureOpenAIStream(streamResponse, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
input: usage.prompt_tokens,
|
||||
output: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
if (streamingResult.execution.output.providerTiming) {
|
||||
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
|
||||
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
|
||||
streamEndTime
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: { input: 0, output: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
timeSegments: [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Streaming response',
|
||||
startTime: providerStartTime,
|
||||
endTime: Date.now(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
],
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
const initialCallTime = Date.now()
|
||||
const originalToolChoice = payload.tool_choice
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
|
||||
let currentResponse = await azureOpenAI.chat.completions.create(payload)
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
const tokens = {
|
||||
input: currentResponse.usage?.prompt_tokens || 0,
|
||||
output: currentResponse.usage?.completion_tokens || 0,
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
let hasUsedForcedTool = false
|
||||
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Initial response',
|
||||
startTime: initialCallTime,
|
||||
endTime: initialCallTime + firstResponseTime,
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
]
|
||||
|
||||
const firstCheckResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
originalToolChoice,
|
||||
logger,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = firstCheckResult.hasUsedForcedTool
|
||||
usedForcedTools = firstCheckResult.usedForcedTools
|
||||
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
try {
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
|
||||
if (!tool) return null
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
const result = await executeTool(toolName, executionParams)
|
||||
const toolCallEndTime = Date.now()
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
} catch (error) {
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call:', { error, toolName })
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: toolCallsInResponse.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) {
|
||||
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
|
||||
|
||||
if (remainingTools.length > 0) {
|
||||
nextPayload.tool_choice = {
|
||||
type: 'function',
|
||||
function: { name: remainingTools[0] },
|
||||
}
|
||||
logger.info(`Forcing next tool: ${remainingTools[0]}`)
|
||||
} else {
|
||||
nextPayload.tool_choice = 'auto'
|
||||
logger.info('All forced tools have been used, switching to auto tool_choice')
|
||||
}
|
||||
}
|
||||
|
||||
const nextModelStartTime = Date.now()
|
||||
currentResponse = await azureOpenAI.chat.completions.create(nextPayload)
|
||||
|
||||
const nextCheckResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
nextPayload.tool_choice,
|
||||
logger,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = nextCheckResult.hasUsedForcedTool
|
||||
usedForcedTools = nextCheckResult.usedForcedTools
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
startTime: nextModelStartTime,
|
||||
endTime: nextModelEndTime,
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
modelTime += thisModelTime
|
||||
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
if (currentResponse.usage) {
|
||||
tokens.input += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.output += currentResponse.usage.completion_tokens || 0
|
||||
tokens.total += currentResponse.usage.total_tokens || 0
|
||||
}
|
||||
|
||||
iterationCount++
|
||||
}
|
||||
|
||||
if (request.stream) {
|
||||
logger.info('Using streaming for final response after tool processing')
|
||||
|
||||
const accumulatedCost = calculateCost(request.model, tokens.input, tokens.output)
|
||||
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto',
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
const streamResponse = await azureOpenAI.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromAzureOpenAIStream(streamResponse, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
input: tokens.input + usage.prompt_tokens,
|
||||
output: tokens.output + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
input: tokens.input,
|
||||
output: tokens.output,
|
||||
total: tokens.total,
|
||||
},
|
||||
toolCalls:
|
||||
toolCalls.length > 0
|
||||
? {
|
||||
list: toolCalls,
|
||||
count: toolCalls.length,
|
||||
}
|
||||
: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
return {
|
||||
content,
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
timing: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
logger.error('Error in Azure OpenAI request:', {
|
||||
error,
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
}
|
||||
|
||||
throw enhancedError
|
||||
}
|
||||
return executeResponsesProviderRequest(request, {
|
||||
providerId: 'azure-openai',
|
||||
providerLabel: 'Azure OpenAI',
|
||||
modelName: deploymentName,
|
||||
endpoint: apiUrl,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'OpenAI-Beta': 'responses=v1',
|
||||
'api-key': request.apiKey,
|
||||
},
|
||||
logger,
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
import type { Logger } from '@sim/logger'
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import type { Stream } from 'openai/streaming'
|
||||
import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils'
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from an Azure OpenAI streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
*/
|
||||
export function createReadableStreamFromAzureOpenAIStream(
|
||||
azureOpenAIStream: Stream<ChatCompletionChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream {
|
||||
return createOpenAICompatibleStream(azureOpenAIStream, 'Azure OpenAI', onComplete)
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a forced tool was used in an Azure OpenAI response.
|
||||
* Uses the shared OpenAI-compatible forced tool usage helper.
|
||||
*/
|
||||
export function checkForForcedToolUsage(
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
|
||||
_logger: Logger,
|
||||
forcedTools: string[],
|
||||
usedForcedTools: string[]
|
||||
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
|
||||
return checkForForcedToolUsageOpenAI(
|
||||
response,
|
||||
toolChoice,
|
||||
'Azure OpenAI',
|
||||
forcedTools,
|
||||
usedForcedTools,
|
||||
_logger
|
||||
)
|
||||
}
|
||||
807
apps/sim/providers/openai/core.ts
Normal file
807
apps/sim/providers/openai/core.ts
Normal file
@@ -0,0 +1,807 @@
|
||||
import type { Logger } from '@sim/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import type { Message, ProviderRequest, ProviderResponse, TimeSegment } from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
import {
|
||||
buildResponsesInputFromMessages,
|
||||
convertResponseOutputToInputItems,
|
||||
convertToolsToResponses,
|
||||
createReadableStreamFromResponses,
|
||||
extractResponseText,
|
||||
extractResponseToolCalls,
|
||||
parseResponsesUsage,
|
||||
type ResponsesInputItem,
|
||||
type ResponsesToolCall,
|
||||
toResponsesToolChoice,
|
||||
} from './utils'
|
||||
|
||||
type PreparedTools = ReturnType<typeof prepareToolsWithUsageControl>
|
||||
type ToolChoice = PreparedTools['toolChoice']
|
||||
|
||||
/**
|
||||
* Recursively enforces OpenAI strict mode requirements on a JSON schema.
|
||||
* - Sets additionalProperties: false on all object types.
|
||||
* - Ensures required includes ALL property keys.
|
||||
*/
|
||||
function enforceStrictSchema(schema: any): any {
|
||||
if (!schema || typeof schema !== 'object') return schema
|
||||
|
||||
const result = { ...schema }
|
||||
|
||||
// If this is an object type, enforce strict requirements
|
||||
if (result.type === 'object') {
|
||||
result.additionalProperties = false
|
||||
|
||||
// Recursively process properties and ensure required includes all keys
|
||||
if (result.properties && typeof result.properties === 'object') {
|
||||
const propKeys = Object.keys(result.properties)
|
||||
result.required = propKeys // Strict mode requires ALL properties
|
||||
result.properties = Object.fromEntries(
|
||||
Object.entries(result.properties).map(([key, value]) => [key, enforceStrictSchema(value)])
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle array items
|
||||
if (result.type === 'array' && result.items) {
|
||||
result.items = enforceStrictSchema(result.items)
|
||||
}
|
||||
|
||||
// Handle anyOf, oneOf, allOf
|
||||
for (const keyword of ['anyOf', 'oneOf', 'allOf']) {
|
||||
if (Array.isArray(result[keyword])) {
|
||||
result[keyword] = result[keyword].map(enforceStrictSchema)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle $defs / definitions
|
||||
for (const defKey of ['$defs', 'definitions']) {
|
||||
if (result[defKey] && typeof result[defKey] === 'object') {
|
||||
result[defKey] = Object.fromEntries(
|
||||
Object.entries(result[defKey]).map(([key, value]) => [key, enforceStrictSchema(value)])
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
export interface ResponsesProviderConfig {
|
||||
providerId: string
|
||||
providerLabel: string
|
||||
modelName: string
|
||||
endpoint: string
|
||||
headers: Record<string, string>
|
||||
logger: Logger
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes a Responses API request with tool-loop handling and streaming support.
|
||||
*/
|
||||
export async function executeResponsesProviderRequest(
|
||||
request: ProviderRequest,
|
||||
config: ResponsesProviderConfig
|
||||
): Promise<ProviderResponse | StreamingExecution> {
|
||||
const { logger } = config
|
||||
|
||||
logger.info(`Preparing ${config.providerLabel} request`, {
|
||||
model: request.model,
|
||||
hasSystemPrompt: !!request.systemPrompt,
|
||||
hasMessages: !!request.messages?.length,
|
||||
hasTools: !!request.tools?.length,
|
||||
toolCount: request.tools?.length || 0,
|
||||
hasResponseFormat: !!request.responseFormat,
|
||||
stream: !!request.stream,
|
||||
})
|
||||
|
||||
const allMessages: Message[] = []
|
||||
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
content: request.systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
content: request.context,
|
||||
})
|
||||
}
|
||||
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
const initialInput = buildResponsesInputFromMessages(allMessages)
|
||||
|
||||
const basePayload: Record<string, any> = {
|
||||
model: config.modelName,
|
||||
}
|
||||
|
||||
if (request.temperature !== undefined) basePayload.temperature = request.temperature
|
||||
if (request.maxTokens != null) basePayload.max_output_tokens = request.maxTokens
|
||||
|
||||
if (request.reasoningEffort !== undefined) {
|
||||
basePayload.reasoning = {
|
||||
effort: request.reasoningEffort,
|
||||
summary: 'auto',
|
||||
}
|
||||
}
|
||||
|
||||
if (request.verbosity !== undefined) {
|
||||
basePayload.text = {
|
||||
...(basePayload.text ?? {}),
|
||||
verbosity: request.verbosity,
|
||||
}
|
||||
}
|
||||
|
||||
// Store response format config - for Azure with tools, we defer applying it until after tool calls complete
|
||||
let deferredTextFormat: { type: string; name: string; schema: any; strict: boolean } | undefined
|
||||
const hasTools = !!request.tools?.length
|
||||
const isAzure = config.providerId === 'azure-openai'
|
||||
|
||||
if (request.responseFormat) {
|
||||
const isStrict = request.responseFormat.strict !== false
|
||||
const rawSchema = request.responseFormat.schema || request.responseFormat
|
||||
// OpenAI strict mode requires additionalProperties: false on ALL nested objects
|
||||
const cleanedSchema = isStrict ? enforceStrictSchema(rawSchema) : rawSchema
|
||||
|
||||
const textFormat = {
|
||||
type: 'json_schema' as const,
|
||||
name: request.responseFormat.name || 'response_schema',
|
||||
schema: cleanedSchema,
|
||||
strict: isStrict,
|
||||
}
|
||||
|
||||
// Azure OpenAI has issues combining tools + response_format in the same request
|
||||
// Defer the format until after tool calls complete for Azure
|
||||
if (isAzure && hasTools) {
|
||||
deferredTextFormat = textFormat
|
||||
logger.info(
|
||||
`Deferring JSON schema response format for ${config.providerLabel} (will apply after tool calls complete)`
|
||||
)
|
||||
} else {
|
||||
basePayload.text = {
|
||||
...(basePayload.text ?? {}),
|
||||
format: textFormat,
|
||||
}
|
||||
logger.info(`Added JSON schema response format to ${config.providerLabel} request`)
|
||||
}
|
||||
}
|
||||
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}))
|
||||
: undefined
|
||||
|
||||
let preparedTools: PreparedTools | null = null
|
||||
let responsesToolChoice: ReturnType<typeof toResponsesToolChoice> | undefined
|
||||
let trackingToolChoice: ToolChoice | undefined
|
||||
|
||||
if (tools?.length) {
|
||||
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, config.providerId)
|
||||
const { tools: filteredTools, toolChoice } = preparedTools
|
||||
trackingToolChoice = toolChoice
|
||||
|
||||
if (filteredTools?.length) {
|
||||
const convertedTools = convertToolsToResponses(filteredTools)
|
||||
if (!convertedTools.length) {
|
||||
throw new Error('All tools have empty names')
|
||||
}
|
||||
|
||||
basePayload.tools = convertedTools
|
||||
basePayload.parallel_tool_calls = true
|
||||
}
|
||||
|
||||
if (toolChoice) {
|
||||
responsesToolChoice = toResponsesToolChoice(toolChoice)
|
||||
if (responsesToolChoice) {
|
||||
basePayload.tool_choice = responsesToolChoice
|
||||
}
|
||||
|
||||
logger.info(`${config.providerLabel} request configuration:`, {
|
||||
toolCount: filteredTools?.length || 0,
|
||||
toolChoice:
|
||||
typeof toolChoice === 'string'
|
||||
? toolChoice
|
||||
: toolChoice.type === 'function'
|
||||
? `force:${toolChoice.function?.name}`
|
||||
: toolChoice.type === 'tool'
|
||||
? `force:${toolChoice.name}`
|
||||
: toolChoice.type === 'any'
|
||||
? `force:${toolChoice.any?.name || 'unknown'}`
|
||||
: 'unknown',
|
||||
model: config.modelName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const createRequestBody = (input: ResponsesInputItem[], overrides: Record<string, any> = {}) => ({
|
||||
...basePayload,
|
||||
input,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const parseErrorResponse = async (response: Response): Promise<string> => {
|
||||
const text = await response.text()
|
||||
try {
|
||||
const payload = JSON.parse(text)
|
||||
return payload?.error?.message || text
|
||||
} catch {
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
const postResponses = async (body: Record<string, any>) => {
|
||||
const response = await fetch(config.endpoint, {
|
||||
method: 'POST',
|
||||
headers: config.headers,
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const message = await parseErrorResponse(response)
|
||||
throw new Error(`${config.providerLabel} API error (${response.status}): ${message}`)
|
||||
}
|
||||
|
||||
return response.json()
|
||||
}
|
||||
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info(`Using streaming response for ${config.providerLabel} request`)
|
||||
|
||||
const streamResponse = await fetch(config.endpoint, {
|
||||
method: 'POST',
|
||||
headers: config.headers,
|
||||
body: JSON.stringify(createRequestBody(initialInput, { stream: true })),
|
||||
})
|
||||
|
||||
if (!streamResponse.ok) {
|
||||
const message = await parseErrorResponse(streamResponse)
|
||||
throw new Error(`${config.providerLabel} API error (${streamResponse.status}): ${message}`)
|
||||
}
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromResponses(streamResponse, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
input: usage?.promptTokens || 0,
|
||||
output: usage?.completionTokens || 0,
|
||||
total: usage?.totalTokens || 0,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage?.promptTokens || 0,
|
||||
usage?.completionTokens || 0
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
if (streamingResult.execution.output.providerTiming) {
|
||||
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
|
||||
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
|
||||
streamEndTime
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: { input: 0, output: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
timeSegments: [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Streaming response',
|
||||
startTime: providerStartTime,
|
||||
endTime: Date.now(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
],
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
const initialCallTime = Date.now()
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
let hasUsedForcedTool = false
|
||||
let currentToolChoice = responsesToolChoice
|
||||
let currentTrackingToolChoice = trackingToolChoice
|
||||
|
||||
const checkForForcedToolUsage = (
|
||||
toolCallsInResponse: ResponsesToolCall[],
|
||||
toolChoice: ToolChoice | undefined
|
||||
) => {
|
||||
if (typeof toolChoice === 'object' && toolCallsInResponse.length > 0) {
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsInResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
config.providerId,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
}
|
||||
}
|
||||
|
||||
const currentInput: ResponsesInputItem[] = [...initialInput]
|
||||
let currentResponse = await postResponses(
|
||||
createRequestBody(currentInput, { tool_choice: currentToolChoice })
|
||||
)
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
const initialUsage = parseResponsesUsage(currentResponse.usage)
|
||||
const tokens = {
|
||||
input: initialUsage?.promptTokens || 0,
|
||||
output: initialUsage?.completionTokens || 0,
|
||||
total: initialUsage?.totalTokens || 0,
|
||||
}
|
||||
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
let iterationCount = 0
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
let content = extractResponseText(currentResponse.output) || ''
|
||||
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Initial response',
|
||||
startTime: initialCallTime,
|
||||
endTime: initialCallTime + firstResponseTime,
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
]
|
||||
|
||||
checkForForcedToolUsage(
|
||||
extractResponseToolCalls(currentResponse.output),
|
||||
currentTrackingToolChoice
|
||||
)
|
||||
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
const responseText = extractResponseText(currentResponse.output)
|
||||
if (responseText) {
|
||||
content = responseText
|
||||
}
|
||||
|
||||
const toolCallsInResponse = extractResponseToolCalls(currentResponse.output)
|
||||
if (!toolCallsInResponse.length) {
|
||||
break
|
||||
}
|
||||
|
||||
const outputInputItems = convertResponseOutputToInputItems(currentResponse.output)
|
||||
if (outputInputItems.length) {
|
||||
currentInput.push(...outputInputItems)
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing ${toolCallsInResponse.length} tool calls in parallel (iteration ${
|
||||
iterationCount + 1
|
||||
}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.name
|
||||
|
||||
try {
|
||||
const toolArgs = toolCall.arguments ? JSON.parse(toolCall.arguments) : {}
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
|
||||
if (!tool) {
|
||||
return null
|
||||
}
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
const result = await executeTool(toolName, executionParams)
|
||||
const toolCallEndTime = Date.now()
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
} catch (error) {
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call:', { error, toolName })
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentInput.push({
|
||||
type: 'function_call_output',
|
||||
call_id: toolCall.id,
|
||||
output: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
if (typeof currentToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) {
|
||||
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
|
||||
|
||||
if (remainingTools.length > 0) {
|
||||
currentToolChoice = {
|
||||
type: 'function',
|
||||
name: remainingTools[0],
|
||||
}
|
||||
currentTrackingToolChoice = {
|
||||
type: 'function',
|
||||
function: { name: remainingTools[0] },
|
||||
}
|
||||
logger.info(`Forcing next tool: ${remainingTools[0]}`)
|
||||
} else {
|
||||
currentToolChoice = 'auto'
|
||||
currentTrackingToolChoice = 'auto'
|
||||
logger.info('All forced tools have been used, switching to auto tool_choice')
|
||||
}
|
||||
}
|
||||
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
currentResponse = await postResponses(
|
||||
createRequestBody(currentInput, { tool_choice: currentToolChoice })
|
||||
)
|
||||
|
||||
checkForForcedToolUsage(
|
||||
extractResponseToolCalls(currentResponse.output),
|
||||
currentTrackingToolChoice
|
||||
)
|
||||
|
||||
const latestText = extractResponseText(currentResponse.output)
|
||||
if (latestText) {
|
||||
content = latestText
|
||||
}
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
startTime: nextModelStartTime,
|
||||
endTime: nextModelEndTime,
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
modelTime += thisModelTime
|
||||
|
||||
const usage = parseResponsesUsage(currentResponse.usage)
|
||||
if (usage) {
|
||||
tokens.input += usage.promptTokens
|
||||
tokens.output += usage.completionTokens
|
||||
tokens.total += usage.totalTokens
|
||||
}
|
||||
|
||||
iterationCount++
|
||||
}
|
||||
|
||||
// For Azure with deferred format: make a final call with the response format applied
|
||||
// This happens whenever we have a deferred format, even if no tools were called
|
||||
// (the initial call was made without the format, so we need to apply it now)
|
||||
let appliedDeferredFormat = false
|
||||
if (deferredTextFormat) {
|
||||
logger.info(
|
||||
`Applying deferred JSON schema response format for ${config.providerLabel} (iterationCount: ${iterationCount})`
|
||||
)
|
||||
|
||||
const finalFormatStartTime = Date.now()
|
||||
|
||||
// Determine what input to use for the formatted call
|
||||
let formattedInput: ResponsesInputItem[]
|
||||
|
||||
if (iterationCount > 0) {
|
||||
// Tools were called - include the conversation history with tool results
|
||||
const lastOutputItems = convertResponseOutputToInputItems(currentResponse.output)
|
||||
if (lastOutputItems.length) {
|
||||
currentInput.push(...lastOutputItems)
|
||||
}
|
||||
formattedInput = currentInput
|
||||
} else {
|
||||
// No tools were called - just retry the initial call with format applied
|
||||
// Don't include the model's previous unformatted response
|
||||
formattedInput = initialInput
|
||||
}
|
||||
|
||||
// Make final call with the response format - build payload without tools
|
||||
const finalPayload: Record<string, any> = {
|
||||
model: config.modelName,
|
||||
input: formattedInput,
|
||||
text: {
|
||||
...(basePayload.text ?? {}),
|
||||
format: deferredTextFormat,
|
||||
},
|
||||
}
|
||||
|
||||
// Copy over non-tool related settings
|
||||
if (request.temperature !== undefined) finalPayload.temperature = request.temperature
|
||||
if (request.maxTokens != null) finalPayload.max_output_tokens = request.maxTokens
|
||||
if (request.reasoningEffort !== undefined) {
|
||||
finalPayload.reasoning = {
|
||||
effort: request.reasoningEffort,
|
||||
summary: 'auto',
|
||||
}
|
||||
}
|
||||
if (request.verbosity !== undefined) {
|
||||
finalPayload.text = {
|
||||
...finalPayload.text,
|
||||
verbosity: request.verbosity,
|
||||
}
|
||||
}
|
||||
|
||||
currentResponse = await postResponses(finalPayload)
|
||||
|
||||
const finalFormatEndTime = Date.now()
|
||||
const finalFormatDuration = finalFormatEndTime - finalFormatStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: 'Final formatted response',
|
||||
startTime: finalFormatStartTime,
|
||||
endTime: finalFormatEndTime,
|
||||
duration: finalFormatDuration,
|
||||
})
|
||||
|
||||
modelTime += finalFormatDuration
|
||||
|
||||
const finalUsage = parseResponsesUsage(currentResponse.usage)
|
||||
if (finalUsage) {
|
||||
tokens.input += finalUsage.promptTokens
|
||||
tokens.output += finalUsage.completionTokens
|
||||
tokens.total += finalUsage.totalTokens
|
||||
}
|
||||
|
||||
// Update content with the formatted response
|
||||
const formattedText = extractResponseText(currentResponse.output)
|
||||
if (formattedText) {
|
||||
content = formattedText
|
||||
}
|
||||
|
||||
appliedDeferredFormat = true
|
||||
}
|
||||
|
||||
// Skip streaming if we already applied deferred format - we have the formatted content
|
||||
// Making another streaming call would lose the formatted response
|
||||
if (request.stream && !appliedDeferredFormat) {
|
||||
logger.info('Using streaming for final response after tool processing')
|
||||
|
||||
const accumulatedCost = calculateCost(request.model, tokens.input, tokens.output)
|
||||
|
||||
// For Azure with deferred format in streaming mode, include the format in the streaming call
|
||||
const streamOverrides: Record<string, any> = { stream: true, tool_choice: 'auto' }
|
||||
if (deferredTextFormat) {
|
||||
streamOverrides.text = {
|
||||
...(basePayload.text ?? {}),
|
||||
format: deferredTextFormat,
|
||||
}
|
||||
}
|
||||
|
||||
const streamResponse = await fetch(config.endpoint, {
|
||||
method: 'POST',
|
||||
headers: config.headers,
|
||||
body: JSON.stringify(createRequestBody(currentInput, streamOverrides)),
|
||||
})
|
||||
|
||||
if (!streamResponse.ok) {
|
||||
const message = await parseErrorResponse(streamResponse)
|
||||
throw new Error(`${config.providerLabel} API error (${streamResponse.status}): ${message}`)
|
||||
}
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromResponses(streamResponse, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
input: tokens.input + (usage?.promptTokens || 0),
|
||||
output: tokens.output + (usage?.completionTokens || 0),
|
||||
total: tokens.total + (usage?.totalTokens || 0),
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage?.promptTokens || 0,
|
||||
usage?.completionTokens || 0
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
input: tokens.input,
|
||||
output: tokens.output,
|
||||
total: tokens.total,
|
||||
},
|
||||
toolCalls:
|
||||
toolCalls.length > 0
|
||||
? {
|
||||
list: toolCalls,
|
||||
count: toolCalls.length,
|
||||
}
|
||||
: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
return {
|
||||
content,
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
timing: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
logger.error(`Error in ${config.providerLabel} request:`, {
|
||||
error,
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
}
|
||||
|
||||
throw enhancedError
|
||||
}
|
||||
}
|
||||
@@ -1,25 +1,11 @@
|
||||
import { createLogger } from '@sim/logger'
|
||||
import OpenAI from 'openai'
|
||||
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import { createReadableStreamFromOpenAIStream } from '@/providers/openai/utils'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types'
|
||||
import { executeResponsesProviderRequest } from './core'
|
||||
|
||||
const logger = createLogger('OpenAIProvider')
|
||||
const responsesEndpoint = 'https://api.openai.com/v1/responses'
|
||||
|
||||
export const openaiProvider: ProviderConfig = {
|
||||
id: 'openai',
|
||||
@@ -32,534 +18,21 @@ export const openaiProvider: ProviderConfig = {
|
||||
executeRequest: async (
|
||||
request: ProviderRequest
|
||||
): Promise<ProviderResponse | StreamingExecution> => {
|
||||
logger.info('Preparing OpenAI request', {
|
||||
model: request.model,
|
||||
hasSystemPrompt: !!request.systemPrompt,
|
||||
hasMessages: !!request.messages?.length,
|
||||
hasTools: !!request.tools?.length,
|
||||
toolCount: request.tools?.length || 0,
|
||||
hasResponseFormat: !!request.responseFormat,
|
||||
stream: !!request.stream,
|
||||
if (!request.apiKey) {
|
||||
throw new Error('API key is required for OpenAI')
|
||||
}
|
||||
|
||||
return executeResponsesProviderRequest(request, {
|
||||
providerId: 'openai',
|
||||
providerLabel: 'OpenAI',
|
||||
modelName: request.model,
|
||||
endpoint: responsesEndpoint,
|
||||
headers: {
|
||||
Authorization: `Bearer ${request.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
'OpenAI-Beta': 'responses=v1',
|
||||
},
|
||||
logger,
|
||||
})
|
||||
|
||||
const openai = new OpenAI({ apiKey: request.apiKey })
|
||||
|
||||
const allMessages = []
|
||||
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
content: request.systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
content: request.context,
|
||||
})
|
||||
}
|
||||
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}))
|
||||
: undefined
|
||||
|
||||
const payload: any = {
|
||||
model: request.model,
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens != null) payload.max_completion_tokens = request.maxTokens
|
||||
|
||||
if (request.reasoningEffort !== undefined) payload.reasoning_effort = request.reasoningEffort
|
||||
if (request.verbosity !== undefined) payload.verbosity = request.verbosity
|
||||
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: request.responseFormat.name || 'response_schema',
|
||||
schema: request.responseFormat.schema || request.responseFormat,
|
||||
strict: request.responseFormat.strict !== false,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info('Added JSON schema response format to request')
|
||||
}
|
||||
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
|
||||
if (tools?.length) {
|
||||
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'openai')
|
||||
const { tools: filteredTools, toolChoice } = preparedTools
|
||||
|
||||
if (filteredTools?.length && toolChoice) {
|
||||
payload.tools = filteredTools
|
||||
payload.tool_choice = toolChoice
|
||||
|
||||
logger.info('OpenAI request configuration:', {
|
||||
toolCount: filteredTools.length,
|
||||
toolChoice:
|
||||
typeof toolChoice === 'string'
|
||||
? toolChoice
|
||||
: toolChoice.type === 'function'
|
||||
? `force:${toolChoice.function.name}`
|
||||
: toolChoice.type === 'tool'
|
||||
? `force:${toolChoice.name}`
|
||||
: toolChoice.type === 'any'
|
||||
? `force:${toolChoice.any?.name || 'unknown'}`
|
||||
: 'unknown',
|
||||
model: request.model,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info('Using streaming response for OpenAI request')
|
||||
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
const streamResponse = await openai.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
input: usage.prompt_tokens,
|
||||
output: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
if (streamingResult.execution.output.providerTiming) {
|
||||
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
|
||||
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
|
||||
streamEndTime
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: { input: 0, output: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
timeSegments: [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Streaming response',
|
||||
startTime: providerStartTime,
|
||||
endTime: Date.now(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
],
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
const initialCallTime = Date.now()
|
||||
|
||||
const originalToolChoice = payload.tool_choice
|
||||
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
|
||||
/**
|
||||
* Helper function to check for forced tool usage in responses
|
||||
*/
|
||||
const checkForForcedToolUsage = (
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }
|
||||
) => {
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
'openai',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
}
|
||||
}
|
||||
|
||||
let currentResponse = await openai.chat.completions.create(payload)
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
const tokens = {
|
||||
input: currentResponse.usage?.prompt_tokens || 0,
|
||||
output: currentResponse.usage?.completion_tokens || 0,
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
let hasUsedForcedTool = false
|
||||
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Initial response',
|
||||
startTime: initialCallTime,
|
||||
endTime: initialCallTime + firstResponseTime,
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
]
|
||||
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing ${toolCallsInResponse.length} tool calls in parallel (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
try {
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
|
||||
if (!tool) {
|
||||
return null
|
||||
}
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
const result = await executeTool(toolName, executionParams)
|
||||
const toolCallEndTime = Date.now()
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
} catch (error) {
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call:', { error, toolName })
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: toolCallsInResponse.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) {
|
||||
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
|
||||
|
||||
if (remainingTools.length > 0) {
|
||||
nextPayload.tool_choice = {
|
||||
type: 'function',
|
||||
function: { name: remainingTools[0] },
|
||||
}
|
||||
logger.info(`Forcing next tool: ${remainingTools[0]}`)
|
||||
} else {
|
||||
nextPayload.tool_choice = 'auto'
|
||||
logger.info('All forced tools have been used, switching to auto tool_choice')
|
||||
}
|
||||
}
|
||||
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
currentResponse = await openai.chat.completions.create(nextPayload)
|
||||
|
||||
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
startTime: nextModelStartTime,
|
||||
endTime: nextModelEndTime,
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
modelTime += thisModelTime
|
||||
|
||||
if (currentResponse.usage) {
|
||||
tokens.input += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.output += currentResponse.usage.completion_tokens || 0
|
||||
tokens.total += currentResponse.usage.total_tokens || 0
|
||||
}
|
||||
|
||||
iterationCount++
|
||||
}
|
||||
|
||||
if (request.stream) {
|
||||
logger.info('Using streaming for final response after tool processing')
|
||||
|
||||
const accumulatedCost = calculateCost(request.model, tokens.input, tokens.output)
|
||||
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto',
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
const streamResponse = await openai.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
input: tokens.input + usage.prompt_tokens,
|
||||
output: tokens.output + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
input: tokens.input,
|
||||
output: tokens.output,
|
||||
total: tokens.total,
|
||||
},
|
||||
toolCalls:
|
||||
toolCalls.length > 0
|
||||
? {
|
||||
list: toolCalls,
|
||||
count: toolCalls.length,
|
||||
}
|
||||
: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
return {
|
||||
content,
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
timing: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
logger.error('Error in OpenAI request:', {
|
||||
error,
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
}
|
||||
|
||||
throw enhancedError
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,15 +1,465 @@
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import type { Stream } from 'openai/streaming'
|
||||
import { createOpenAICompatibleStream } from '@/providers/utils'
|
||||
import { createLogger } from '@sim/logger'
|
||||
import type { Message } from '@/providers/types'
|
||||
|
||||
const logger = createLogger('ResponsesUtils')
|
||||
|
||||
export interface ResponsesUsageTokens {
|
||||
promptTokens: number
|
||||
completionTokens: number
|
||||
totalTokens: number
|
||||
cachedTokens: number
|
||||
reasoningTokens: number
|
||||
}
|
||||
|
||||
export interface ResponsesToolCall {
|
||||
id: string
|
||||
name: string
|
||||
arguments: string
|
||||
}
|
||||
|
||||
export type ResponsesInputItem =
|
||||
| {
|
||||
role: 'system' | 'user' | 'assistant'
|
||||
content: string
|
||||
}
|
||||
| {
|
||||
type: 'function_call'
|
||||
call_id: string
|
||||
name: string
|
||||
arguments: string
|
||||
}
|
||||
| {
|
||||
type: 'function_call_output'
|
||||
call_id: string
|
||||
output: string
|
||||
}
|
||||
|
||||
export interface ResponsesToolDefinition {
|
||||
type: 'function'
|
||||
name: string
|
||||
description?: string
|
||||
parameters?: Record<string, any>
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from an OpenAI streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
* Converts chat-style messages into Responses API input items.
|
||||
*/
|
||||
export function createReadableStreamFromOpenAIStream(
|
||||
openaiStream: Stream<ChatCompletionChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createOpenAICompatibleStream(openaiStream, 'OpenAI', onComplete)
|
||||
export function buildResponsesInputFromMessages(messages: Message[]): ResponsesInputItem[] {
|
||||
const input: ResponsesInputItem[] = []
|
||||
|
||||
for (const message of messages) {
|
||||
if (message.role === 'tool' && message.tool_call_id) {
|
||||
input.push({
|
||||
type: 'function_call_output',
|
||||
call_id: message.tool_call_id,
|
||||
output: message.content ?? '',
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if (
|
||||
message.content &&
|
||||
(message.role === 'system' || message.role === 'user' || message.role === 'assistant')
|
||||
) {
|
||||
input.push({
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
})
|
||||
}
|
||||
|
||||
if (message.tool_calls?.length) {
|
||||
for (const toolCall of message.tool_calls) {
|
||||
input.push({
|
||||
type: 'function_call',
|
||||
call_id: toolCall.id,
|
||||
name: toolCall.function.name,
|
||||
arguments: toolCall.function.arguments,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return input
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts tool definitions to the Responses API format.
|
||||
*/
|
||||
export function convertToolsToResponses(tools: any[]): ResponsesToolDefinition[] {
|
||||
return tools
|
||||
.map((tool) => {
|
||||
const name = tool.function?.name ?? tool.name
|
||||
if (!name) {
|
||||
return null
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'function' as const,
|
||||
name,
|
||||
description: tool.function?.description ?? tool.description,
|
||||
parameters: tool.function?.parameters ?? tool.parameters,
|
||||
}
|
||||
})
|
||||
.filter(Boolean) as ResponsesToolDefinition[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts tool_choice to the Responses API format.
|
||||
*/
|
||||
export function toResponsesToolChoice(
|
||||
toolChoice:
|
||||
| 'auto'
|
||||
| 'none'
|
||||
| { type: 'function'; function?: { name: string }; name?: string }
|
||||
| { type: 'tool'; name: string }
|
||||
| { type: 'any'; any: { model: string; name: string } }
|
||||
| undefined
|
||||
): 'auto' | 'none' | { type: 'function'; name: string } | undefined {
|
||||
if (!toolChoice) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (typeof toolChoice === 'string') {
|
||||
return toolChoice
|
||||
}
|
||||
|
||||
if (toolChoice.type === 'function') {
|
||||
const name = toolChoice.name ?? toolChoice.function?.name
|
||||
return name ? { type: 'function', name } : undefined
|
||||
}
|
||||
|
||||
return 'auto'
|
||||
}
|
||||
|
||||
function extractTextFromMessageItem(item: any): string {
|
||||
if (!item) {
|
||||
return ''
|
||||
}
|
||||
|
||||
if (typeof item.content === 'string') {
|
||||
return item.content
|
||||
}
|
||||
|
||||
if (!Array.isArray(item.content)) {
|
||||
return ''
|
||||
}
|
||||
|
||||
const textParts: string[] = []
|
||||
for (const part of item.content) {
|
||||
if (!part || typeof part !== 'object') {
|
||||
continue
|
||||
}
|
||||
|
||||
if ((part.type === 'output_text' || part.type === 'text') && typeof part.text === 'string') {
|
||||
textParts.push(part.text)
|
||||
continue
|
||||
}
|
||||
|
||||
if (part.type === 'output_json') {
|
||||
if (typeof part.text === 'string') {
|
||||
textParts.push(part.text)
|
||||
} else if (part.json !== undefined) {
|
||||
textParts.push(JSON.stringify(part.json))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return textParts.join('')
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts plain text from Responses API output items.
|
||||
*/
|
||||
export function extractResponseText(output: unknown): string {
|
||||
if (!Array.isArray(output)) {
|
||||
return ''
|
||||
}
|
||||
|
||||
const textParts: string[] = []
|
||||
for (const item of output) {
|
||||
if (item?.type !== 'message') {
|
||||
continue
|
||||
}
|
||||
|
||||
const text = extractTextFromMessageItem(item)
|
||||
if (text) {
|
||||
textParts.push(text)
|
||||
}
|
||||
}
|
||||
|
||||
return textParts.join('')
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts Responses API output items into input items for subsequent calls.
|
||||
*/
|
||||
export function convertResponseOutputToInputItems(output: unknown): ResponsesInputItem[] {
|
||||
if (!Array.isArray(output)) {
|
||||
return []
|
||||
}
|
||||
|
||||
const items: ResponsesInputItem[] = []
|
||||
for (const item of output) {
|
||||
if (!item || typeof item !== 'object') {
|
||||
continue
|
||||
}
|
||||
|
||||
if (item.type === 'message') {
|
||||
const text = extractTextFromMessageItem(item)
|
||||
if (text) {
|
||||
items.push({
|
||||
role: 'assistant',
|
||||
content: text,
|
||||
})
|
||||
}
|
||||
|
||||
const toolCalls = Array.isArray(item.tool_calls) ? item.tool_calls : []
|
||||
for (const toolCall of toolCalls) {
|
||||
const callId = toolCall?.id
|
||||
const name = toolCall?.function?.name ?? toolCall?.name
|
||||
if (!callId || !name) {
|
||||
continue
|
||||
}
|
||||
|
||||
const argumentsValue =
|
||||
typeof toolCall?.function?.arguments === 'string'
|
||||
? toolCall.function.arguments
|
||||
: JSON.stringify(toolCall?.function?.arguments ?? {})
|
||||
|
||||
items.push({
|
||||
type: 'function_call',
|
||||
call_id: callId,
|
||||
name,
|
||||
arguments: argumentsValue,
|
||||
})
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if (item.type === 'function_call') {
|
||||
const callId = item.call_id ?? item.id
|
||||
const name = item.name ?? item.function?.name
|
||||
if (!callId || !name) {
|
||||
continue
|
||||
}
|
||||
|
||||
const argumentsValue =
|
||||
typeof item.arguments === 'string' ? item.arguments : JSON.stringify(item.arguments ?? {})
|
||||
|
||||
items.push({
|
||||
type: 'function_call',
|
||||
call_id: callId,
|
||||
name,
|
||||
arguments: argumentsValue,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts tool calls from Responses API output items.
|
||||
*/
|
||||
export function extractResponseToolCalls(output: unknown): ResponsesToolCall[] {
|
||||
if (!Array.isArray(output)) {
|
||||
return []
|
||||
}
|
||||
|
||||
const toolCalls: ResponsesToolCall[] = []
|
||||
|
||||
for (const item of output) {
|
||||
if (!item || typeof item !== 'object') {
|
||||
continue
|
||||
}
|
||||
|
||||
if (item.type === 'function_call') {
|
||||
const callId = item.call_id ?? item.id
|
||||
const name = item.name ?? item.function?.name
|
||||
if (!callId || !name) {
|
||||
continue
|
||||
}
|
||||
|
||||
const argumentsValue =
|
||||
typeof item.arguments === 'string' ? item.arguments : JSON.stringify(item.arguments ?? {})
|
||||
|
||||
toolCalls.push({
|
||||
id: callId,
|
||||
name,
|
||||
arguments: argumentsValue,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if (item.type === 'message' && Array.isArray(item.tool_calls)) {
|
||||
for (const toolCall of item.tool_calls) {
|
||||
const callId = toolCall?.id
|
||||
const name = toolCall?.function?.name ?? toolCall?.name
|
||||
if (!callId || !name) {
|
||||
continue
|
||||
}
|
||||
|
||||
const argumentsValue =
|
||||
typeof toolCall?.function?.arguments === 'string'
|
||||
? toolCall.function.arguments
|
||||
: JSON.stringify(toolCall?.function?.arguments ?? {})
|
||||
|
||||
toolCalls.push({
|
||||
id: callId,
|
||||
name,
|
||||
arguments: argumentsValue,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps Responses API usage data to prompt/completion token counts.
|
||||
*
|
||||
* Note: output_tokens is expected to include reasoning tokens; fall back to reasoning_tokens
|
||||
* when output_tokens is missing or zero.
|
||||
*/
|
||||
export function parseResponsesUsage(usage: any): ResponsesUsageTokens | undefined {
|
||||
if (!usage || typeof usage !== 'object') {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const inputTokens = Number(usage.input_tokens ?? 0)
|
||||
const outputTokens = Number(usage.output_tokens ?? 0)
|
||||
const cachedTokens = Number(usage.input_tokens_details?.cached_tokens ?? 0)
|
||||
const reasoningTokens = Number(usage.output_tokens_details?.reasoning_tokens ?? 0)
|
||||
const completionTokens = Math.max(outputTokens, reasoningTokens)
|
||||
const totalTokens = inputTokens + completionTokens
|
||||
|
||||
return {
|
||||
promptTokens: inputTokens,
|
||||
completionTokens,
|
||||
totalTokens,
|
||||
cachedTokens,
|
||||
reasoningTokens,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from a Responses API SSE stream.
|
||||
*/
|
||||
export function createReadableStreamFromResponses(
|
||||
response: Response,
|
||||
onComplete?: (content: string, usage?: ResponsesUsageTokens) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
let fullContent = ''
|
||||
let finalUsage: ResponsesUsageTokens | undefined
|
||||
let activeEventType: string | undefined
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
return new ReadableStream<Uint8Array>({
|
||||
async start(controller) {
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim()
|
||||
if (!trimmed) {
|
||||
continue
|
||||
}
|
||||
|
||||
if (trimmed.startsWith('event:')) {
|
||||
activeEventType = trimmed.slice(6).trim()
|
||||
continue
|
||||
}
|
||||
|
||||
if (!trimmed.startsWith('data:')) {
|
||||
continue
|
||||
}
|
||||
|
||||
const data = trimmed.slice(5).trim()
|
||||
if (data === '[DONE]') {
|
||||
continue
|
||||
}
|
||||
|
||||
let event: any
|
||||
try {
|
||||
event = JSON.parse(data)
|
||||
} catch (error) {
|
||||
logger.debug('Skipping non-JSON response stream chunk', {
|
||||
data: data.slice(0, 200),
|
||||
error,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
const eventType = event?.type ?? activeEventType
|
||||
|
||||
if (
|
||||
eventType === 'response.error' ||
|
||||
eventType === 'error' ||
|
||||
eventType === 'response.failed'
|
||||
) {
|
||||
const message = event?.error?.message || 'Responses API stream error'
|
||||
controller.error(new Error(message))
|
||||
return
|
||||
}
|
||||
|
||||
if (
|
||||
eventType === 'response.output_text.delta' ||
|
||||
eventType === 'response.output_json.delta'
|
||||
) {
|
||||
let deltaText = ''
|
||||
if (typeof event.delta === 'string') {
|
||||
deltaText = event.delta
|
||||
} else if (event.delta && typeof event.delta.text === 'string') {
|
||||
deltaText = event.delta.text
|
||||
} else if (event.delta && event.delta.json !== undefined) {
|
||||
deltaText = JSON.stringify(event.delta.json)
|
||||
} else if (event.json !== undefined) {
|
||||
deltaText = JSON.stringify(event.json)
|
||||
} else if (typeof event.text === 'string') {
|
||||
deltaText = event.text
|
||||
}
|
||||
|
||||
if (deltaText.length > 0) {
|
||||
fullContent += deltaText
|
||||
controller.enqueue(encoder.encode(deltaText))
|
||||
}
|
||||
}
|
||||
|
||||
if (eventType === 'response.completed') {
|
||||
finalUsage = parseResponsesUsage(event?.response?.usage ?? event?.usage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, finalUsage)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user