From 832a35354bd479776dd49f08575397ca58d2d550 Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Tue, 4 Feb 2025 15:31:47 -0800 Subject: [PATCH] Generalized provider implementation so adding other providers will be easier --- providers/openai/index.ts | 55 +++++--- providers/service.ts | 289 ++++++++++++++++---------------------- providers/types.ts | 18 +++ 3 files changed, 177 insertions(+), 185 deletions(-) diff --git a/providers/openai/index.ts b/providers/openai/index.ts index 4a1cdbdab1..baf76faadb 100644 --- a/providers/openai/index.ts +++ b/providers/openai/index.ts @@ -1,4 +1,4 @@ -import { ProviderConfig, FunctionCallResponse, ProviderToolConfig } from '../types' +import { ProviderConfig, FunctionCallResponse, ProviderToolConfig, ProviderRequest } from '../types' import { ToolConfig } from '@/tools/types' export const openaiProvider: ProviderConfig = { @@ -22,17 +22,8 @@ export const openaiProvider: ProviderConfig = { return tools.map(tool => ({ name: tool.id, - description: tool.description || '', - parameters: { - ...tool.parameters, - properties: Object.entries(tool.parameters.properties).reduce((acc, [key, value]) => ({ - ...acc, - [key]: { - ...value, - ...(key in tool.params && { default: tool.params[key] }) - } - }), {}) - } + description: tool.description, + parameters: tool.parameters })) }, @@ -42,19 +33,47 @@ export const openaiProvider: ProviderConfig = { throw new Error('No function call found in response') } - const args = typeof functionCall.arguments === 'string' - ? JSON.parse(functionCall.arguments) - : functionCall.arguments - const tool = tools?.find(t => t.id === functionCall.name) const toolParams = tool?.params || {} return { name: functionCall.name, arguments: { - ...toolParams, // First spread the stored params to ensure they're used as defaults - ...args // Then spread any overrides from the function call + ...toolParams, + ...JSON.parse(functionCall.arguments) } } + }, + + transformRequest: (request: ProviderRequest, functions?: any) => { + return { + model: request.model || 'gpt-4o', + messages: [ + { role: 'system', content: request.systemPrompt }, + ...(request.context ? [{ role: 'user', content: request.context }] : []), + ...(request.messages || []) + ], + temperature: request.temperature, + max_tokens: request.maxTokens, + ...(functions && { + functions, + function_call: 'auto' + }) + } + }, + + transformResponse: (response: any) => { + return { + content: response.choices?.[0]?.message?.content || '', + tokens: response.usage && { + prompt: response.usage.prompt_tokens, + completion: response.usage.completion_tokens, + total: response.usage.total_tokens + } + } + }, + + hasFunctionCall: (response: any) => { + return !!response.choices?.[0]?.message?.function_call } } diff --git a/providers/service.ts b/providers/service.ts index 5559502f73..678f0090c8 100644 --- a/providers/service.ts +++ b/providers/service.ts @@ -1,7 +1,6 @@ -import { ProviderConfig, ProviderRequest, ProviderResponse, Message } from './types' +import { ProviderConfig, ProviderRequest, ProviderResponse, TokenInfo } from './types' import { openaiProvider } from './openai' import { anthropicProvider } from './anthropic' -import { ToolConfig } from '@/tools/types' import { getTool, executeTool } from '@/tools' // Register providers @@ -20,52 +19,131 @@ export async function executeProviderRequest( throw new Error(`Provider not found: ${providerId}`) } - // Only transform tools if they are provided and non-empty + // Transform tools to provider-specific function format const functions = request.tools && request.tools.length > 0 ? provider.transformToolsToFunctions(request.tools) : undefined - // Base payload that's common across providers - const basePayload = { - model: request.model || provider.defaultModel, - messages: [ - { role: 'system' as const, content: request.systemPrompt }, - ...(request.context ? [{ role: 'user' as const, content: request.context }] : []) - ] as Message[], - temperature: request.temperature, - max_tokens: request.maxTokens + // Transform the request using provider-specific logic + const payload = provider.transformRequest(request, functions) + + // Make the initial API request through the proxy + let currentResponse = await makeProxyRequest(providerId, payload, request.apiKey) + let content = '' + let tokens: TokenInfo | undefined = undefined + let toolCalls = [] + let toolResults = [] + let currentMessages = [...(request.messages || [])] + let iterationCount = 0 + const MAX_ITERATIONS = 10 // Prevent infinite loops + + try { + while (iterationCount < MAX_ITERATIONS) { + console.log(`Processing iteration ${iterationCount + 1}`) + + // Transform the response using provider-specific logic + const transformedResponse = provider.transformResponse(currentResponse) + content = transformedResponse.content + + // Update tokens + if (transformedResponse.tokens) { + const newTokens: TokenInfo = { + prompt: (tokens?.prompt ?? 0) + (transformedResponse.tokens?.prompt ?? 0), + completion: (tokens?.completion ?? 0) + (transformedResponse.tokens?.completion ?? 0), + total: (tokens?.total ?? 0) + (transformedResponse.tokens?.total ?? 0) + } + tokens = newTokens + } + + // Check for function calls using provider-specific logic + const hasFunctionCall = provider.hasFunctionCall(currentResponse) + console.log('Has function call:', hasFunctionCall) + + if (!hasFunctionCall) { + console.log('No function call detected, breaking loop') + break + } + + // Transform function call using provider-specific logic + let functionCall + try { + functionCall = provider.transformFunctionCallResponse(currentResponse, request.tools) + } catch (error) { + console.log('Error transforming function call:', error) + break + } + + if (!functionCall) { + console.log('No function call after transformation, breaking loop') + break + } + + console.log('Function call:', functionCall.name) + + // Execute the tool + const tool = getTool(functionCall.name) + if (!tool) { + console.log(`Tool not found: ${functionCall.name}`) + break + } + + const result = await executeTool(functionCall.name, functionCall.arguments) + console.log('Tool execution result:', result.success) + + if (!result.success) { + console.log('Tool execution failed') + break + } + + toolResults.push(result.output) + toolCalls.push(functionCall) + + // Add the function call and result to messages + currentMessages.push({ + role: 'assistant', + content: null, + function_call: { + name: functionCall.name, + arguments: JSON.stringify(functionCall.arguments) + } + }) + currentMessages.push({ + role: 'function', + name: functionCall.name, + content: JSON.stringify(result.output) + }) + + // Prepare the next request + const nextPayload = provider.transformRequest({ + ...request, + messages: currentMessages + }, functions) + + // Make the next request + currentResponse = await makeProxyRequest(providerId, nextPayload, request.apiKey) + iterationCount++ + } + + if (iterationCount >= MAX_ITERATIONS) { + console.log('Max iterations reached, breaking loop') + } + } catch (error: any) { + console.error('Error executing tool:', error) + throw error } - // Provider-specific payload adjustments - let payload - switch (providerId) { - case 'openai': - payload = { - ...basePayload, - ...(functions && { - functions, - function_call: 'auto' - }) - } - break - case 'anthropic': - payload = { - ...basePayload, - system: request.systemPrompt, - messages: request.context ? [{ role: 'user', content: request.context }] : [], - ...(functions && { - tools: functions - }) - } - break - default: - payload = { - ...basePayload, - ...(functions && { functions }) - } + return { + content, + model: currentResponse.model, + tokens, + toolCalls: toolCalls.length > 0 ? toolCalls : undefined, + toolResults: toolResults.length > 0 ? toolResults : undefined } +} - // Make the API request through the proxy +async function makeProxyRequest(providerId: string, payload: any, apiKey: string) { + console.log('Making proxy request for provider:', providerId) + const response = await fetch('/api/proxy', { method: 'POST', headers: { @@ -75,7 +153,7 @@ export async function executeProviderRequest( toolId: `${providerId}/chat`, params: { ...payload, - apiKey: request.apiKey + apiKey } }) }) @@ -85,131 +163,8 @@ export async function executeProviderRequest( throw new Error(error.error || 'Provider API error') } - const { output: data } = await response.json() - - // Extract content and tokens based on provider - let content = '' - let tokens = undefined - - switch (providerId) { - case 'anthropic': - content = data.content?.[0]?.text || '' - tokens = { - prompt: data.usage?.input_tokens, - completion: data.usage?.output_tokens, - total: data.usage?.input_tokens + data.usage?.output_tokens - } - break - default: - content = data.choices?.[0]?.message?.content || '' - tokens = data.usage && { - prompt: data.usage.prompt_tokens, - completion: data.usage.completion_tokens, - total: data.usage.total_tokens - } - } - - // Check for function calls - let toolCalls = [] - let toolResults = [] - let currentMessages = [...basePayload.messages] - - try { - let currentResponse = data - let hasMoreCalls = true - - while (hasMoreCalls) { - const hasFunctionCall = - (providerId === 'openai' && currentResponse.choices?.[0]?.message?.function_call) || - (providerId === 'anthropic' && currentResponse.content?.some((item: any) => item.type === 'function_call')) - - if (!hasFunctionCall) { - // No more function calls, use the content from the current response - content = currentResponse.choices?.[0]?.message?.content || '' - hasMoreCalls = false - continue - } - - const functionCall = provider.transformFunctionCallResponse(currentResponse, request.tools) - if (!functionCall) { - hasMoreCalls = false - continue - } - - // Execute the tool - const tool = getTool(functionCall.name) - if (!tool) { - throw new Error(`Tool not found: ${functionCall.name}`) - } - - const result = await executeTool(functionCall.name, functionCall.arguments) - if (result.success) { - toolResults.push(result.output) - toolCalls.push(functionCall) - - // Add the assistant's function call and the function result to the message history - currentMessages.push({ - role: 'assistant', - content: null, - function_call: { - name: functionCall.name, - arguments: JSON.stringify(functionCall.arguments) - } - }) - currentMessages.push({ - role: 'function', - name: functionCall.name, - content: JSON.stringify(result.output) - }) - - // Make the next call through the proxy - const nextResponse = await fetch('/api/proxy', { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - toolId: `${providerId}/chat`, - params: { - ...basePayload, - messages: currentMessages, - ...(functions && { functions, function_call: 'auto' }), - apiKey: request.apiKey - } - }) - }) - - if (!nextResponse.ok) { - const error = await nextResponse.json() - throw new Error(error.error || 'Provider API error') - } - - const { output: nextData } = await nextResponse.json() - currentResponse = nextData - - // Update tokens - if (nextData.usage) { - tokens = { - prompt: (tokens?.prompt || 0) + nextData.usage.prompt_tokens, - completion: (tokens?.completion || 0) + nextData.usage.completion_tokens, - total: (tokens?.total || 0) + nextData.usage.total_tokens - } - } - } else { - hasMoreCalls = false - } - } - } catch (error: any) { - console.error('Error executing tool:', error) - throw error - } - - return { - content, - model: data.model, - tokens, - toolCalls: toolCalls.length > 0 ? toolCalls : undefined, - toolResults: toolResults.length > 0 ? toolResults : undefined - } + const { output } = await response.json() + console.log('Proxy request completed') + return output } diff --git a/providers/types.ts b/providers/types.ts index 849b98a8b4..6483a53398 100644 --- a/providers/types.ts +++ b/providers/types.ts @@ -1,5 +1,16 @@ import { ToolConfig } from '@/tools/types' +export interface TokenInfo { + prompt?: number + completion?: number + total?: number +} + +export interface TransformedResponse { + content: string + tokens?: TokenInfo +} + export interface ProviderConfig { id: string name: string @@ -15,6 +26,13 @@ export interface ProviderConfig { // Tool calling support transformToolsToFunctions: (tools: ProviderToolConfig[]) => any transformFunctionCallResponse: (response: any, tools?: ProviderToolConfig[]) => FunctionCallResponse + + // Provider-specific request/response transformations + transformRequest: (request: ProviderRequest, functions?: any) => any + transformResponse: (response: any) => TransformedResponse + + // Function to check if response contains a function call + hasFunctionCall: (response: any) => boolean // Internal state for tool name mapping _toolNameMapping?: Map