feat(providers): added openai-sdk in place of custom business logic

This commit is contained in:
Waleed Latif
2025-02-26 19:18:23 -08:00
parent fa67494fe7
commit a08efa3d91
7 changed files with 443 additions and 99 deletions

8
package-lock.json generated
View File

@@ -34,7 +34,7 @@
"lodash.debounce": "^4.0.8",
"lucide-react": "^0.469.0",
"next": "15.1.3",
"openai": "^4.83.0",
"openai": "^4.85.4",
"postgres": "^3.4.5",
"prismjs": "^1.29.0",
"react": "^18.2.0",
@@ -8653,9 +8653,9 @@
}
},
"node_modules/openai": {
"version": "4.83.0",
"resolved": "https://registry.npmjs.org/openai/-/openai-4.83.0.tgz",
"integrity": "sha512-fmTsqud0uTtRKsPC7L8Lu55dkaTwYucqncDHzVvO64DKOpNTuiYwjbR/nVgpapXuYy8xSnhQQPUm+3jQaxICgw==",
"version": "4.85.4",
"resolved": "https://registry.npmjs.org/openai/-/openai-4.85.4.tgz",
"integrity": "sha512-Nki51PBSu+Aryo7WKbdXvfm0X/iKkQS2fq3O0Uqb/O3b4exOZFid2te1BZ52bbO5UwxQZ5eeHJDCTqtrJLPw0w==",
"license": "Apache-2.0",
"dependencies": {
"@types/node": "^18.11.18",

View File

@@ -43,7 +43,7 @@
"lodash.debounce": "^4.0.8",
"lucide-react": "^0.469.0",
"next": "15.1.3",
"openai": "^4.83.0",
"openai": "^4.85.4",
"postgres": "^3.4.5",
"prismjs": "^1.29.0",
"react": "^18.2.0",

View File

@@ -1,18 +1,131 @@
import { ToolConfig } from '@/tools/types'
import { FunctionCallResponse, ProviderConfig, ProviderRequest, ProviderToolConfig } from '../types'
import { Cerebras } from '@cerebras/cerebras_cloud_sdk'
import {
FunctionCallResponse,
ProviderConfig,
ProviderRequest,
ProviderResponse,
ProviderToolConfig,
} from '../types'
export const cerebrasProvider: ProviderConfig = {
id: 'cerebras',
name: 'Cerebras',
description: "Cerebras' Llama models",
description: 'Cerebras Cloud LLMs',
version: '1.0.0',
models: ['llama-3.3-70b'],
defaultModel: 'llama-3.3-70b',
implementationType: 'sdk',
// Since we're using the SDK directly, we'll set these to empty values
// They won't be used since we'll handle the execution locally
baseUrl: '',
headers: (apiKey: string) => ({}),
// SDK-based implementation
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
try {
const client = new Cerebras({
apiKey: request.apiKey,
})
// Start with an empty array for all messages
const allMessages = []
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push({
role: 'system',
content: request.systemPrompt,
})
}
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
content: request.context,
})
}
// Add remaining messages
if (request.messages) {
allMessages.push(...request.messages)
}
// Transform tools to Cerebras format if provided
const functions = request.tools?.length
? request.tools.map((tool) => ({
name: tool.id,
description: tool.description,
parameters: tool.parameters,
}))
: undefined
// Build the request payload
const payload: any = {
model: request.model || 'llama-3.3-70b',
messages: allMessages,
}
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add functions if provided
if (functions?.length) {
payload.functions = functions
payload.function_call = 'auto'
}
// Add local execution flag if specified
if (request.local_execution) {
payload.local_execution = true
}
// Execute the request using the SDK
const response = (await client.chat.completions.create(payload)) as any
// Extract content and token info
const content = response.choices?.[0]?.message?.content || ''
const tokens = {
prompt: response.usage?.prompt_tokens || 0,
completion: response.usage?.completion_tokens || 0,
total: response.usage?.total_tokens || 0,
}
// Check for function calls
const functionCall = response.choices?.[0]?.message?.function_call
let toolCalls = undefined
if (functionCall) {
const tool = request.tools?.find((t) => t.id === functionCall.name)
const toolParams = tool?.params || {}
toolCalls = [
{
name: functionCall.name,
arguments: {
...toolParams,
...JSON.parse(functionCall.arguments),
},
},
]
}
// Return the response in the expected format
return {
content,
model: request.model,
tokens,
toolCalls,
}
} catch (error: any) {
console.error('Error executing Cerebras request:', error)
throw new Error(`Cerebras API error: ${error.message}`)
}
},
// These are still needed for backward compatibility
baseUrl: 'https://api.cerebras.cloud/v1/chat/completions',
headers: (apiKey: string) => ({
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
}),
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
if (!tools || tools.length === 0) {
@@ -76,10 +189,9 @@ export const cerebrasProvider: ProviderConfig = {
const payload: any = {
model: request.model || 'llama-3.3-70b',
messages: allMessages,
local_execution: true, // Enable local execution with the SDK
}
// Add standard parameters
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
@@ -89,6 +201,11 @@ export const cerebrasProvider: ProviderConfig = {
payload.function_call = 'auto'
}
// Add local execution flag if specified
if (request.local_execution) {
payload.local_execution = true
}
return payload
},

View File

@@ -1,56 +0,0 @@
import { Cerebras } from '@cerebras/cerebras_cloud_sdk'
import { ProviderRequest, ProviderResponse } from '../types'
import { cerebrasProvider } from './index'
// This function will be used to execute Cerebras requests locally using their SDK
export async function executeCerebrasRequest(request: ProviderRequest): Promise<ProviderResponse> {
try {
const client = new Cerebras({
apiKey: request.apiKey,
})
// Transform the request using the provider's transformRequest method
const payload = cerebrasProvider.transformRequest(request)
// Prepare the messages for the SDK
const messages = payload.messages
// Prepare the options for the SDK
const options = {
temperature: payload.temperature,
max_tokens: payload.max_tokens,
functions: payload.functions,
function_call: payload.function_call,
}
// Execute the request using the SDK
const response = await client.chat.completions.create({
model: payload.model,
messages,
...options,
})
// Transform the response using the provider's transformResponse method
const transformedResponse = cerebrasProvider.transformResponse(response)
// Check for function calls
const hasFunctionCall = cerebrasProvider.hasFunctionCall(response)
let toolCalls = undefined
if (hasFunctionCall) {
const functionCall = cerebrasProvider.transformFunctionCallResponse(response, request.tools)
toolCalls = [functionCall]
}
// Return the response in the expected format
return {
content: transformedResponse.content,
model: request.model,
tokens: transformedResponse.tokens,
toolCalls,
}
} catch (error: any) {
console.error('Error executing Cerebras request:', error)
throw new Error(`Cerebras API error: ${error.message}`)
}
}

View File

@@ -1,5 +1,11 @@
import { ToolConfig } from '@/tools/types'
import { FunctionCallResponse, ProviderConfig, ProviderRequest, ProviderToolConfig } from '../types'
import OpenAI from 'openai'
import {
FunctionCallResponse,
ProviderConfig,
ProviderRequest,
ProviderResponse,
ProviderToolConfig,
} from '../types'
export const openaiProvider: ProviderConfig = {
id: 'openai',
@@ -8,7 +14,184 @@ export const openaiProvider: ProviderConfig = {
version: '1.0.0',
models: ['gpt-4o', 'o1', 'o3-mini'],
defaultModel: 'gpt-4o',
implementationType: 'sdk',
// SDK-based implementation
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
if (!request.apiKey) {
throw new Error('API key is required for OpenAI')
}
const openai = new OpenAI({
apiKey: request.apiKey,
dangerouslyAllowBrowser: true,
})
// Start with an empty array for all messages
const allMessages = []
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push({
role: 'system',
content: request.systemPrompt,
})
}
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
content: request.context,
})
}
// Add remaining messages
if (request.messages) {
allMessages.push(...request.messages)
}
// Transform tools to OpenAI format if provided
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: tool.parameters,
},
}))
: undefined
// Build the request payload
const payload: any = {
model: request.model || 'gpt-4o',
messages: allMessages,
}
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add response format for structured output if specified
if (request.responseFormat) {
payload.response_format = { type: 'json_object' }
}
// Add tools if provided
if (tools?.length) {
payload.tools = tools
payload.tool_choice = 'auto'
}
// Make the initial API request
let currentResponse = await openai.chat.completions.create(payload)
let content = currentResponse.choices[0]?.message?.content || ''
let tokens = {
prompt: currentResponse.usage?.prompt_tokens || 0,
completion: currentResponse.usage?.completion_tokens || 0,
total: currentResponse.usage?.total_tokens || 0,
}
let toolCalls = []
let toolResults = []
let currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
try {
while (iterationCount < MAX_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
// Process each tool call
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
// Execute the tool (this will need to be imported from your tools system)
const { executeTool } = await import('@/tools')
const result = await executeTool(toolName, toolArgs)
if (!result.success) continue
toolResults.push(result.output)
toolCalls.push({
name: toolName,
arguments: toolArgs,
})
// Add the tool call and result to messages
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(result.output),
})
} catch (error) {
console.error('Error processing tool call:', error)
}
}
// Make the next request with updated messages
const nextPayload = {
...payload,
messages: currentMessages,
}
// Make the next request
currentResponse = await openai.chat.completions.create(nextPayload)
// Update content if we have a text response
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
// Update token counts
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.prompt_tokens || 0
tokens.completion += currentResponse.usage.completion_tokens || 0
tokens.total += currentResponse.usage.total_tokens || 0
}
iterationCount++
}
} catch (error) {
console.error('Error in OpenAI request:', error)
throw error
}
return {
content,
model: request.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
}
},
// These are still needed for backward compatibility
baseUrl: 'https://api.openai.com/v1/chat/completions',
headers: (apiKey: string) => ({
'Content-Type': 'application/json',
@@ -21,9 +204,12 @@ export const openaiProvider: ProviderConfig = {
}
return tools.map((tool) => ({
name: tool.id,
description: tool.description,
parameters: tool.parameters,
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: tool.parameters,
},
}))
},
@@ -31,6 +217,26 @@ export const openaiProvider: ProviderConfig = {
response: any,
tools?: ProviderToolConfig[]
): FunctionCallResponse => {
// Handle SDK response format
if (response.choices?.[0]?.message?.tool_calls) {
const toolCall = response.choices[0].message.tool_calls[0]
if (!toolCall) {
throw new Error('No tool call found in response')
}
const tool = tools?.find((t) => t.id === toolCall.function.name)
const toolParams = tool?.params || {}
return {
name: toolCall.function.name,
arguments: {
...toolParams,
...JSON.parse(toolCall.function.arguments),
},
}
}
// Handle legacy function_call format for backward compatibility
const functionCall = response.choices?.[0]?.message?.function_call
if (!functionCall) {
throw new Error('No function call found in response')
@@ -134,9 +340,10 @@ export const openaiProvider: ProviderConfig = {
// Add function calling support (supported by all models)
if (functions) {
payload.functions = functions
payload.function_call = 'auto'
payload.tools = functions
payload.tool_choice = 'auto'
}
return payload
},
@@ -163,6 +370,10 @@ export const openaiProvider: ProviderConfig = {
},
hasFunctionCall: (response: any) => {
return !!response.choices?.[0]?.message?.function_call
return (
!!response.choices?.[0]?.message?.function_call ||
(response.choices?.[0]?.message?.tool_calls &&
response.choices[0].message.tool_calls.length > 0)
)
},
}

View File

@@ -1,5 +1,5 @@
import OpenAI from 'openai'
import { executeTool, getTool } from '@/tools'
import { executeCerebrasRequest } from './cerebras/service'
import { getProvider } from './registry'
import { ProviderRequest, ProviderResponse, TokenInfo } from './types'
import { extractAndParseJSON } from './utils'
@@ -65,11 +65,20 @@ export async function executeProviderRequest(
throw new Error(`Provider not found: ${providerId}`)
}
// Special handling for Cerebras provider which uses SDK directly
if (providerId === 'cerebras') {
return executeCerebrasRequest(request)
// Use SDK-based implementation if available
if (provider.implementationType === 'sdk' && provider.executeRequest) {
return provider.executeRequest(request)
}
// Legacy HTTP-based implementation for other providers
return executeHttpProviderRequest(provider, request)
}
// Legacy HTTP-based implementation
async function executeHttpProviderRequest(
provider: any,
request: ProviderRequest
): Promise<ProviderResponse> {
// If responseFormat is provided, modify the system prompt to enforce structured output
if (request.responseFormat) {
const structuredOutputInstructions = generateStructuredOutputInstructions(
@@ -88,7 +97,7 @@ export async function executeProviderRequest(
const payload = provider.transformRequest(request, functions)
// Make the initial API request through the proxy
let currentResponse = await makeProxyRequest(providerId, payload, request.apiKey)
let currentResponse = await makeProxyRequest(provider.id, payload, request.apiKey)
let content = ''
let tokens: TokenInfo | undefined = undefined
let toolCalls = []
@@ -209,19 +218,66 @@ export async function executeProviderRequest(
toolCalls.push(functionCall)
// Add the function call and result to messages
currentMessages.push({
role: 'assistant',
content: null,
function_call: {
// Check if we're dealing with the new tool_calls format or the legacy function_call format
const hasToolCalls = currentResponse.choices?.[0]?.message?.tool_calls
if (hasToolCalls) {
const toolCall = currentResponse.choices[0].message.tool_calls[0]
if (toolCall && toolCall.id) {
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: functionCall.name,
arguments: JSON.stringify(functionCall.arguments),
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(result.output),
})
} else {
// Fallback to legacy format if id is missing
currentMessages.push({
role: 'assistant',
content: null,
function_call: {
name: functionCall.name,
arguments: JSON.stringify(functionCall.arguments),
},
})
currentMessages.push({
role: 'function',
name: functionCall.name,
content: JSON.stringify(result.output),
})
}
} else {
// Legacy format
currentMessages.push({
role: 'assistant',
content: null,
function_call: {
name: functionCall.name,
arguments: JSON.stringify(functionCall.arguments),
},
})
currentMessages.push({
role: 'function',
name: functionCall.name,
arguments: JSON.stringify(functionCall.arguments),
},
})
currentMessages.push({
role: 'function',
name: functionCall.name,
content: JSON.stringify(result.output),
})
content: JSON.stringify(result.output),
})
}
// Prepare the next request
const nextPayload = provider.transformRequest(
@@ -233,7 +289,7 @@ export async function executeProviderRequest(
)
// Make the next request
currentResponse = await makeProxyRequest(providerId, nextPayload, request.apiKey)
currentResponse = await makeProxyRequest(provider.id, nextPayload, request.apiKey)
iterationCount++
}

View File

@@ -19,9 +19,15 @@ export interface ProviderConfig {
models: string[]
defaultModel: string
// Provider-specific configuration
baseUrl: string
headers: (apiKey: string) => Record<string, string>
// Provider implementation type
implementationType: 'sdk' | 'http'
// For HTTP-based providers
baseUrl?: string
headers?: (apiKey: string) => Record<string, string>
// For SDK-based providers
executeRequest?: (request: ProviderRequest) => Promise<ProviderResponse>
// Tool calling support
transformToolsToFunctions: (tools: ProviderToolConfig[]) => any
@@ -78,6 +84,15 @@ export interface Message {
name: string
arguments: string
}
tool_calls?: Array<{
id: string
type: 'function'
function: {
name: string
arguments: string
}
}>
tool_call_id?: string
}
export interface ProviderRequest {
@@ -96,6 +111,7 @@ export interface ProviderRequest {
description?: string
}>
}
local_execution?: boolean
}
// Map of provider IDs to their configurations