mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-08 22:48:14 -05:00
Generalized provider implementation so adding other providers will be easier
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import { ProviderConfig, FunctionCallResponse, ProviderToolConfig } from '../types'
|
||||
import { ProviderConfig, FunctionCallResponse, ProviderToolConfig, ProviderRequest } from '../types'
|
||||
import { ToolConfig } from '@/tools/types'
|
||||
|
||||
export const openaiProvider: ProviderConfig = {
|
||||
@@ -22,17 +22,8 @@ export const openaiProvider: ProviderConfig = {
|
||||
|
||||
return tools.map(tool => ({
|
||||
name: tool.id,
|
||||
description: tool.description || '',
|
||||
parameters: {
|
||||
...tool.parameters,
|
||||
properties: Object.entries(tool.parameters.properties).reduce((acc, [key, value]) => ({
|
||||
...acc,
|
||||
[key]: {
|
||||
...value,
|
||||
...(key in tool.params && { default: tool.params[key] })
|
||||
}
|
||||
}), {})
|
||||
}
|
||||
description: tool.description,
|
||||
parameters: tool.parameters
|
||||
}))
|
||||
},
|
||||
|
||||
@@ -42,19 +33,47 @@ export const openaiProvider: ProviderConfig = {
|
||||
throw new Error('No function call found in response')
|
||||
}
|
||||
|
||||
const args = typeof functionCall.arguments === 'string'
|
||||
? JSON.parse(functionCall.arguments)
|
||||
: functionCall.arguments
|
||||
|
||||
const tool = tools?.find(t => t.id === functionCall.name)
|
||||
const toolParams = tool?.params || {}
|
||||
|
||||
return {
|
||||
name: functionCall.name,
|
||||
arguments: {
|
||||
...toolParams, // First spread the stored params to ensure they're used as defaults
|
||||
...args // Then spread any overrides from the function call
|
||||
...toolParams,
|
||||
...JSON.parse(functionCall.arguments)
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
transformRequest: (request: ProviderRequest, functions?: any) => {
|
||||
return {
|
||||
model: request.model || 'gpt-4o',
|
||||
messages: [
|
||||
{ role: 'system', content: request.systemPrompt },
|
||||
...(request.context ? [{ role: 'user', content: request.context }] : []),
|
||||
...(request.messages || [])
|
||||
],
|
||||
temperature: request.temperature,
|
||||
max_tokens: request.maxTokens,
|
||||
...(functions && {
|
||||
functions,
|
||||
function_call: 'auto'
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
transformResponse: (response: any) => {
|
||||
return {
|
||||
content: response.choices?.[0]?.message?.content || '',
|
||||
tokens: response.usage && {
|
||||
prompt: response.usage.prompt_tokens,
|
||||
completion: response.usage.completion_tokens,
|
||||
total: response.usage.total_tokens
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
hasFunctionCall: (response: any) => {
|
||||
return !!response.choices?.[0]?.message?.function_call
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { ProviderConfig, ProviderRequest, ProviderResponse, Message } from './types'
|
||||
import { ProviderConfig, ProviderRequest, ProviderResponse, TokenInfo } from './types'
|
||||
import { openaiProvider } from './openai'
|
||||
import { anthropicProvider } from './anthropic'
|
||||
import { ToolConfig } from '@/tools/types'
|
||||
import { getTool, executeTool } from '@/tools'
|
||||
|
||||
// Register providers
|
||||
@@ -20,52 +19,131 @@ export async function executeProviderRequest(
|
||||
throw new Error(`Provider not found: ${providerId}`)
|
||||
}
|
||||
|
||||
// Only transform tools if they are provided and non-empty
|
||||
// Transform tools to provider-specific function format
|
||||
const functions = request.tools && request.tools.length > 0
|
||||
? provider.transformToolsToFunctions(request.tools)
|
||||
: undefined
|
||||
|
||||
// Base payload that's common across providers
|
||||
const basePayload = {
|
||||
model: request.model || provider.defaultModel,
|
||||
messages: [
|
||||
{ role: 'system' as const, content: request.systemPrompt },
|
||||
...(request.context ? [{ role: 'user' as const, content: request.context }] : [])
|
||||
] as Message[],
|
||||
temperature: request.temperature,
|
||||
max_tokens: request.maxTokens
|
||||
// Transform the request using provider-specific logic
|
||||
const payload = provider.transformRequest(request, functions)
|
||||
|
||||
// Make the initial API request through the proxy
|
||||
let currentResponse = await makeProxyRequest(providerId, payload, request.apiKey)
|
||||
let content = ''
|
||||
let tokens: TokenInfo | undefined = undefined
|
||||
let toolCalls = []
|
||||
let toolResults = []
|
||||
let currentMessages = [...(request.messages || [])]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
console.log(`Processing iteration ${iterationCount + 1}`)
|
||||
|
||||
// Transform the response using provider-specific logic
|
||||
const transformedResponse = provider.transformResponse(currentResponse)
|
||||
content = transformedResponse.content
|
||||
|
||||
// Update tokens
|
||||
if (transformedResponse.tokens) {
|
||||
const newTokens: TokenInfo = {
|
||||
prompt: (tokens?.prompt ?? 0) + (transformedResponse.tokens?.prompt ?? 0),
|
||||
completion: (tokens?.completion ?? 0) + (transformedResponse.tokens?.completion ?? 0),
|
||||
total: (tokens?.total ?? 0) + (transformedResponse.tokens?.total ?? 0)
|
||||
}
|
||||
tokens = newTokens
|
||||
}
|
||||
|
||||
// Check for function calls using provider-specific logic
|
||||
const hasFunctionCall = provider.hasFunctionCall(currentResponse)
|
||||
console.log('Has function call:', hasFunctionCall)
|
||||
|
||||
if (!hasFunctionCall) {
|
||||
console.log('No function call detected, breaking loop')
|
||||
break
|
||||
}
|
||||
|
||||
// Transform function call using provider-specific logic
|
||||
let functionCall
|
||||
try {
|
||||
functionCall = provider.transformFunctionCallResponse(currentResponse, request.tools)
|
||||
} catch (error) {
|
||||
console.log('Error transforming function call:', error)
|
||||
break
|
||||
}
|
||||
|
||||
if (!functionCall) {
|
||||
console.log('No function call after transformation, breaking loop')
|
||||
break
|
||||
}
|
||||
|
||||
console.log('Function call:', functionCall.name)
|
||||
|
||||
// Execute the tool
|
||||
const tool = getTool(functionCall.name)
|
||||
if (!tool) {
|
||||
console.log(`Tool not found: ${functionCall.name}`)
|
||||
break
|
||||
}
|
||||
|
||||
const result = await executeTool(functionCall.name, functionCall.arguments)
|
||||
console.log('Tool execution result:', result.success)
|
||||
|
||||
if (!result.success) {
|
||||
console.log('Tool execution failed')
|
||||
break
|
||||
}
|
||||
|
||||
toolResults.push(result.output)
|
||||
toolCalls.push(functionCall)
|
||||
|
||||
// Add the function call and result to messages
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
function_call: {
|
||||
name: functionCall.name,
|
||||
arguments: JSON.stringify(functionCall.arguments)
|
||||
}
|
||||
})
|
||||
currentMessages.push({
|
||||
role: 'function',
|
||||
name: functionCall.name,
|
||||
content: JSON.stringify(result.output)
|
||||
})
|
||||
|
||||
// Prepare the next request
|
||||
const nextPayload = provider.transformRequest({
|
||||
...request,
|
||||
messages: currentMessages
|
||||
}, functions)
|
||||
|
||||
// Make the next request
|
||||
currentResponse = await makeProxyRequest(providerId, nextPayload, request.apiKey)
|
||||
iterationCount++
|
||||
}
|
||||
|
||||
if (iterationCount >= MAX_ITERATIONS) {
|
||||
console.log('Max iterations reached, breaking loop')
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error('Error executing tool:', error)
|
||||
throw error
|
||||
}
|
||||
|
||||
// Provider-specific payload adjustments
|
||||
let payload
|
||||
switch (providerId) {
|
||||
case 'openai':
|
||||
payload = {
|
||||
...basePayload,
|
||||
...(functions && {
|
||||
functions,
|
||||
function_call: 'auto'
|
||||
})
|
||||
}
|
||||
break
|
||||
case 'anthropic':
|
||||
payload = {
|
||||
...basePayload,
|
||||
system: request.systemPrompt,
|
||||
messages: request.context ? [{ role: 'user', content: request.context }] : [],
|
||||
...(functions && {
|
||||
tools: functions
|
||||
})
|
||||
}
|
||||
break
|
||||
default:
|
||||
payload = {
|
||||
...basePayload,
|
||||
...(functions && { functions })
|
||||
}
|
||||
return {
|
||||
content,
|
||||
model: currentResponse.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined
|
||||
}
|
||||
}
|
||||
|
||||
// Make the API request through the proxy
|
||||
async function makeProxyRequest(providerId: string, payload: any, apiKey: string) {
|
||||
console.log('Making proxy request for provider:', providerId)
|
||||
|
||||
const response = await fetch('/api/proxy', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -75,7 +153,7 @@ export async function executeProviderRequest(
|
||||
toolId: `${providerId}/chat`,
|
||||
params: {
|
||||
...payload,
|
||||
apiKey: request.apiKey
|
||||
apiKey
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -85,131 +163,8 @@ export async function executeProviderRequest(
|
||||
throw new Error(error.error || 'Provider API error')
|
||||
}
|
||||
|
||||
const { output: data } = await response.json()
|
||||
|
||||
// Extract content and tokens based on provider
|
||||
let content = ''
|
||||
let tokens = undefined
|
||||
|
||||
switch (providerId) {
|
||||
case 'anthropic':
|
||||
content = data.content?.[0]?.text || ''
|
||||
tokens = {
|
||||
prompt: data.usage?.input_tokens,
|
||||
completion: data.usage?.output_tokens,
|
||||
total: data.usage?.input_tokens + data.usage?.output_tokens
|
||||
}
|
||||
break
|
||||
default:
|
||||
content = data.choices?.[0]?.message?.content || ''
|
||||
tokens = data.usage && {
|
||||
prompt: data.usage.prompt_tokens,
|
||||
completion: data.usage.completion_tokens,
|
||||
total: data.usage.total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
// Check for function calls
|
||||
let toolCalls = []
|
||||
let toolResults = []
|
||||
let currentMessages = [...basePayload.messages]
|
||||
|
||||
try {
|
||||
let currentResponse = data
|
||||
let hasMoreCalls = true
|
||||
|
||||
while (hasMoreCalls) {
|
||||
const hasFunctionCall =
|
||||
(providerId === 'openai' && currentResponse.choices?.[0]?.message?.function_call) ||
|
||||
(providerId === 'anthropic' && currentResponse.content?.some((item: any) => item.type === 'function_call'))
|
||||
|
||||
if (!hasFunctionCall) {
|
||||
// No more function calls, use the content from the current response
|
||||
content = currentResponse.choices?.[0]?.message?.content || ''
|
||||
hasMoreCalls = false
|
||||
continue
|
||||
}
|
||||
|
||||
const functionCall = provider.transformFunctionCallResponse(currentResponse, request.tools)
|
||||
if (!functionCall) {
|
||||
hasMoreCalls = false
|
||||
continue
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
const tool = getTool(functionCall.name)
|
||||
if (!tool) {
|
||||
throw new Error(`Tool not found: ${functionCall.name}`)
|
||||
}
|
||||
|
||||
const result = await executeTool(functionCall.name, functionCall.arguments)
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
toolCalls.push(functionCall)
|
||||
|
||||
// Add the assistant's function call and the function result to the message history
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
function_call: {
|
||||
name: functionCall.name,
|
||||
arguments: JSON.stringify(functionCall.arguments)
|
||||
}
|
||||
})
|
||||
currentMessages.push({
|
||||
role: 'function',
|
||||
name: functionCall.name,
|
||||
content: JSON.stringify(result.output)
|
||||
})
|
||||
|
||||
// Make the next call through the proxy
|
||||
const nextResponse = await fetch('/api/proxy', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
toolId: `${providerId}/chat`,
|
||||
params: {
|
||||
...basePayload,
|
||||
messages: currentMessages,
|
||||
...(functions && { functions, function_call: 'auto' }),
|
||||
apiKey: request.apiKey
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
if (!nextResponse.ok) {
|
||||
const error = await nextResponse.json()
|
||||
throw new Error(error.error || 'Provider API error')
|
||||
}
|
||||
|
||||
const { output: nextData } = await nextResponse.json()
|
||||
currentResponse = nextData
|
||||
|
||||
// Update tokens
|
||||
if (nextData.usage) {
|
||||
tokens = {
|
||||
prompt: (tokens?.prompt || 0) + nextData.usage.prompt_tokens,
|
||||
completion: (tokens?.completion || 0) + nextData.usage.completion_tokens,
|
||||
total: (tokens?.total || 0) + nextData.usage.total_tokens
|
||||
}
|
||||
}
|
||||
} else {
|
||||
hasMoreCalls = false
|
||||
}
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error('Error executing tool:', error)
|
||||
throw error
|
||||
}
|
||||
|
||||
return {
|
||||
content,
|
||||
model: data.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined
|
||||
}
|
||||
const { output } = await response.json()
|
||||
console.log('Proxy request completed')
|
||||
return output
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,16 @@
|
||||
import { ToolConfig } from '@/tools/types'
|
||||
|
||||
export interface TokenInfo {
|
||||
prompt?: number
|
||||
completion?: number
|
||||
total?: number
|
||||
}
|
||||
|
||||
export interface TransformedResponse {
|
||||
content: string
|
||||
tokens?: TokenInfo
|
||||
}
|
||||
|
||||
export interface ProviderConfig {
|
||||
id: string
|
||||
name: string
|
||||
@@ -15,6 +26,13 @@ export interface ProviderConfig {
|
||||
// Tool calling support
|
||||
transformToolsToFunctions: (tools: ProviderToolConfig[]) => any
|
||||
transformFunctionCallResponse: (response: any, tools?: ProviderToolConfig[]) => FunctionCallResponse
|
||||
|
||||
// Provider-specific request/response transformations
|
||||
transformRequest: (request: ProviderRequest, functions?: any) => any
|
||||
transformResponse: (response: any) => TransformedResponse
|
||||
|
||||
// Function to check if response contains a function call
|
||||
hasFunctionCall: (response: any) => boolean
|
||||
|
||||
// Internal state for tool name mapping
|
||||
_toolNameMapping?: Map<string, string>
|
||||
|
||||
Reference in New Issue
Block a user