Generalized provider implementation so adding other providers will be easier

This commit is contained in:
Waleed Latif
2025-02-04 15:31:47 -08:00
parent 47b4984376
commit 832a35354b
3 changed files with 177 additions and 185 deletions

View File

@@ -1,4 +1,4 @@
import { ProviderConfig, FunctionCallResponse, ProviderToolConfig } from '../types'
import { ProviderConfig, FunctionCallResponse, ProviderToolConfig, ProviderRequest } from '../types'
import { ToolConfig } from '@/tools/types'
export const openaiProvider: ProviderConfig = {
@@ -22,17 +22,8 @@ export const openaiProvider: ProviderConfig = {
return tools.map(tool => ({
name: tool.id,
description: tool.description || '',
parameters: {
...tool.parameters,
properties: Object.entries(tool.parameters.properties).reduce((acc, [key, value]) => ({
...acc,
[key]: {
...value,
...(key in tool.params && { default: tool.params[key] })
}
}), {})
}
description: tool.description,
parameters: tool.parameters
}))
},
@@ -42,19 +33,47 @@ export const openaiProvider: ProviderConfig = {
throw new Error('No function call found in response')
}
const args = typeof functionCall.arguments === 'string'
? JSON.parse(functionCall.arguments)
: functionCall.arguments
const tool = tools?.find(t => t.id === functionCall.name)
const toolParams = tool?.params || {}
return {
name: functionCall.name,
arguments: {
...toolParams, // First spread the stored params to ensure they're used as defaults
...args // Then spread any overrides from the function call
...toolParams,
...JSON.parse(functionCall.arguments)
}
}
},
transformRequest: (request: ProviderRequest, functions?: any) => {
return {
model: request.model || 'gpt-4o',
messages: [
{ role: 'system', content: request.systemPrompt },
...(request.context ? [{ role: 'user', content: request.context }] : []),
...(request.messages || [])
],
temperature: request.temperature,
max_tokens: request.maxTokens,
...(functions && {
functions,
function_call: 'auto'
})
}
},
transformResponse: (response: any) => {
return {
content: response.choices?.[0]?.message?.content || '',
tokens: response.usage && {
prompt: response.usage.prompt_tokens,
completion: response.usage.completion_tokens,
total: response.usage.total_tokens
}
}
},
hasFunctionCall: (response: any) => {
return !!response.choices?.[0]?.message?.function_call
}
}

View File

@@ -1,7 +1,6 @@
import { ProviderConfig, ProviderRequest, ProviderResponse, Message } from './types'
import { ProviderConfig, ProviderRequest, ProviderResponse, TokenInfo } from './types'
import { openaiProvider } from './openai'
import { anthropicProvider } from './anthropic'
import { ToolConfig } from '@/tools/types'
import { getTool, executeTool } from '@/tools'
// Register providers
@@ -20,52 +19,131 @@ export async function executeProviderRequest(
throw new Error(`Provider not found: ${providerId}`)
}
// Only transform tools if they are provided and non-empty
// Transform tools to provider-specific function format
const functions = request.tools && request.tools.length > 0
? provider.transformToolsToFunctions(request.tools)
: undefined
// Base payload that's common across providers
const basePayload = {
model: request.model || provider.defaultModel,
messages: [
{ role: 'system' as const, content: request.systemPrompt },
...(request.context ? [{ role: 'user' as const, content: request.context }] : [])
] as Message[],
temperature: request.temperature,
max_tokens: request.maxTokens
// Transform the request using provider-specific logic
const payload = provider.transformRequest(request, functions)
// Make the initial API request through the proxy
let currentResponse = await makeProxyRequest(providerId, payload, request.apiKey)
let content = ''
let tokens: TokenInfo | undefined = undefined
let toolCalls = []
let toolResults = []
let currentMessages = [...(request.messages || [])]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
try {
while (iterationCount < MAX_ITERATIONS) {
console.log(`Processing iteration ${iterationCount + 1}`)
// Transform the response using provider-specific logic
const transformedResponse = provider.transformResponse(currentResponse)
content = transformedResponse.content
// Update tokens
if (transformedResponse.tokens) {
const newTokens: TokenInfo = {
prompt: (tokens?.prompt ?? 0) + (transformedResponse.tokens?.prompt ?? 0),
completion: (tokens?.completion ?? 0) + (transformedResponse.tokens?.completion ?? 0),
total: (tokens?.total ?? 0) + (transformedResponse.tokens?.total ?? 0)
}
tokens = newTokens
}
// Check for function calls using provider-specific logic
const hasFunctionCall = provider.hasFunctionCall(currentResponse)
console.log('Has function call:', hasFunctionCall)
if (!hasFunctionCall) {
console.log('No function call detected, breaking loop')
break
}
// Transform function call using provider-specific logic
let functionCall
try {
functionCall = provider.transformFunctionCallResponse(currentResponse, request.tools)
} catch (error) {
console.log('Error transforming function call:', error)
break
}
if (!functionCall) {
console.log('No function call after transformation, breaking loop')
break
}
console.log('Function call:', functionCall.name)
// Execute the tool
const tool = getTool(functionCall.name)
if (!tool) {
console.log(`Tool not found: ${functionCall.name}`)
break
}
const result = await executeTool(functionCall.name, functionCall.arguments)
console.log('Tool execution result:', result.success)
if (!result.success) {
console.log('Tool execution failed')
break
}
toolResults.push(result.output)
toolCalls.push(functionCall)
// Add the function call and result to messages
currentMessages.push({
role: 'assistant',
content: null,
function_call: {
name: functionCall.name,
arguments: JSON.stringify(functionCall.arguments)
}
})
currentMessages.push({
role: 'function',
name: functionCall.name,
content: JSON.stringify(result.output)
})
// Prepare the next request
const nextPayload = provider.transformRequest({
...request,
messages: currentMessages
}, functions)
// Make the next request
currentResponse = await makeProxyRequest(providerId, nextPayload, request.apiKey)
iterationCount++
}
if (iterationCount >= MAX_ITERATIONS) {
console.log('Max iterations reached, breaking loop')
}
} catch (error: any) {
console.error('Error executing tool:', error)
throw error
}
// Provider-specific payload adjustments
let payload
switch (providerId) {
case 'openai':
payload = {
...basePayload,
...(functions && {
functions,
function_call: 'auto'
})
}
break
case 'anthropic':
payload = {
...basePayload,
system: request.systemPrompt,
messages: request.context ? [{ role: 'user', content: request.context }] : [],
...(functions && {
tools: functions
})
}
break
default:
payload = {
...basePayload,
...(functions && { functions })
}
return {
content,
model: currentResponse.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined
}
}
// Make the API request through the proxy
async function makeProxyRequest(providerId: string, payload: any, apiKey: string) {
console.log('Making proxy request for provider:', providerId)
const response = await fetch('/api/proxy', {
method: 'POST',
headers: {
@@ -75,7 +153,7 @@ export async function executeProviderRequest(
toolId: `${providerId}/chat`,
params: {
...payload,
apiKey: request.apiKey
apiKey
}
})
})
@@ -85,131 +163,8 @@ export async function executeProviderRequest(
throw new Error(error.error || 'Provider API error')
}
const { output: data } = await response.json()
// Extract content and tokens based on provider
let content = ''
let tokens = undefined
switch (providerId) {
case 'anthropic':
content = data.content?.[0]?.text || ''
tokens = {
prompt: data.usage?.input_tokens,
completion: data.usage?.output_tokens,
total: data.usage?.input_tokens + data.usage?.output_tokens
}
break
default:
content = data.choices?.[0]?.message?.content || ''
tokens = data.usage && {
prompt: data.usage.prompt_tokens,
completion: data.usage.completion_tokens,
total: data.usage.total_tokens
}
}
// Check for function calls
let toolCalls = []
let toolResults = []
let currentMessages = [...basePayload.messages]
try {
let currentResponse = data
let hasMoreCalls = true
while (hasMoreCalls) {
const hasFunctionCall =
(providerId === 'openai' && currentResponse.choices?.[0]?.message?.function_call) ||
(providerId === 'anthropic' && currentResponse.content?.some((item: any) => item.type === 'function_call'))
if (!hasFunctionCall) {
// No more function calls, use the content from the current response
content = currentResponse.choices?.[0]?.message?.content || ''
hasMoreCalls = false
continue
}
const functionCall = provider.transformFunctionCallResponse(currentResponse, request.tools)
if (!functionCall) {
hasMoreCalls = false
continue
}
// Execute the tool
const tool = getTool(functionCall.name)
if (!tool) {
throw new Error(`Tool not found: ${functionCall.name}`)
}
const result = await executeTool(functionCall.name, functionCall.arguments)
if (result.success) {
toolResults.push(result.output)
toolCalls.push(functionCall)
// Add the assistant's function call and the function result to the message history
currentMessages.push({
role: 'assistant',
content: null,
function_call: {
name: functionCall.name,
arguments: JSON.stringify(functionCall.arguments)
}
})
currentMessages.push({
role: 'function',
name: functionCall.name,
content: JSON.stringify(result.output)
})
// Make the next call through the proxy
const nextResponse = await fetch('/api/proxy', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
toolId: `${providerId}/chat`,
params: {
...basePayload,
messages: currentMessages,
...(functions && { functions, function_call: 'auto' }),
apiKey: request.apiKey
}
})
})
if (!nextResponse.ok) {
const error = await nextResponse.json()
throw new Error(error.error || 'Provider API error')
}
const { output: nextData } = await nextResponse.json()
currentResponse = nextData
// Update tokens
if (nextData.usage) {
tokens = {
prompt: (tokens?.prompt || 0) + nextData.usage.prompt_tokens,
completion: (tokens?.completion || 0) + nextData.usage.completion_tokens,
total: (tokens?.total || 0) + nextData.usage.total_tokens
}
}
} else {
hasMoreCalls = false
}
}
} catch (error: any) {
console.error('Error executing tool:', error)
throw error
}
return {
content,
model: data.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined
}
const { output } = await response.json()
console.log('Proxy request completed')
return output
}

View File

@@ -1,5 +1,16 @@
import { ToolConfig } from '@/tools/types'
export interface TokenInfo {
prompt?: number
completion?: number
total?: number
}
export interface TransformedResponse {
content: string
tokens?: TokenInfo
}
export interface ProviderConfig {
id: string
name: string
@@ -15,6 +26,13 @@ export interface ProviderConfig {
// Tool calling support
transformToolsToFunctions: (tools: ProviderToolConfig[]) => any
transformFunctionCallResponse: (response: any, tools?: ProviderToolConfig[]) => FunctionCallResponse
// Provider-specific request/response transformations
transformRequest: (request: ProviderRequest, functions?: any) => any
transformResponse: (response: any) => TransformedResponse
// Function to check if response contains a function call
hasFunctionCall: (response: any) => boolean
// Internal state for tool name mapping
_toolNameMapping?: Map<string, string>