feat(providers): modified all providers to use SDK rather than creating an HTTP request

This commit is contained in:
Waleed Latif
2025-02-27 13:39:31 -08:00
parent a08efa3d91
commit 5c11e9da16
24 changed files with 1254 additions and 1639 deletions

View File

@@ -1,118 +1,23 @@
import { NextResponse } from 'next/server'
import { MODEL_PROVIDERS } from '@/providers/consts'
import { getProvider } from '@/providers/registry'
import { getTool } from '@/tools'
import { executeTool, getTool } from '@/tools'
import { validateToolRequest } from '@/tools/utils'
export async function POST(request: Request) {
try {
const { toolId, params } = await request.json()
// Check if this is a provider chat request
const provider = getProvider(toolId)
if (provider) {
const { apiKey, ...restParams } = params
if (!apiKey) {
throw new Error('API key is required')
}
const response = await fetch(provider.baseUrl, {
method: 'POST',
headers: provider.headers(apiKey),
body: JSON.stringify(restParams),
})
if (!response.ok) {
const error = await response.json()
throw new Error(error.error?.message || `${toolId} API error`)
}
return NextResponse.json({
success: true,
output: await response.json(),
})
}
// Check if this is an LLM provider tool (e.g., openai_chat, anthropic_chat)
const providerPrefix = toolId.split('_')[0]
if (Object.values(MODEL_PROVIDERS).includes(providerPrefix)) {
// Redirect to the provider system
const providerInstance = getProvider(providerPrefix)
if (!providerInstance) {
throw new Error(`Provider not found for tool: ${toolId}`)
}
const { apiKey, ...restParams } = params
if (!apiKey) {
throw new Error('API key is required')
}
const response = await fetch(providerInstance.baseUrl, {
method: 'POST',
headers: providerInstance.headers(apiKey),
body: JSON.stringify(restParams),
})
if (!response.ok) {
const error = await response.json()
throw new Error(error.error?.message || `${toolId} API error`)
}
return NextResponse.json({
success: true,
output: await response.json(),
})
}
// Handle regular tool requests
const tool = getTool(toolId)
if (!tool) {
throw new Error(`Tool not found: ${toolId}`)
}
if (tool.params?.apiKey?.required && !params.apiKey) {
throw new Error(`API key is required for ${toolId}`)
}
const { url: urlOrFn, method: defaultMethod, headers: headersFn, body: bodyFn } = tool.request
// Validate the tool and its parameters
validateToolRequest(toolId, tool, params)
try {
const url = typeof urlOrFn === 'function' ? urlOrFn(params) : urlOrFn
const method = params.method || defaultMethod || 'GET'
const headers = headersFn ? headersFn(params) : {}
const hasBody = method !== 'GET' && method !== 'HEAD' && !!bodyFn
const bodyResult = bodyFn ? bodyFn(params) : undefined
// Special handling for NDJSON content type
const isNDJSON = headers['Content-Type'] === 'application/x-ndjson'
const body = hasBody
? isNDJSON && bodyResult
? bodyResult.body
: JSON.stringify(bodyResult)
: undefined
const externalResponse = await fetch(url, { method, headers, body })
if (!externalResponse.ok) {
const errorContent = await externalResponse.json().catch(() => ({
message: externalResponse.statusText,
}))
// Use the tool's error transformer or a default message
const error = tool.transformError
? tool.transformError(errorContent)
: errorContent.message || `${toolId} API error: ${externalResponse.statusText}`
throw new Error(error)
if (!tool) {
throw new Error(`Tool not found: ${toolId}`)
}
const transformResponse =
tool.transformResponse ||
(async (resp: Response) => ({
success: true,
output: await resp.json(),
}))
const result = await transformResponse(externalResponse)
// Use executeTool with skipProxy=true to prevent recursive proxy calls
const result = await executeTool(toolId, params, true)
if (!result.success) {
throw new Error(

View File

@@ -1,5 +1,5 @@
import { AgentIcon } from '@/components/icons'
import { MODEL_PROVIDERS } from '@/providers/consts'
import { MODEL_PROVIDERS } from '@/providers/utils'
import { ToolResponse } from '@/tools/types'
import { BlockConfig } from '../types'

View File

@@ -1,6 +1,6 @@
import { ChartBarIcon } from '@/components/icons'
import { MODEL_PROVIDERS } from '@/providers/consts'
import { ProviderId } from '@/providers/registry'
import { ProviderId } from '@/providers/types'
import { MODEL_PROVIDERS } from '@/providers/utils'
import { ToolResponse } from '@/tools/types'
import { BlockConfig, ParamType } from '../types'

View File

@@ -1,6 +1,6 @@
import { ConnectIcon } from '@/components/icons'
import { MODEL_PROVIDERS } from '@/providers/consts'
import { ProviderId } from '@/providers/registry'
import { ProviderId } from '@/providers/types'
import { MODEL_PROVIDERS } from '@/providers/utils'
import { ToolResponse } from '@/tools/types'
import { BlockConfig } from '../types'

View File

@@ -1,6 +1,6 @@
import { TranslateIcon } from '@/components/icons'
import { MODEL_PROVIDERS } from '@/providers/consts'
import { ProviderId } from '@/providers/registry'
import { ProviderId } from '@/providers/types'
import { MODEL_PROVIDERS } from '@/providers/utils'
import { BlockConfig } from '../types'
const getTranslationPrompt = (

View File

@@ -1,14 +0,0 @@
// export const MODEL_TOOLS = {
// 'gpt-4o': 'openai_chat',
// o1: 'openai_chat',
// 'o3-mini': 'openai_chat',
// 'deepseek-v3': 'deepseek_chat',
// 'deepseek-r1': 'deepseek_reasoner',
// 'claude-3-7-sonnet-20250219': 'anthropic_chat',
// 'gemini-2.0-flash': 'google_chat',
// 'grok-2-latest': 'xai_chat',
// 'llama-3.3-70b': 'cerebras_chat',
// } as const
// export type ModelType = keyof typeof MODEL_TOOLS
// export type ToolType = (typeof MODEL_TOOLS)[ModelType]

View File

@@ -1,4 +1,4 @@
import { BlockState, SubBlockState } from '@/stores/workflow/types'
import { SubBlockState } from '@/stores/workflow/types'
import { BlockOutput, OutputConfig } from '@/blocks/types'
interface CodeLine {

View File

@@ -1,7 +1,7 @@
import { getAllBlocks } from '@/blocks'
import { generateRouterPrompt } from '@/blocks/blocks/router'
import { BlockOutput } from '@/blocks/types'
import { executeProviderRequest } from '@/providers/service'
import { executeProviderRequest } from '@/providers'
import { getProviderFromModel } from '@/providers/utils'
import { SerializedBlock } from '@/serializer/types'
import { executeTool, getTool } from '@/tools'
@@ -103,9 +103,6 @@ export class AgentBlockHandler implements BlockHandler {
.filter((t): t is NonNullable<typeof t> => t !== null)
: []
// Add local_execution: true for Cerebras provider
const additionalParams = providerId === 'cerebras' ? { local_execution: true } : {}
const response = await executeProviderRequest(providerId, {
model,
systemPrompt: inputs.systemPrompt,
@@ -117,7 +114,6 @@ export class AgentBlockHandler implements BlockHandler {
maxTokens: inputs.maxTokens,
apiKey: inputs.apiKey,
responseFormat,
...additionalParams,
})
// Return structured or standard response based on responseFormat
@@ -447,7 +443,7 @@ export class ApiBlockHandler implements BlockHandler {
throw new Error(`Tool not found: ${block.config.tool}`)
}
const result = await executeTool(block.config.tool, inputs)
const result = await executeTool(block.config.tool, inputs, true)
if (!result.success) {
throw new Error(result.error || `API request failed with no error message`)
}
@@ -474,7 +470,7 @@ export class FunctionBlockHandler implements BlockHandler {
throw new Error(`Tool not found: ${block.config.tool}`)
}
const result = await executeTool(block.config.tool, inputs)
const result = await executeTool(block.config.tool, inputs, true)
if (!result.success) {
throw new Error(result.error || `Function execution failed with no error message`)
}
@@ -502,7 +498,7 @@ export class GenericBlockHandler implements BlockHandler {
throw new Error(`Tool not found: ${block.config.tool}`)
}
const result = await executeTool(block.config.tool, inputs)
const result = await executeTool(block.config.tool, inputs, true)
if (!result.success) {
throw new Error(result.error || `Block execution failed with no error message`)
}

62
package-lock.json generated
View File

@@ -8,7 +8,7 @@
"name": "sim",
"version": "0.1.0",
"dependencies": {
"@cerebras/cerebras_cloud_sdk": "^1.23.0",
"@anthropic-ai/sdk": "^0.38.0",
"@radix-ui/react-alert-dialog": "^1.1.5",
"@radix-ui/react-checkbox": "^1.1.3",
"@radix-ui/react-dialog": "^1.1.5",
@@ -105,6 +105,36 @@
"node": ">=6.0.0"
}
},
"node_modules/@anthropic-ai/sdk": {
"version": "0.38.0",
"resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.38.0.tgz",
"integrity": "sha512-ZUYjadEEeb1wwKd9MEM+plSfl7zQjYixhjHRtyPsjO7MtzRmbZvBb1n1ofo2kHwb0aUiIdb3aAEbAwS9Bcbm/A==",
"license": "MIT",
"dependencies": {
"@types/node": "^18.11.18",
"@types/node-fetch": "^2.6.4",
"abort-controller": "^3.0.0",
"agentkeepalive": "^4.2.1",
"form-data-encoder": "1.7.2",
"formdata-node": "^4.3.2",
"node-fetch": "^2.6.7"
}
},
"node_modules/@anthropic-ai/sdk/node_modules/@types/node": {
"version": "18.19.76",
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.76.tgz",
"integrity": "sha512-yvR7Q9LdPz2vGpmpJX5LolrgRdWvB67MJKDPSgIIzpFbaf9a1j/f5DnLp5VDyHGMR0QZHlTr1afsD87QCXFHKw==",
"license": "MIT",
"dependencies": {
"undici-types": "~5.26.4"
}
},
"node_modules/@anthropic-ai/sdk/node_modules/undici-types": {
"version": "5.26.5",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
"license": "MIT"
},
"node_modules/@babel/code-frame": {
"version": "7.26.2",
"resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz",
@@ -647,36 +677,6 @@
"resolved": "https://registry.npmjs.org/@better-fetch/fetch/-/fetch-1.1.12.tgz",
"integrity": "sha512-B3bfloI/2UBQWIATRN6qmlORrvx3Mp0kkNjmXLv0b+DtbtR+pP4/I5kQA/rDUv+OReLywCCldf6co4LdDmh8JA=="
},
"node_modules/@cerebras/cerebras_cloud_sdk": {
"version": "1.23.0",
"resolved": "https://registry.npmjs.org/@cerebras/cerebras_cloud_sdk/-/cerebras_cloud_sdk-1.23.0.tgz",
"integrity": "sha512-1krbmU4nTbJICUbcJGQGGo+MtB0nzHx/jwW24ZhoBzuC5QT8H/WzNjLdKtvdf3TB8GS1AtdWUkUHNJf1EZfvJA==",
"license": "Apache-2.0",
"dependencies": {
"@types/node": "^18.11.18",
"@types/node-fetch": "^2.6.4",
"abort-controller": "^3.0.0",
"agentkeepalive": "^4.2.1",
"form-data-encoder": "1.7.2",
"formdata-node": "^4.3.2",
"node-fetch": "^2.6.7"
}
},
"node_modules/@cerebras/cerebras_cloud_sdk/node_modules/@types/node": {
"version": "18.19.76",
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.76.tgz",
"integrity": "sha512-yvR7Q9LdPz2vGpmpJX5LolrgRdWvB67MJKDPSgIIzpFbaf9a1j/f5DnLp5VDyHGMR0QZHlTr1afsD87QCXFHKw==",
"license": "MIT",
"dependencies": {
"undici-types": "~5.26.4"
}
},
"node_modules/@cerebras/cerebras_cloud_sdk/node_modules/undici-types": {
"version": "5.26.5",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
"license": "MIT"
},
"node_modules/@drizzle-team/brocli": {
"version": "0.10.2",
"resolved": "https://registry.npmjs.org/@drizzle-team/brocli/-/brocli-0.10.2.tgz",

View File

@@ -17,7 +17,7 @@
"test:coverage": "jest --coverage"
},
"dependencies": {
"@cerebras/cerebras_cloud_sdk": "^1.23.0",
"@anthropic-ai/sdk": "^0.38.0",
"@radix-ui/react-alert-dialog": "^1.1.5",
"@radix-ui/react-checkbox": "^1.1.3",
"@radix-ui/react-dialog": "^1.1.5",

View File

@@ -1,10 +1,6 @@
import {
FunctionCallResponse,
Message,
ProviderConfig,
ProviderRequest,
ProviderToolConfig,
} from '../types'
import Anthropic from '@anthropic-ai/sdk'
import { executeTool } from '@/tools'
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
export const anthropicProvider: ProviderConfig = {
id: 'anthropic',
@@ -14,73 +10,40 @@ export const anthropicProvider: ProviderConfig = {
models: ['claude-3-7-sonnet-20250219'],
defaultModel: 'claude-3-7-sonnet-20250219',
baseUrl: 'https://api.anthropic.com/v1/messages',
headers: (apiKey: string) => ({
'Content-Type': 'application/json',
'x-api-key': apiKey,
'anthropic-version': '2023-06-01',
}),
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
if (!tools || tools.length === 0) {
return undefined
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
if (!request.apiKey) {
throw new Error('API key is required for Anthropic')
}
return tools.map((tool) => ({
name: tool.id,
description: tool.description,
input_schema: {
type: 'object',
properties: tool.parameters.properties,
required: tool.parameters.required,
},
}))
},
const anthropic = new Anthropic({
apiKey: request.apiKey,
dangerouslyAllowBrowser: true,
})
transformFunctionCallResponse: (
response: any,
tools?: ProviderToolConfig[]
): FunctionCallResponse => {
const rawResponse = response?.output || response
if (!rawResponse?.content) {
throw new Error('No content found in response')
// Helper function to generate a simple unique ID for tool uses
const generateToolUseId = (toolName: string) => {
return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}`
}
const toolUse = rawResponse.content.find((item: any) => item.type === 'tool_use')
if (!toolUse) {
throw new Error('No tool use found in response')
}
const tool = tools?.find((t) => t.id === toolUse.name)
if (!tool) {
throw new Error(`Tool not found: ${toolUse.name}`)
}
let input = toolUse.input
if (typeof input === 'string') {
try {
input = JSON.parse(input)
} catch (e) {
console.error('Failed to parse tool input:', e)
input = {}
}
}
return {
name: toolUse.name,
arguments: {
...tool.params,
...input,
},
}
},
transformRequest: (request: ProviderRequest, functions?: any) => {
// Transform messages to Anthropic format
const messages =
request.messages?.map((msg) => {
const messages = []
// Add system prompt if present
let systemPrompt = request.systemPrompt || ''
// Add context if present
if (request.context) {
messages.push({
role: 'user',
content: request.context,
})
}
// Add remaining messages
if (request.messages) {
request.messages.forEach((msg) => {
if (msg.role === 'function') {
return {
messages.push({
role: 'user',
content: [
{
@@ -89,58 +52,55 @@ export const anthropicProvider: ProviderConfig = {
content: msg.content,
},
],
}
}
if (msg.function_call) {
return {
})
} else if (msg.function_call) {
const toolUseId = msg.function_call.name + '-' + Date.now()
messages.push({
role: 'assistant',
content: [
{
type: 'tool_use',
id: msg.function_call.name,
id: toolUseId,
name: msg.function_call.name,
input: JSON.parse(msg.function_call.arguments),
},
],
}
})
} else {
messages.push({
role: msg.role === 'assistant' ? 'assistant' : 'user',
content: msg.content ? [{ type: 'text', text: msg.content }] : [],
})
}
return {
role: msg.role === 'assistant' ? 'assistant' : 'user',
content: msg.content ? [{ type: 'text', text: msg.content }] : [],
}
}) || []
// Add context if provided
if (request.context) {
messages.unshift({
role: 'user',
content: [{ type: 'text', text: request.context }],
})
}
// Ensure there's at least one message by adding the system prompt as a user message if no messages exist
// Ensure there's at least one message
if (messages.length === 0) {
messages.push({
role: 'user',
content: [{ type: 'text', text: request.systemPrompt || '' }],
content: [{ type: 'text', text: systemPrompt || 'Hello' }],
})
// Clear system prompt since we've used it as a user message
systemPrompt = ''
}
// Build the request payload
const payload = {
model: request.model || 'claude-3-7-sonnet-20250219',
messages,
system: request.systemPrompt || '',
max_tokens: parseInt(String(request.maxTokens)) || 1024,
temperature: parseFloat(String(request.temperature ?? 0.7)),
...(functions && { tools: functions }),
}
// Transform tools to Anthropic format if provided
const tools = request.tools?.length
? request.tools.map((tool) => ({
name: tool.id,
description: tool.description,
input_schema: {
type: 'object',
properties: tool.parameters.properties,
required: tool.parameters.required,
},
}))
: undefined
// If response format is specified, add strict formatting instructions
if (request.responseFormat) {
payload.system = `${payload.system}\n\nIMPORTANT RESPONSE FORMAT INSTRUCTIONS:
systemPrompt = `${systemPrompt}\n\nIMPORTANT RESPONSE FORMAT INSTRUCTIONS:
1. Your response must be EXACTLY in this format, with no additional fields:
{
${request.responseFormat.fields.map((field) => ` "${field.name}": ${field.type === 'string' ? '"value"' : field.type === 'array' ? '[]' : field.type === 'object' ? '{}' : field.type === 'number' ? '0' : 'true/false'}`).join(',\n')}
@@ -155,67 +115,163 @@ ${request.responseFormat.fields.map((field) => `${field.name} (${field.type})${f
5. Your response MUST be valid JSON and include all the specified fields with their correct types`
}
return payload
},
// Build the request payload
const payload: any = {
model: request.model || 'claude-3-7-sonnet-20250219',
messages,
system: systemPrompt,
max_tokens: parseInt(String(request.maxTokens)) || 1024,
temperature: parseFloat(String(request.temperature ?? 0.7)),
}
// Add tools if provided
if (tools?.length) {
payload.tools = tools
}
// Make the initial API request
let currentResponse = await anthropic.messages.create(payload)
let content = ''
// Extract text content from the message
if (Array.isArray(currentResponse.content)) {
content = currentResponse.content
.filter((item) => item.type === 'text')
.map((item) => item.text)
.join('\n')
}
let tokens = {
prompt: currentResponse.usage?.input_tokens || 0,
completion: currentResponse.usage?.output_tokens || 0,
total:
(currentResponse.usage?.input_tokens || 0) + (currentResponse.usage?.output_tokens || 0),
}
let toolCalls = []
let toolResults = []
let currentMessages = [...messages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
transformResponse: (response: any) => {
try {
if (!response) {
console.warn('Received undefined response from Anthropic API')
return { content: '' }
}
while (iterationCount < MAX_ITERATIONS) {
// Check for tool calls
const toolUses = currentResponse.content.filter((item) => item.type === 'tool_use')
if (!toolUses || toolUses.length === 0) {
break
}
// Get the actual response content
const rawResponse = response.output || response
// Process each tool call
for (const toolUse of toolUses) {
try {
const toolName = toolUse.name
const toolArgs = toolUse.input as Record<string, any>
// Extract text content from the message
let content = ''
const messageContent = rawResponse?.content || rawResponse?.message?.content
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
if (Array.isArray(messageContent)) {
content = messageContent
// Execute the tool
const mergedArgs = { ...tool.params, ...toolArgs }
const result = await executeTool(toolName, mergedArgs, true)
if (!result.success) continue
toolResults.push(result.output)
toolCalls.push({
name: toolName,
arguments: toolArgs,
})
// Add the tool call and result to messages
const toolUseId = generateToolUseId(toolName)
currentMessages.push({
role: 'assistant',
content: [
{
type: 'tool_use',
id: toolUseId,
name: toolName,
input: toolArgs,
} as any,
],
})
currentMessages.push({
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: toolUseId,
content: JSON.stringify(result.output),
} as any,
],
})
} catch (error) {
console.error('Error processing tool call:', error)
}
}
// Make the next request with updated messages
const nextPayload = {
...payload,
messages: currentMessages,
}
// Make the next request
currentResponse = await anthropic.messages.create(nextPayload)
// Update content if we have a text response
const textContent = currentResponse.content
.filter((item) => item.type === 'text')
.map((item) => item.text)
.join('\n')
} else if (typeof messageContent === 'string') {
content = messageContent
}
// If the content looks like it contains JSON, extract just the JSON part
if (content.includes('{') && content.includes('}')) {
try {
const jsonMatch = content.match(/\{[\s\S]*\}/m)
if (jsonMatch) {
content = jsonMatch[0]
}
} catch (e) {
console.error('Error extracting JSON from response:', e)
if (textContent) {
content = textContent
}
}
return {
content,
model: rawResponse?.model || response?.model || 'claude-3-7-sonnet-20250219',
tokens: rawResponse?.usage && {
prompt: rawResponse.usage.input_tokens,
completion: rawResponse.usage.output_tokens,
total: rawResponse.usage.input_tokens + rawResponse.usage.output_tokens,
},
// Update token counts
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.input_tokens || 0
tokens.completion += currentResponse.usage.output_tokens || 0
tokens.total +=
(currentResponse.usage.input_tokens || 0) + (currentResponse.usage.output_tokens || 0)
}
iterationCount++
}
} catch (error) {
console.error('Error in transformResponse:', error)
return { content: '' }
console.error('Error in Anthropic request:', error)
throw error
}
},
hasFunctionCall: (response: any) => {
try {
if (!response) return false
const rawResponse = response.output || response
return rawResponse?.content?.some((item: any) => item.type === 'tool_use') || false
} catch (error) {
console.error('Error in hasFunctionCall:', error)
return false
// If the content looks like it contains JSON, extract just the JSON part
if (content.includes('{') && content.includes('}')) {
try {
const jsonMatch = content.match(/\{[\s\S]*\}/m)
if (jsonMatch) {
content = jsonMatch[0]
}
} catch (e) {
console.error('Error extracting JSON from response:', e)
}
}
return {
content,
model: request.model || 'claude-3-7-sonnet-20250219',
tokens,
toolCalls:
toolCalls.length > 0
? toolCalls.map((tc) => ({
name: tc.name,
arguments: tc.arguments as Record<string, any>,
}))
: undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
}
},
}

View File

@@ -1,11 +1,6 @@
import { Cerebras } from '@cerebras/cerebras_cloud_sdk'
import {
FunctionCallResponse,
ProviderConfig,
ProviderRequest,
ProviderResponse,
ProviderToolConfig,
} from '../types'
import OpenAI from 'openai'
import { executeTool } from '@/tools'
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
export const cerebrasProvider: ProviderConfig = {
id: 'cerebras',
@@ -14,153 +9,18 @@ export const cerebrasProvider: ProviderConfig = {
version: '1.0.0',
models: ['llama-3.3-70b'],
defaultModel: 'llama-3.3-70b',
implementationType: 'sdk',
// SDK-based implementation
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
try {
const client = new Cerebras({
apiKey: request.apiKey,
})
// Start with an empty array for all messages
const allMessages = []
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push({
role: 'system',
content: request.systemPrompt,
})
}
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
content: request.context,
})
}
// Add remaining messages
if (request.messages) {
allMessages.push(...request.messages)
}
// Transform tools to Cerebras format if provided
const functions = request.tools?.length
? request.tools.map((tool) => ({
name: tool.id,
description: tool.description,
parameters: tool.parameters,
}))
: undefined
// Build the request payload
const payload: any = {
model: request.model || 'llama-3.3-70b',
messages: allMessages,
}
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add functions if provided
if (functions?.length) {
payload.functions = functions
payload.function_call = 'auto'
}
// Add local execution flag if specified
if (request.local_execution) {
payload.local_execution = true
}
// Execute the request using the SDK
const response = (await client.chat.completions.create(payload)) as any
// Extract content and token info
const content = response.choices?.[0]?.message?.content || ''
const tokens = {
prompt: response.usage?.prompt_tokens || 0,
completion: response.usage?.completion_tokens || 0,
total: response.usage?.total_tokens || 0,
}
// Check for function calls
const functionCall = response.choices?.[0]?.message?.function_call
let toolCalls = undefined
if (functionCall) {
const tool = request.tools?.find((t) => t.id === functionCall.name)
const toolParams = tool?.params || {}
toolCalls = [
{
name: functionCall.name,
arguments: {
...toolParams,
...JSON.parse(functionCall.arguments),
},
},
]
}
// Return the response in the expected format
return {
content,
model: request.model,
tokens,
toolCalls,
}
} catch (error: any) {
console.error('Error executing Cerebras request:', error)
throw new Error(`Cerebras API error: ${error.message}`)
}
},
// These are still needed for backward compatibility
baseUrl: 'https://api.cerebras.cloud/v1/chat/completions',
headers: (apiKey: string) => ({
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
}),
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
if (!tools || tools.length === 0) {
return undefined
if (!request.apiKey) {
throw new Error('API key is required for Cerebras')
}
return tools.map((tool) => ({
name: tool.id,
description: tool.description,
parameters: tool.parameters,
}))
},
const openai = new OpenAI({
apiKey: request.apiKey,
baseURL: 'https://api.cerebras.ai/v1',
dangerouslyAllowBrowser: true,
})
transformFunctionCallResponse: (
response: any,
tools?: ProviderToolConfig[]
): FunctionCallResponse => {
const functionCall = response.choices?.[0]?.message?.function_call
if (!functionCall) {
throw new Error('No function call found in response')
}
const tool = tools?.find((t) => t.id === functionCall.name)
const toolParams = tool?.params || {}
return {
name: functionCall.name,
arguments: {
...toolParams,
...JSON.parse(functionCall.arguments),
},
}
},
transformRequest: (request: ProviderRequest, functions?: any) => {
// Start with an empty array for all messages
const allMessages = []
@@ -185,6 +45,18 @@ export const cerebrasProvider: ProviderConfig = {
allMessages.push(...request.messages)
}
// Transform tools to OpenAI format if provided
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: tool.parameters,
},
}))
: undefined
// Build the request payload
const payload: any = {
model: request.model || 'llama-3.3-70b',
@@ -195,38 +67,128 @@ export const cerebrasProvider: ProviderConfig = {
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add function calling support
if (functions) {
payload.functions = functions
payload.function_call = 'auto'
// Add response format for structured output if specified
if (request.responseFormat) {
payload.response_format = { type: 'json_object' }
}
// Add local execution flag if specified
// Add tools if provided
if (tools?.length) {
payload.tools = tools
payload.tool_choice = 'auto'
}
// Add local execution flag if specified by Cerebras
if (request.local_execution) {
payload.local_execution = true
}
return payload
},
transformResponse: (response: any) => {
const output = {
content: response.choices?.[0]?.message?.content || '',
tokens: undefined as any,
// Make the initial API request
let currentResponse = await openai.chat.completions.create(payload)
let content = currentResponse.choices[0]?.message?.content || ''
let tokens = {
prompt: currentResponse.usage?.prompt_tokens || 0,
completion: currentResponse.usage?.completion_tokens || 0,
total: currentResponse.usage?.total_tokens || 0,
}
let toolCalls = []
let toolResults = []
let currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
if (response.usage) {
output.tokens = {
prompt: response.usage.prompt_tokens,
completion: response.usage.completion_tokens,
total: response.usage.total_tokens,
try {
while (iterationCount < MAX_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
// Process each tool call
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
// Execute the tool
const mergedArgs = { ...tool.params, ...toolArgs }
const result = await executeTool(toolName, mergedArgs, true)
if (!result.success) continue
toolResults.push(result.output)
toolCalls.push({
name: toolName,
arguments: toolArgs,
})
// Add the tool call and result to messages
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
const toolResultContent = JSON.stringify(result.output)
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: toolResultContent,
})
} catch (error) {
console.error('Error processing tool call:', error)
}
}
// Make the next request with updated messages
const nextPayload = {
...payload,
messages: currentMessages,
}
// Make the next request
currentResponse = await openai.chat.completions.create(nextPayload)
// Update content if we have a text response
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
// Update token counts
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.prompt_tokens || 0
tokens.completion += currentResponse.usage.completion_tokens || 0
tokens.total += currentResponse.usage.total_tokens || 0
}
iterationCount++
}
} catch (error) {
console.error('Error in Cerebras request:', error)
throw error
}
return output
},
hasFunctionCall: (response: any) => {
return !!response.choices?.[0]?.message?.function_call
return {
content,
model: request.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
}
},
}

View File

@@ -1,17 +0,0 @@
import { ProviderId } from './registry'
/**
* Direct mapping from model names to provider IDs
* This replaces the need for the MODEL_TOOLS mapping in blocks/consts.ts
*/
export const MODEL_PROVIDERS: Record<string, ProviderId> = {
'gpt-4o': 'openai',
o1: 'openai',
'o3-mini': 'openai',
'claude-3-7-sonnet-20250219': 'anthropic',
'gemini-2.0-flash': 'google',
'grok-2-latest': 'xai',
'deepseek-v3': 'deepseek',
'deepseek-r1': 'deepseek',
'llama-3.3-70b': 'cerebras',
}

View File

@@ -1,4 +1,6 @@
import { FunctionCallResponse, ProviderConfig, ProviderRequest, ProviderToolConfig } from '../types'
import OpenAI from 'openai'
import { executeTool } from '@/tools'
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
export const deepseekProvider: ProviderConfig = {
id: 'deepseek',
@@ -8,116 +10,74 @@ export const deepseekProvider: ProviderConfig = {
models: ['deepseek-chat'],
defaultModel: 'deepseek-chat',
baseUrl: 'https://api.deepseek.com/v1/chat/completions',
headers: (apiKey: string) => ({
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
}),
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
if (!tools || tools.length === 0) {
return undefined
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
if (!request.apiKey) {
throw new Error('API key is required for Deepseek')
}
return tools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: tool.parameters,
},
}))
},
transformFunctionCallResponse: (
response: any,
tools?: ProviderToolConfig[]
): FunctionCallResponse => {
const toolCall = response.choices?.[0]?.message?.tool_calls?.[0]
if (!toolCall || !toolCall.function) {
throw new Error('No valid tool call found in response')
}
const tool = tools?.find((t) => t.id === toolCall.function.name)
if (!tool) {
throw new Error(`Tool not found: ${toolCall.function.name}`)
}
let args = toolCall.function.arguments
if (typeof args === 'string') {
try {
args = JSON.parse(args)
} catch (e) {
console.error('Failed to parse tool arguments:', e)
args = {}
}
}
return {
name: toolCall.function.name,
arguments: {
...tool.params,
...args,
},
}
},
transformRequest: (request: ProviderRequest, functions?: any) => {
// Transform messages from internal format to Deepseek format
const messages = (request.messages || []).map((msg) => {
if (msg.role === 'function') {
return {
role: 'tool',
content: msg.content,
tool_call_id: msg.name,
}
}
if (msg.function_call) {
return {
role: 'assistant',
content: null,
tool_calls: [
{
id: msg.function_call.name,
type: 'function',
function: {
name: msg.function_call.name,
arguments: msg.function_call.arguments,
},
},
],
}
}
return msg
// Deepseek uses the OpenAI SDK with a custom baseURL
const deepseek = new OpenAI({
apiKey: request.apiKey,
baseURL: 'https://api.deepseek.com/v1',
dangerouslyAllowBrowser: true,
})
const payload = {
model: 'deepseek-chat',
messages: [
{ role: 'system', content: request.systemPrompt },
...(request.context ? [{ role: 'user', content: request.context }] : []),
...messages,
],
temperature: request.temperature || 0.7,
max_tokens: request.maxTokens || 1024,
...(functions && { tools: functions }),
// Start with an empty array for all messages
const allMessages = []
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push({
role: 'system',
content: request.systemPrompt,
})
}
return payload
},
transformResponse: (response: any) => {
if (!response) {
console.warn('Received undefined response from Deepseek API')
return { content: '' }
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
content: request.context,
})
}
const output = response.choices?.[0]?.message
// Add remaining messages
if (request.messages) {
allMessages.push(...request.messages)
}
// Try to clean up the response content if it exists
let content = output?.content || ''
// Transform tools to OpenAI format if provided
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: tool.parameters,
},
}))
: undefined
const payload: any = {
model: 'deepseek-chat', // Hardcode to deepseek-chat regardless of what's selected in the UI
messages: allMessages,
}
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add tools if provided
if (tools?.length) {
payload.tools = tools
payload.tool_choice = 'auto'
}
// Make the initial API request
let currentResponse = await deepseek.chat.completions.create(payload)
let content = currentResponse.choices[0]?.message?.content || ''
// Clean up the response content if it exists
if (content) {
// Remove any markdown code block markers
content = content.replace(/```json\n?|\n?```/g, '')
@@ -125,18 +85,110 @@ export const deepseekProvider: ProviderConfig = {
content = content.trim()
}
let tokens = {
prompt: currentResponse.usage?.prompt_tokens || 0,
completion: currentResponse.usage?.completion_tokens || 0,
total: currentResponse.usage?.total_tokens || 0,
}
let toolCalls = []
let toolResults = []
let currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
try {
while (iterationCount < MAX_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
// Process each tool call
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
// Execute the tool
const mergedArgs = { ...tool.params, ...toolArgs }
const result = await executeTool(toolName, mergedArgs, true)
if (!result.success) continue
toolResults.push(result.output)
toolCalls.push({
name: toolName,
arguments: toolArgs,
})
// Add the tool call and result to messages
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(result.output),
})
} catch (error) {
console.error('Error processing tool call:', error)
}
}
// Make the next request with updated messages
const nextPayload = {
...payload,
messages: currentMessages,
}
// Make the next request
currentResponse = await deepseek.chat.completions.create(nextPayload)
// Update content if we have a text response
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
// Clean up the response content
content = content.replace(/```json\n?|\n?```/g, '')
content = content.trim()
}
// Update token counts
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.prompt_tokens || 0
tokens.completion += currentResponse.usage.completion_tokens || 0
tokens.total += currentResponse.usage.total_tokens || 0
}
iterationCount++
}
} catch (error) {
console.error('Error in Deepseek request:', error)
throw error
}
return {
content,
tokens: response.usage && {
prompt: response.usage.prompt_tokens,
completion: response.usage.completion_tokens,
total: response.usage.total_tokens,
},
model: request.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
}
},
hasFunctionCall: (response: any) => {
if (!response) return false
return !!response.choices?.[0]?.message?.tool_calls?.[0]
},
}

View File

@@ -1,205 +1,189 @@
import { ToolConfig } from '@/tools/types'
import { FunctionCallResponse, ProviderConfig, ProviderRequest, ProviderToolConfig } from '../types'
import OpenAI from 'openai'
import { executeTool } from '@/tools'
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
export const googleProvider: ProviderConfig = {
id: 'google',
name: 'Google',
description: "Google's Gemini models",
version: '1.0.0',
models: ['gemini-2.0-flash-001'],
defaultModel: 'gemini-2.0-flash-001',
models: ['gemini-2.0-flash'],
defaultModel: 'gemini-2.0-flash',
baseUrl:
'https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-001:generateContent',
headers: (apiKey: string) => ({
'Content-Type': 'application/json',
'x-goog-api-key': apiKey,
}),
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
if (!tools || tools.length === 0) {
return undefined
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
if (!request.apiKey) {
throw new Error('API key is required for Google Gemini')
}
const transformProperties = (properties: Record<string, any>): Record<string, any> => {
return Object.entries(properties).reduce((acc, [key, value]) => {
// Skip complex JSON/object parameters for Gemini
if (value.type === 'json' || (value.type === 'object' && !value.properties)) {
return acc
}
const openai = new OpenAI({
apiKey: request.apiKey,
baseURL: 'https://generativelanguage.googleapis.com/v1beta/openai/',
dangerouslyAllowBrowser: true,
})
// For object types with defined properties
if (value.type === 'object' && value.properties) {
return {
...acc,
[key]: {
type: 'OBJECT',
description: value.description || '',
properties: transformProperties(value.properties),
},
}
}
// Start with an empty array for all messages
const allMessages = []
// For simple types
return {
...acc,
[key]: {
type: (value.type || 'string').toUpperCase(),
description: value.description || '',
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push({
role: 'system',
content: request.systemPrompt,
})
}
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
content: request.context,
})
}
// Add remaining messages
if (request.messages) {
allMessages.push(...request.messages)
}
// Transform tools to OpenAI format if provided
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: tool.parameters,
},
}))
: undefined
// Build the request payload
const payload: any = {
model: request.model || 'gemini-2.0-flash',
messages: allMessages,
}
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add response format for structured output if specified
if (request.responseFormat) {
payload.response_format = { type: 'json_object' }
}
// Add tools if provided
if (tools?.length) {
payload.tools = tools
payload.tool_choice = 'auto'
}
// Make the initial API request
let currentResponse = await openai.chat.completions.create(payload)
let content = currentResponse.choices[0]?.message?.content || ''
let tokens = {
prompt: currentResponse.usage?.prompt_tokens || 0,
completion: currentResponse.usage?.completion_tokens || 0,
total: currentResponse.usage?.total_tokens || 0,
}
let toolCalls = []
let toolResults = []
let currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
try {
while (iterationCount < MAX_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
}, {})
}
return {
functionDeclarations: tools.map((tool) => {
// Get properties excluding complex types
const properties = transformProperties(tool.parameters.properties)
// Process each tool call
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
// Filter required fields to only include ones that exist in properties
const required = (tool.parameters.required || []).filter((field) => field in properties)
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
return {
name: tool.id,
description: tool.description || '',
parameters: {
type: 'OBJECT',
properties,
required,
},
}
}),
}
},
// Execute the tool
const mergedArgs = { ...tool.params, ...toolArgs }
const result = await executeTool(toolName, mergedArgs, true)
transformFunctionCallResponse: (
response: any,
tools?: ProviderToolConfig[]
): FunctionCallResponse => {
// Extract function call from Gemini response
const functionCall = response.candidates?.[0]?.content?.parts?.[0]?.functionCall
if (!functionCall) {
throw new Error('No function call found in response')
}
if (!result.success) continue
// Log the function call for debugging
console.log('Raw function call from Gemini:', JSON.stringify(functionCall, null, 2))
toolResults.push(result.output)
toolCalls.push({
name: toolName,
arguments: toolArgs,
})
const tool = tools?.find((t) => t.id === functionCall.name)
if (!tool) {
throw new Error(`Tool not found: ${functionCall.name}`)
}
// Ensure args is an object
let args = functionCall.args
if (typeof args === 'string') {
try {
args = JSON.parse(args)
} catch (e) {
console.error('Failed to parse function call args:', e)
args = {}
}
}
// Get arguments from function call, but NEVER override apiKey
const { apiKey: _, ...functionArgs } = args
return {
name: functionCall.name,
arguments: {
...functionArgs,
apiKey: tool.params.apiKey, // Always use the apiKey from tool params
},
}
},
transformRequest: (request: ProviderRequest, functions?: any) => {
// Combine system prompt and context into a single message if both exist
const initialMessage = request.systemPrompt + (request.context ? `\n\n${request.context}` : '')
const messages = [
{ role: 'user', parts: [{ text: initialMessage }] },
...(request.messages || []).map((msg) => {
if (msg.role === 'function') {
return {
role: 'user',
parts: [
{
functionResponse: {
name: msg.name,
response: {
name: msg.name,
content: JSON.parse(msg.content || '{}'),
// Add the tool call and result to messages
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
},
],
],
})
const toolResultContent = JSON.stringify(result.output)
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: toolResultContent,
})
} catch (error) {
console.error('Error processing tool call:', error)
}
}
if (msg.function_call) {
return {
role: 'model',
parts: [
{
functionCall: {
name: msg.function_call.name,
args: JSON.parse(msg.function_call.arguments),
},
},
],
}
// Make the next request with updated messages
const nextPayload = {
...payload,
messages: currentMessages,
}
return {
role: msg.role === 'assistant' ? 'model' : 'user',
parts: [{ text: msg.content || '' }],
// Make the next request
currentResponse = await openai.chat.completions.create(nextPayload)
// Update content if we have a text response
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
}),
]
// Log the request for debugging
console.log(
'Gemini request:',
JSON.stringify(
{
messages,
tools: functions ? [{ functionDeclarations: functions.functionDeclarations }] : undefined,
},
null,
2
)
)
// Update token counts
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.prompt_tokens || 0
tokens.completion += currentResponse.usage.completion_tokens || 0
tokens.total += currentResponse.usage.total_tokens || 0
}
return {
contents: messages,
tools: functions ? [{ functionDeclarations: functions.functionDeclarations }] : undefined,
generationConfig: {
temperature: request.temperature || 0.7,
maxOutputTokens: request.maxTokens || 1024,
},
}
},
transformResponse: (response: any) => {
let content = response.candidates?.[0]?.content?.parts?.[0]?.text || ''
if (content) {
content = content.replace(/```json\n?|\n?```/g, '')
content = content.trim()
iterationCount++
}
} catch (error) {
console.error('Error in Google Gemini request:', error)
throw error
}
return {
content,
tokens: response.usageMetadata && {
prompt: response.usageMetadata.promptTokenCount,
completion: response.usageMetadata.candidatesTokenCount,
total: response.usageMetadata.totalTokenCount,
},
model: request.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
}
},
hasFunctionCall: (response: any) => {
return !!response.candidates?.[0]?.content?.parts?.[0]?.functionCall
},
}

27
providers/index.ts Normal file
View File

@@ -0,0 +1,27 @@
import { ProviderRequest, ProviderResponse } from './types'
import { generateStructuredOutputInstructions, getProvider } from './utils'
export async function executeProviderRequest(
providerId: string,
request: ProviderRequest
): Promise<ProviderResponse> {
const provider = getProvider(providerId)
if (!provider) {
throw new Error(`Provider not found: ${providerId}`)
}
if (!provider.executeRequest) {
throw new Error(`Provider ${providerId} does not implement executeRequest`)
}
// If responseFormat is provided, modify the system prompt to enforce structured output
if (request.responseFormat) {
const structuredOutputInstructions = generateStructuredOutputInstructions(
request.responseFormat
)
request.systemPrompt = `${request.systemPrompt}\n\n${structuredOutputInstructions}`
}
// Execute the request using the provider's implementation
return await provider.executeRequest(request)
}

View File

@@ -1,11 +1,6 @@
import OpenAI from 'openai'
import {
FunctionCallResponse,
ProviderConfig,
ProviderRequest,
ProviderResponse,
ProviderToolConfig,
} from '../types'
import { executeTool } from '@/tools'
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
export const openaiProvider: ProviderConfig = {
id: 'openai',
@@ -14,9 +9,7 @@ export const openaiProvider: ProviderConfig = {
version: '1.0.0',
models: ['gpt-4o', 'o1', 'o3-mini'],
defaultModel: 'gpt-4o',
implementationType: 'sdk',
// SDK-based implementation
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
if (!request.apiKey) {
throw new Error('API key is required for OpenAI')
@@ -116,9 +109,9 @@ export const openaiProvider: ProviderConfig = {
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
// Execute the tool (this will need to be imported from your tools system)
const { executeTool } = await import('@/tools')
const result = await executeTool(toolName, toolArgs)
// Execute the tool
const mergedArgs = { ...tool.params, ...toolArgs }
const result = await executeTool(toolName, mergedArgs, true)
if (!result.success) continue
@@ -190,190 +183,4 @@ export const openaiProvider: ProviderConfig = {
toolResults: toolResults.length > 0 ? toolResults : undefined,
}
},
// These are still needed for backward compatibility
baseUrl: 'https://api.openai.com/v1/chat/completions',
headers: (apiKey: string) => ({
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
}),
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
if (!tools || tools.length === 0) {
return undefined
}
return tools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: tool.parameters,
},
}))
},
transformFunctionCallResponse: (
response: any,
tools?: ProviderToolConfig[]
): FunctionCallResponse => {
// Handle SDK response format
if (response.choices?.[0]?.message?.tool_calls) {
const toolCall = response.choices[0].message.tool_calls[0]
if (!toolCall) {
throw new Error('No tool call found in response')
}
const tool = tools?.find((t) => t.id === toolCall.function.name)
const toolParams = tool?.params || {}
return {
name: toolCall.function.name,
arguments: {
...toolParams,
...JSON.parse(toolCall.function.arguments),
},
}
}
// Handle legacy function_call format for backward compatibility
const functionCall = response.choices?.[0]?.message?.function_call
if (!functionCall) {
throw new Error('No function call found in response')
}
const tool = tools?.find((t) => t.id === functionCall.name)
const toolParams = tool?.params || {}
return {
name: functionCall.name,
arguments: {
...toolParams,
...JSON.parse(functionCall.arguments),
},
}
},
transformRequest: (request: ProviderRequest, functions?: any) => {
const isO1Model = request.model?.startsWith('o1')
const isO1Mini = request.model === 'o1-mini'
// Helper function to transform message role
const transformMessageRole = (message: any) => {
if (isO1Mini && message.role === 'system') {
return { ...message, role: 'user' }
}
return message
}
// Start with an empty array for all messages
const allMessages = []
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push(
transformMessageRole({
role: 'system',
content: request.systemPrompt,
})
)
}
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
content: request.context,
})
}
// Add remaining messages, transforming roles as needed
if (request.messages) {
allMessages.push(...request.messages.map(transformMessageRole))
}
// Build the request payload
const payload: any = {
model: request.model || 'gpt-4o',
messages: allMessages,
}
// Only add parameters supported by the model type
if (!isO1Model) {
// gpt-4o supports standard parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add response format for structured output if specified
if (request.responseFormat) {
// Use OpenAI's simpler response format
payload.response_format = { type: 'json_object' }
// If we have both function calls and response format, we need to guide the model
if (functions) {
payload.messages[0].content = `${payload.messages[0].content}\n\nProcess:\n1. First, use the provided functions to gather the necessary data\n2. Then, format your final response as a SINGLE JSON object with these exact fields and types:\n${request.responseFormat.fields
.map(
(field) =>
`- "${field.name}" (${field.type})${field.description ? `: ${field.description}` : ''}`
)
.join(
'\n'
)}\n\nYour final response after function calls must be a SINGLE valid JSON object with all required fields and correct types. Do not return multiple objects or include any text outside the JSON.`
} else {
// If no functions, just format as JSON directly
payload.messages[0].content = `${payload.messages[0].content}\n\nYou MUST return a SINGLE JSON object with exactly these fields and types:\n${request.responseFormat.fields
.map(
(field) =>
`- "${field.name}" (${field.type})${field.description ? `: ${field.description}` : ''}`
)
.join(
'\n'
)}\n\nThe response must:\n1. Be a single valid JSON object\n2. Include all the specified fields\n3. Use the correct type for each field\n4. Not include any additional fields\n5. Not include any explanatory text outside the JSON\n6. Not return multiple objects`
}
}
} else {
// o1 models use max_completion_tokens
if (request.maxTokens !== undefined) {
payload.max_completion_tokens = request.maxTokens
}
}
// Add function calling support (supported by all models)
if (functions) {
payload.tools = functions
payload.tool_choice = 'auto'
}
return payload
},
transformResponse: (response: any) => {
const output = {
content: response.choices?.[0]?.message?.content || '',
tokens: undefined as any,
}
if (response.usage) {
output.tokens = {
prompt: response.usage.prompt_tokens,
completion: response.usage.completion_tokens,
total: response.usage.total_tokens,
}
// Add reasoning_tokens for o1 models if available
if (response.usage.completion_tokens_details?.reasoning_tokens) {
output.tokens.reasoning = response.usage.completion_tokens_details.reasoning_tokens
}
}
return output
},
hasFunctionCall: (response: any) => {
return (
!!response.choices?.[0]?.message?.function_call ||
(response.choices?.[0]?.message?.tool_calls &&
response.choices[0].message.tool_calls.length > 0)
)
},
}

View File

@@ -1,28 +0,0 @@
import { anthropicProvider } from './anthropic'
import { cerebrasProvider } from './cerebras'
import { deepseekProvider } from './deepseek'
import { googleProvider } from './google'
import { openaiProvider } from './openai'
import { ProviderConfig } from './types'
import { xAIProvider } from './xai'
export type ProviderId = 'openai' | 'anthropic' | 'google' | 'deepseek' | 'xai' | 'cerebras'
export const providers: Record<ProviderId, ProviderConfig> = {
openai: openaiProvider,
anthropic: anthropicProvider,
google: googleProvider,
deepseek: deepseekProvider,
xai: xAIProvider,
cerebras: cerebrasProvider,
}
export function getProvider(id: string): ProviderConfig | undefined {
// Handle both formats: 'openai' and 'openai/chat'
const providerId = id.split('/')[0] as ProviderId
return providers[providerId]
}
export function getProviderChatId(providerId: ProviderId): string {
return `${providerId}/chat`
}

View File

@@ -1,341 +0,0 @@
import OpenAI from 'openai'
import { executeTool, getTool } from '@/tools'
import { getProvider } from './registry'
import { ProviderRequest, ProviderResponse, TokenInfo } from './types'
import { extractAndParseJSON } from './utils'
// Helper function to generate provider-specific structured output instructions
function generateStructuredOutputInstructions(responseFormat: any): string {
if (!responseFormat?.fields) return ''
function generateFieldStructure(field: any): string {
if (field.type === 'object' && field.properties) {
return `{
${Object.entries(field.properties)
.map(([key, prop]: [string, any]) => `"${key}": ${prop.type === 'number' ? '0' : '"value"'}`)
.join(',\n ')}
}`
}
return field.type === 'string'
? '"value"'
: field.type === 'number'
? '0'
: field.type === 'boolean'
? 'true/false'
: '[]'
}
const exampleFormat = responseFormat.fields
.map((field: any) => ` "${field.name}": ${generateFieldStructure(field)}`)
.join(',\n')
const fieldDescriptions = responseFormat.fields
.map((field: any) => {
let desc = `${field.name} (${field.type})`
if (field.description) desc += `: ${field.description}`
if (field.type === 'object' && field.properties) {
desc += '\nProperties:'
Object.entries(field.properties).forEach(([key, prop]: [string, any]) => {
desc += `\n - ${key} (${(prop as any).type}): ${(prop as any).description || ''}`
})
}
return desc
})
.join('\n')
return `
Please provide your response in the following JSON format:
{
${exampleFormat}
}
Field descriptions:
${fieldDescriptions}
Your response MUST be valid JSON and include all the specified fields with their correct types.
Each metric should be an object containing 'score' (number) and 'reasoning' (string).`
}
export async function executeProviderRequest(
providerId: string,
request: ProviderRequest
): Promise<ProviderResponse> {
const provider = getProvider(providerId)
if (!provider) {
throw new Error(`Provider not found: ${providerId}`)
}
// Use SDK-based implementation if available
if (provider.implementationType === 'sdk' && provider.executeRequest) {
return provider.executeRequest(request)
}
// Legacy HTTP-based implementation for other providers
return executeHttpProviderRequest(provider, request)
}
// Legacy HTTP-based implementation
async function executeHttpProviderRequest(
provider: any,
request: ProviderRequest
): Promise<ProviderResponse> {
// If responseFormat is provided, modify the system prompt to enforce structured output
if (request.responseFormat) {
const structuredOutputInstructions = generateStructuredOutputInstructions(
request.responseFormat
)
request.systemPrompt = `${request.systemPrompt}\n\n${structuredOutputInstructions}`
}
// Transform tools to provider-specific function format
const functions =
request.tools && request.tools.length > 0
? provider.transformToolsToFunctions(request.tools)
: undefined
// Transform the request using provider-specific logic
const payload = provider.transformRequest(request, functions)
// Make the initial API request through the proxy
let currentResponse = await makeProxyRequest(provider.id, payload, request.apiKey)
let content = ''
let tokens: TokenInfo | undefined = undefined
let toolCalls = []
let toolResults = []
let currentMessages = [...(request.messages || [])]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
try {
while (iterationCount < MAX_ITERATIONS) {
// Transform the response using provider-specific logic
const transformedResponse = provider.transformResponse(currentResponse)
content = transformedResponse.content
// If responseFormat is specified and we have content (not a function call), validate and parse the response
if (request.responseFormat && content && !provider.hasFunctionCall(currentResponse)) {
try {
// Extract and parse the JSON content
const parsedContent = extractAndParseJSON(content)
// Validate that all required fields are present and have correct types
const validationErrors = request.responseFormat.fields
.map((field: any) => {
if (!(field.name in parsedContent)) {
return `Missing field: ${field.name}`
}
const value = parsedContent[field.name]
const type = typeof value
if (field.type === 'string' && type !== 'string') {
return `Invalid type for ${field.name}: expected string, got ${type}`
}
if (field.type === 'number' && type !== 'number') {
return `Invalid type for ${field.name}: expected number, got ${type}`
}
if (field.type === 'boolean' && type !== 'boolean') {
return `Invalid type for ${field.name}: expected boolean, got ${type}`
}
if (field.type === 'array' && !Array.isArray(value)) {
return `Invalid type for ${field.name}: expected array, got ${type}`
}
if (field.type === 'object' && (type !== 'object' || Array.isArray(value))) {
return `Invalid type for ${field.name}: expected object, got ${type}`
}
return null
})
.filter(Boolean)
if (validationErrors.length > 0) {
throw new Error(`Response format validation failed:\n${validationErrors.join('\n')}`)
}
// Store the validated JSON response
content = JSON.stringify(parsedContent)
} catch (error: any) {
console.error('Raw content:', content)
throw new Error(`Failed to parse structured response: ${error.message}`)
}
}
// Update tokens
if (transformedResponse.tokens) {
const newTokens: TokenInfo = {
prompt: (tokens?.prompt ?? 0) + (transformedResponse.tokens?.prompt ?? 0),
completion: (tokens?.completion ?? 0) + (transformedResponse.tokens?.completion ?? 0),
total: (tokens?.total ?? 0) + (transformedResponse.tokens?.total ?? 0),
}
tokens = newTokens
}
// Check for function calls using provider-specific logic
const hasFunctionCall = provider.hasFunctionCall(currentResponse)
// Break if we have content and no function call
if (!hasFunctionCall) {
break
}
// Safety check: if we have the same function call multiple times in a row
// with the same arguments, break to prevent infinite loops
let functionCall
try {
functionCall = provider.transformFunctionCallResponse(currentResponse, request.tools)
// Check if this is a duplicate call
const lastCall = toolCalls[toolCalls.length - 1]
if (
lastCall &&
lastCall.name === functionCall.name &&
JSON.stringify(lastCall.arguments) === JSON.stringify(functionCall.arguments)
) {
console.log(
'Detected duplicate function call, breaking loop to prevent infinite recursion'
)
break
}
} catch (error) {
console.log('Error transforming function call:', error)
break
}
if (!functionCall) {
break
}
// Execute the tool
const tool = getTool(functionCall.name)
if (!tool) {
break
}
const result = await executeTool(functionCall.name, functionCall.arguments)
if (!result.success) {
break
}
toolResults.push(result.output)
toolCalls.push(functionCall)
// Add the function call and result to messages
// Check if we're dealing with the new tool_calls format or the legacy function_call format
const hasToolCalls = currentResponse.choices?.[0]?.message?.tool_calls
if (hasToolCalls) {
const toolCall = currentResponse.choices[0].message.tool_calls[0]
if (toolCall && toolCall.id) {
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: functionCall.name,
arguments: JSON.stringify(functionCall.arguments),
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(result.output),
})
} else {
// Fallback to legacy format if id is missing
currentMessages.push({
role: 'assistant',
content: null,
function_call: {
name: functionCall.name,
arguments: JSON.stringify(functionCall.arguments),
},
})
currentMessages.push({
role: 'function',
name: functionCall.name,
content: JSON.stringify(result.output),
})
}
} else {
// Legacy format
currentMessages.push({
role: 'assistant',
content: null,
function_call: {
name: functionCall.name,
arguments: JSON.stringify(functionCall.arguments),
},
})
currentMessages.push({
role: 'function',
name: functionCall.name,
content: JSON.stringify(result.output),
})
}
// Prepare the next request
const nextPayload = provider.transformRequest(
{
...request,
messages: currentMessages,
},
functions
)
// Make the next request
currentResponse = await makeProxyRequest(provider.id, nextPayload, request.apiKey)
iterationCount++
}
if (iterationCount >= MAX_ITERATIONS) {
console.log('Max iterations of tool calls reached, breaking loop')
}
} catch (error) {
console.error('Error executing tool:', error)
throw error
}
return {
content,
model: currentResponse.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
}
}
async function makeProxyRequest(providerId: string, payload: any, apiKey: string) {
const baseUrl = process.env.NEXT_PUBLIC_APP_URL
if (!baseUrl) {
throw new Error('NEXT_PUBLIC_APP_URL environment variable is not set')
}
const proxyUrl = new URL('/api/proxy', baseUrl).toString()
const response = await fetch(proxyUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
toolId: `${providerId}/chat`,
params: {
...payload,
apiKey,
},
}),
})
const data = await response.json()
if (!data.success) {
throw new Error(data.error || 'Provider API error')
}
return data.output
}

View File

@@ -1,4 +1,4 @@
import { ToolConfig } from '@/tools/types'
export type ProviderId = 'openai' | 'anthropic' | 'google' | 'deepseek' | 'xai' | 'cerebras'
export interface TokenInfo {
prompt?: number
@@ -18,33 +18,7 @@ export interface ProviderConfig {
version: string
models: string[]
defaultModel: string
// Provider implementation type
implementationType: 'sdk' | 'http'
// For HTTP-based providers
baseUrl?: string
headers?: (apiKey: string) => Record<string, string>
// For SDK-based providers
executeRequest?: (request: ProviderRequest) => Promise<ProviderResponse>
// Tool calling support
transformToolsToFunctions: (tools: ProviderToolConfig[]) => any
transformFunctionCallResponse: (
response: any,
tools?: ProviderToolConfig[]
) => FunctionCallResponse
// Provider-specific request/response transformations
transformRequest: (request: ProviderRequest, functions?: any) => any
transformResponse: (response: any) => TransformedResponse
// Function to check if response contains a function call
hasFunctionCall: (response: any) => boolean
// Internal state for tool name mapping
_toolNameMapping?: Map<string, string>
}
export interface FunctionCallResponse {

View File

@@ -1,60 +1,161 @@
import { ProviderId } from './registry'
import { anthropicProvider } from './anthropic'
import { cerebrasProvider } from './cerebras'
import { deepseekProvider } from './deepseek'
import { googleProvider } from './google'
import { openaiProvider } from './openai'
import { ProviderConfig, ProviderId } from './types'
import { xAIProvider } from './xai'
/**
* Direct mapping from model names to provider IDs
* Provider configurations with associated model names/patterns
*/
export const MODEL_PROVIDERS: Record<string, ProviderId> = {
'gpt-4o': 'openai',
o1: 'openai',
'o3-mini': 'openai',
'claude-3-7-sonnet-20250219': 'anthropic',
'gemini-2.0-flash': 'google',
'grok-2-latest': 'xai',
'deepseek-v3': 'deepseek',
'deepseek-r1': 'deepseek',
'llama-3.3-70b': 'cerebras',
export const providers: Record<
ProviderId,
ProviderConfig & {
models: string[]
modelPatterns?: RegExp[]
}
> = {
openai: {
...openaiProvider,
models: ['gpt-4o', 'o1', 'o3-mini'],
modelPatterns: [/^gpt/, /^o1/],
},
anthropic: {
...anthropicProvider,
models: ['claude-3-7-sonnet-20250219'],
modelPatterns: [/^claude/],
},
google: {
...googleProvider,
models: ['gemini-2.0-flash'],
modelPatterns: [/^gemini/],
},
deepseek: {
...deepseekProvider,
models: ['deepseek-v3', 'deepseek-r1'],
modelPatterns: [/^deepseek/],
},
xai: {
...xAIProvider,
models: ['grok-2-latest'],
modelPatterns: [/^grok/],
},
cerebras: {
...cerebrasProvider,
models: ['llama-3.3-70b'],
modelPatterns: [/^llama/],
},
}
/**
* Determines the provider ID based on the model name.
* Uses the MODEL_PROVIDERS mapping and falls back to pattern matching if needed.
*
* @param model - The model name/identifier
* @returns The corresponding provider ID
* Direct mapping from model names to provider IDs
* Automatically generated from the providers configuration
*/
export const MODEL_PROVIDERS: Record<string, ProviderId> = Object.entries(providers).reduce(
(map, [providerId, config]) => {
config.models.forEach((model) => {
map[model.toLowerCase()] = providerId as ProviderId
})
return map
},
{} as Record<string, ProviderId>
)
export function getProviderFromModel(model: string): ProviderId {
const normalizedModel = model.toLowerCase()
// First try to match exactly from our MODEL_PROVIDERS mapping
if (normalizedModel in MODEL_PROVIDERS) {
return MODEL_PROVIDERS[normalizedModel]
}
// If no exact match, use pattern matching as fallback
if (normalizedModel.startsWith('gpt') || normalizedModel.startsWith('o1')) {
return 'openai'
for (const [providerId, config] of Object.entries(providers)) {
if (config.modelPatterns) {
for (const pattern of config.modelPatterns) {
if (pattern.test(normalizedModel)) {
return providerId as ProviderId
}
}
}
}
if (normalizedModel.startsWith('claude')) {
return 'anthropic'
}
if (normalizedModel.startsWith('gemini')) {
return 'google'
}
if (normalizedModel.startsWith('grok')) {
return 'xai'
}
if (normalizedModel.startsWith('llama')) {
return 'cerebras'
}
// Default to deepseek for any other models
return 'deepseek'
}
export function getProvider(id: string): ProviderConfig | undefined {
// Handle both formats: 'openai' and 'openai/chat'
const providerId = id.split('/')[0] as ProviderId
return providers[providerId]
}
export function getProviderConfigFromModel(model: string): ProviderConfig | undefined {
const providerId = getProviderFromModel(model)
return providers[providerId]
}
export function getAllModels(): string[] {
return Object.values(providers).flatMap((provider) => provider.models || [])
}
export function getAllProviderIds(): ProviderId[] {
return Object.keys(providers) as ProviderId[]
}
export function getProviderModels(providerId: ProviderId): string[] {
const provider = providers[providerId]
return provider?.models || []
}
export function generateStructuredOutputInstructions(responseFormat: any): string {
if (!responseFormat?.fields) return ''
function generateFieldStructure(field: any): string {
if (field.type === 'object' && field.properties) {
return `{
${Object.entries(field.properties)
.map(([key, prop]: [string, any]) => `"${key}": ${prop.type === 'number' ? '0' : '"value"'}`)
.join(',\n ')}
}`
}
return field.type === 'string'
? '"value"'
: field.type === 'number'
? '0'
: field.type === 'boolean'
? 'true/false'
: '[]'
}
const exampleFormat = responseFormat.fields
.map((field: any) => ` "${field.name}": ${generateFieldStructure(field)}`)
.join(',\n')
const fieldDescriptions = responseFormat.fields
.map((field: any) => {
let desc = `${field.name} (${field.type})`
if (field.description) desc += `: ${field.description}`
if (field.type === 'object' && field.properties) {
desc += '\nProperties:'
Object.entries(field.properties).forEach(([key, prop]: [string, any]) => {
desc += `\n - ${key} (${(prop as any).type}): ${(prop as any).description || ''}`
})
}
return desc
})
.join('\n')
return `
Please provide your response in the following JSON format:
{
${exampleFormat}
}
Field descriptions:
${fieldDescriptions}
Your response MUST be valid JSON and include all the specified fields with their correct types.
Each metric should be an object containing 'score' (number) and 'reasoning' (string).`
}
export function extractAndParseJSON(content: string): any {
// First clean up the string
const trimmed = content.trim()

View File

@@ -1,4 +1,5 @@
import { FunctionCallResponse, ProviderConfig, ProviderRequest, ProviderToolConfig } from '../types'
import OpenAI from 'openai'
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
export const xAIProvider: ProviderConfig = {
id: 'xai',
@@ -8,210 +9,210 @@ export const xAIProvider: ProviderConfig = {
models: ['grok-2-latest'],
defaultModel: 'grok-2-latest',
baseUrl: 'https://api.x.ai/v1/chat/completions',
headers: (apiKey: string) => ({
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
}),
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
if (!tools || tools.length === 0) {
return undefined
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
if (!request.apiKey) {
throw new Error('API key is required for xAI')
}
return tools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: tool.parameters,
},
}))
},
transformFunctionCallResponse: (
response: any,
tools?: ProviderToolConfig[]
): FunctionCallResponse => {
// xAI returns tool_calls array like OpenAI
const toolCall = response.choices?.[0]?.message?.tool_calls?.[0]
if (!toolCall || !toolCall.function) {
throw new Error('No valid tool call found in response')
}
const tool = tools?.find((t) => t.id === toolCall.function.name)
if (!tool) {
throw new Error(`Tool not found: ${toolCall.function.name}`)
}
let args = toolCall.function.arguments
if (typeof args === 'string') {
try {
args = JSON.parse(args)
} catch (e) {
console.error('Failed to parse tool arguments:', e)
args = {}
}
}
return {
name: toolCall.function.name,
arguments: {
...tool.params,
...args,
},
}
},
transformRequest: (request: ProviderRequest, functions?: any) => {
// Convert function messages to tool messages
const messages = (request.messages || []).map((msg) => {
if (msg.role === 'function') {
return {
role: 'tool',
content: msg.content,
tool_call_id: msg.name, // xAI expects tool_call_id for tool results
}
}
if (msg.function_call) {
return {
role: 'assistant',
content: null,
tool_calls: [
{
id: msg.function_call.name,
type: 'function',
function: {
name: msg.function_call.name,
arguments: msg.function_call.arguments,
},
},
],
}
}
return msg
const xai = new OpenAI({
apiKey: request.apiKey,
baseURL: 'https://api.x.ai/v1',
dangerouslyAllowBrowser: true,
})
// Add response format for structured output if specified
let systemPrompt = request.systemPrompt
if (request.responseFormat) {
systemPrompt += `\n\nYou MUST respond with a valid JSON object. DO NOT include any other text, explanations, or markdown formatting in your response - ONLY the JSON object.\n\nThe response MUST match this schema:\n${JSON.stringify(
{
type: 'object',
properties: request.responseFormat.fields.reduce(
(acc, field) => ({
...acc,
[field.name]: {
type:
field.type === 'array'
? 'array'
: field.type === 'object'
? 'object'
: field.type,
description: field.description,
},
}),
{}
),
required: request.responseFormat.fields.map((f) => f.name),
},
null,
2
)}\n\nExample response format:\n{\n${request.responseFormat.fields
.map(
(f) =>
` "${f.name}": ${
f.type === 'string'
? '"value"'
: f.type === 'number'
? '0'
: f.type === 'boolean'
? 'true'
: f.type === 'array'
? '[]'
: '{}'
}`
)
.join(',\n')}\n}`
const allMessages = []
if (request.systemPrompt) {
allMessages.push({
role: 'system',
content: request.systemPrompt,
})
}
const payload = {
model: request.model || 'grok-2-latest',
messages: [
{ role: 'system', content: systemPrompt },
...(request.context ? [{ role: 'user', content: request.context }] : []),
...messages,
],
temperature: request.temperature || 0.7,
max_tokens: request.maxTokens || 1024,
...(functions && {
tools: functions,
tool_choice: 'auto', // xAI specific parameter
}),
...(request.responseFormat && {
response_format: {
type: 'json_schema',
json_schema: {
name: 'structured_response',
schema: {
type: 'object',
properties: request.responseFormat.fields.reduce(
(acc, field) => ({
...acc,
[field.name]: {
type:
field.type === 'array'
? 'array'
: field.type === 'object'
? 'object'
: field.type === 'number'
? 'number'
: field.type === 'boolean'
? 'boolean'
: 'string',
description: field.description || '',
...(field.type === 'array' && {
items: { type: 'string' },
}),
...(field.type === 'object' && {
additionalProperties: true,
}),
},
}),
{}
),
required: request.responseFormat.fields.map((f) => f.name),
additionalProperties: false,
},
strict: true,
if (request.context) {
allMessages.push({
role: 'user',
content: request.context,
})
}
if (request.messages) {
allMessages.push(...request.messages)
}
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: tool.parameters,
},
},
}),
}))
: undefined
const payload: any = {
model: request.model || 'grok-2-latest',
messages: allMessages,
}
return payload
},
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
transformResponse: (response: any) => {
if (!response) {
console.warn('Received undefined response from xAI API')
return { content: '' }
if (request.responseFormat) {
payload.response_format = {
type: 'json_schema',
json_schema: {
name: 'structured_response',
schema: {
type: 'object',
properties: request.responseFormat.fields.reduce(
(acc, field) => ({
...acc,
[field.name]: {
type:
field.type === 'array'
? 'array'
: field.type === 'object'
? 'object'
: field.type === 'number'
? 'number'
: field.type === 'boolean'
? 'boolean'
: 'string',
description: field.description || '',
...(field.type === 'array' && {
items: { type: 'string' },
}),
...(field.type === 'object' && {
additionalProperties: true,
}),
},
}),
{}
),
required: request.responseFormat.fields.map((f) => f.name),
additionalProperties: false,
},
strict: true,
},
}
if (allMessages.length > 0 && allMessages[0].role === 'system') {
allMessages[0].content = `${allMessages[0].content}\n\nYou MUST respond with a valid JSON object. DO NOT include any other text, explanations, or markdown formatting in your response - ONLY the JSON object.`
} else {
allMessages.unshift({
role: 'system',
content: `You MUST respond with a valid JSON object. DO NOT include any other text, explanations, or markdown formatting in your response - ONLY the JSON object.`,
})
}
}
if (tools?.length) {
payload.tools = tools
payload.tool_choice = 'auto'
}
let currentResponse = await xai.chat.completions.create(payload)
let content = currentResponse.choices[0]?.message?.content || ''
let tokens = {
prompt: currentResponse.usage?.prompt_tokens || 0,
completion: currentResponse.usage?.completion_tokens || 0,
total: currentResponse.usage?.total_tokens || 0,
}
let toolCalls = []
let toolResults = []
let currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10
try {
while (iterationCount < MAX_ITERATIONS) {
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
const { executeTool } = await import('@/tools')
const mergedArgs = { ...tool.params, ...toolArgs }
console.log(`Merged tool args for ${toolName}:`, {
toolParams: tool.params,
llmArgs: toolArgs,
mergedArgs,
})
const result = await executeTool(toolName, mergedArgs, true)
if (!result.success) continue
toolResults.push(result.output)
toolCalls.push({
name: toolName,
arguments: toolArgs,
})
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(result.output),
})
} catch (error) {
console.error('Error processing tool call:', error)
}
}
const nextPayload = {
...payload,
messages: currentMessages,
}
currentResponse = await xai.chat.completions.create(nextPayload)
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.prompt_tokens || 0
tokens.completion += currentResponse.usage.completion_tokens || 0
tokens.total += currentResponse.usage.total_tokens || 0
}
iterationCount++
}
} catch (error) {
console.error('Error in xAI request:', error)
throw error
}
return {
content: response.choices?.[0]?.message?.content || '',
tokens: response.usage && {
prompt: response.usage.prompt_tokens,
completion: response.usage.completion_tokens,
total: response.usage.total_tokens,
},
content,
model: request.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
}
},
hasFunctionCall: (response: any) => {
if (!response) return false
return !!response.choices?.[0]?.message?.tool_calls?.[0]
},
}

View File

@@ -24,6 +24,7 @@ import { slackMessageTool } from './slack/message'
import { extractTool as tavilyExtract } from './tavily/extract'
import { searchTool as tavilySearch } from './tavily/search'
import { ToolConfig, ToolResponse } from './types'
import { executeRequest, formatRequestParams, validateToolRequest } from './utils'
import { readTool as xRead } from './x/read'
import { searchTool as xSearch } from './x/search'
import { userTool as xUser } from './x/user'
@@ -72,54 +73,41 @@ export function getTool(toolId: string): ToolConfig | undefined {
// Execute a tool by calling either the proxy for external APIs or directly for internal routes
export async function executeTool(
toolId: string,
params: Record<string, any>
params: Record<string, any>,
skipProxy = false
): Promise<ToolResponse> {
try {
const tool = getTool(toolId)
console.log(`Tool being called: ${toolId}`, {
params: { ...params, apiKey: params.apiKey ? '[REDACTED]' : undefined },
skipProxy,
})
// Validate the tool and its parameters
validateToolRequest(toolId, tool, params)
// After validation, we know tool exists
if (!tool) {
throw new Error(`Tool not found: ${toolId}`)
}
// For internal routes, call the API directly
if (tool.request.isInternalRoute) {
const url =
typeof tool.request.url === 'function' ? tool.request.url(params) : tool.request.url
const response = await fetch(url, {
method: tool.request.method,
headers: tool.request.headers(params),
body: JSON.stringify(tool.request.body ? tool.request.body(params) : params),
// For internal routes or when skipProxy is true, call the API directly
if (tool.request.isInternalRoute || skipProxy) {
console.log(`Calling internal request for ${toolId}`)
const result = await handleInternalRequest(toolId, tool, params)
console.log(`Tool ${toolId} execution result:`, {
success: result.success,
outputKeys: result.success ? Object.keys(result.output) : [],
error: result.error,
})
const result = await tool.transformResponse(response)
return result
}
// For external APIs, use the proxy
const baseUrl = process.env.NEXT_PUBLIC_APP_URL
if (!baseUrl) {
throw new Error('NEXT_PUBLIC_APP_URL environment variable is not set')
}
const proxyUrl = new URL('/api/proxy', baseUrl).toString()
const response = await fetch(proxyUrl, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ toolId, params }),
})
const result = await response.json()
if (!result.success) {
return {
success: false,
output: {},
error: result.error,
}
}
return result
console.log(`Calling proxy request for ${toolId}`)
return await handleProxyRequest(toolId, params)
} catch (error: any) {
console.error(`Error executing tool ${toolId}:`, error)
return {
success: false,
output: {},
@@ -127,3 +115,56 @@ export async function executeTool(
}
}
}
/**
* Handle an internal/direct tool request
*/
async function handleInternalRequest(
toolId: string,
tool: ToolConfig,
params: Record<string, any>
): Promise<ToolResponse> {
// Log the request for debugging
console.log(`Executing tool ${toolId} with params:`, {
toolId,
params: { ...params, apiKey: params.apiKey ? '[REDACTED]' : undefined },
})
// Format the request parameters
const requestParams = formatRequestParams(tool, params)
// Execute the request
return await executeRequest(toolId, tool, requestParams)
}
/**
* Handle a request via the proxy
*/
async function handleProxyRequest(
toolId: string,
params: Record<string, any>
): Promise<ToolResponse> {
const baseUrl = process.env.NEXT_PUBLIC_APP_URL
if (!baseUrl) {
throw new Error('NEXT_PUBLIC_APP_URL environment variable is not set')
}
const proxyUrl = new URL('/api/proxy', baseUrl).toString()
const response = await fetch(proxyUrl, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ toolId, params }),
})
const result = await response.json()
if (!result.success) {
return {
success: false,
output: {},
error: result.error,
}
}
return result
}

View File

@@ -1,4 +1,5 @@
import { TableRow } from './types'
import { ToolConfig, ToolResponse } from './types'
/**
* Transforms a table from the store format to a key-value object
@@ -18,3 +19,111 @@ export const transformTable = (table: TableRow[] | null): Record<string, string>
{} as Record<string, string>
)
}
interface RequestParams {
url: string
method: string
headers: Record<string, string>
body?: string
}
/**
* Format request parameters based on tool configuration and provided params
*/
export function formatRequestParams(tool: ToolConfig, params: Record<string, any>): RequestParams {
// Process URL
const url = typeof tool.request.url === 'function' ? tool.request.url(params) : tool.request.url
// Process method
const method = params.method || tool.request.method || 'GET'
// Process headers
const headers = tool.request.headers ? tool.request.headers(params) : {}
// Process body
const hasBody = method !== 'GET' && method !== 'HEAD' && !!tool.request.body
const bodyResult = tool.request.body ? tool.request.body(params) : undefined
// Special handling for NDJSON content type
const isNDJSON = headers['Content-Type'] === 'application/x-ndjson'
const body = hasBody
? isNDJSON && bodyResult
? bodyResult.body
: JSON.stringify(bodyResult)
: undefined
return { url, method, headers, body }
}
/**
* Execute the actual request and transform the response
*/
export async function executeRequest(
toolId: string,
tool: ToolConfig,
requestParams: RequestParams
): Promise<ToolResponse> {
try {
const { url, method, headers, body } = requestParams
// Log the request for debugging
console.log(`Executing tool ${toolId}:`, { url, method })
const externalResponse = await fetch(url, { method, headers, body })
// Log response status
console.log(`${toolId} response status:`, externalResponse.status, externalResponse.statusText)
if (!externalResponse.ok) {
let errorContent
try {
errorContent = await externalResponse.json()
} catch (e) {
errorContent = { message: externalResponse.statusText }
}
// Use the tool's error transformer or a default message
const error = tool.transformError
? tool.transformError(errorContent)
: errorContent.message || `${toolId} API error: ${externalResponse.statusText}`
console.error(`${toolId} error:`, error)
throw new Error(error)
}
const transformResponse =
tool.transformResponse ||
(async (resp: Response) => ({
success: true,
output: await resp.json(),
}))
return await transformResponse(externalResponse)
} catch (error: any) {
return {
success: false,
output: {},
error: error.message || 'Unknown error',
}
}
}
/**
* Validates the tool and its parameters
*/
export function validateToolRequest(
toolId: string,
tool: ToolConfig | undefined,
params: Record<string, any>
): void {
if (!tool) {
throw new Error(`Tool not found: ${toolId}`)
}
// Ensure all required parameters for tool call are provided
for (const [paramName, paramConfig] of Object.entries(tool.params)) {
if (paramConfig.requiredForToolCall && !(paramName in params)) {
throw new Error(`Parameter "${paramName}" is required for ${toolId} but was not provided`)
}
}
}