mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-09 15:07:55 -05:00
feat(providers): modified all providers to use SDK rather than creating an HTTP request
This commit is contained in:
@@ -1,118 +1,23 @@
|
||||
import { NextResponse } from 'next/server'
|
||||
import { MODEL_PROVIDERS } from '@/providers/consts'
|
||||
import { getProvider } from '@/providers/registry'
|
||||
import { getTool } from '@/tools'
|
||||
import { executeTool, getTool } from '@/tools'
|
||||
import { validateToolRequest } from '@/tools/utils'
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
const { toolId, params } = await request.json()
|
||||
|
||||
// Check if this is a provider chat request
|
||||
const provider = getProvider(toolId)
|
||||
if (provider) {
|
||||
const { apiKey, ...restParams } = params
|
||||
if (!apiKey) {
|
||||
throw new Error('API key is required')
|
||||
}
|
||||
|
||||
const response = await fetch(provider.baseUrl, {
|
||||
method: 'POST',
|
||||
headers: provider.headers(apiKey),
|
||||
body: JSON.stringify(restParams),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.json()
|
||||
throw new Error(error.error?.message || `${toolId} API error`)
|
||||
}
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
output: await response.json(),
|
||||
})
|
||||
}
|
||||
|
||||
// Check if this is an LLM provider tool (e.g., openai_chat, anthropic_chat)
|
||||
const providerPrefix = toolId.split('_')[0]
|
||||
if (Object.values(MODEL_PROVIDERS).includes(providerPrefix)) {
|
||||
// Redirect to the provider system
|
||||
const providerInstance = getProvider(providerPrefix)
|
||||
if (!providerInstance) {
|
||||
throw new Error(`Provider not found for tool: ${toolId}`)
|
||||
}
|
||||
|
||||
const { apiKey, ...restParams } = params
|
||||
if (!apiKey) {
|
||||
throw new Error('API key is required')
|
||||
}
|
||||
|
||||
const response = await fetch(providerInstance.baseUrl, {
|
||||
method: 'POST',
|
||||
headers: providerInstance.headers(apiKey),
|
||||
body: JSON.stringify(restParams),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.json()
|
||||
throw new Error(error.error?.message || `${toolId} API error`)
|
||||
}
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
output: await response.json(),
|
||||
})
|
||||
}
|
||||
|
||||
// Handle regular tool requests
|
||||
const tool = getTool(toolId)
|
||||
if (!tool) {
|
||||
throw new Error(`Tool not found: ${toolId}`)
|
||||
}
|
||||
|
||||
if (tool.params?.apiKey?.required && !params.apiKey) {
|
||||
throw new Error(`API key is required for ${toolId}`)
|
||||
}
|
||||
|
||||
const { url: urlOrFn, method: defaultMethod, headers: headersFn, body: bodyFn } = tool.request
|
||||
// Validate the tool and its parameters
|
||||
validateToolRequest(toolId, tool, params)
|
||||
|
||||
try {
|
||||
const url = typeof urlOrFn === 'function' ? urlOrFn(params) : urlOrFn
|
||||
const method = params.method || defaultMethod || 'GET'
|
||||
const headers = headersFn ? headersFn(params) : {}
|
||||
const hasBody = method !== 'GET' && method !== 'HEAD' && !!bodyFn
|
||||
|
||||
const bodyResult = bodyFn ? bodyFn(params) : undefined
|
||||
|
||||
// Special handling for NDJSON content type
|
||||
const isNDJSON = headers['Content-Type'] === 'application/x-ndjson'
|
||||
const body = hasBody
|
||||
? isNDJSON && bodyResult
|
||||
? bodyResult.body
|
||||
: JSON.stringify(bodyResult)
|
||||
: undefined
|
||||
|
||||
const externalResponse = await fetch(url, { method, headers, body })
|
||||
|
||||
if (!externalResponse.ok) {
|
||||
const errorContent = await externalResponse.json().catch(() => ({
|
||||
message: externalResponse.statusText,
|
||||
}))
|
||||
|
||||
// Use the tool's error transformer or a default message
|
||||
const error = tool.transformError
|
||||
? tool.transformError(errorContent)
|
||||
: errorContent.message || `${toolId} API error: ${externalResponse.statusText}`
|
||||
|
||||
throw new Error(error)
|
||||
if (!tool) {
|
||||
throw new Error(`Tool not found: ${toolId}`)
|
||||
}
|
||||
|
||||
const transformResponse =
|
||||
tool.transformResponse ||
|
||||
(async (resp: Response) => ({
|
||||
success: true,
|
||||
output: await resp.json(),
|
||||
}))
|
||||
const result = await transformResponse(externalResponse)
|
||||
// Use executeTool with skipProxy=true to prevent recursive proxy calls
|
||||
const result = await executeTool(toolId, params, true)
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { AgentIcon } from '@/components/icons'
|
||||
import { MODEL_PROVIDERS } from '@/providers/consts'
|
||||
import { MODEL_PROVIDERS } from '@/providers/utils'
|
||||
import { ToolResponse } from '@/tools/types'
|
||||
import { BlockConfig } from '../types'
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { ChartBarIcon } from '@/components/icons'
|
||||
import { MODEL_PROVIDERS } from '@/providers/consts'
|
||||
import { ProviderId } from '@/providers/registry'
|
||||
import { ProviderId } from '@/providers/types'
|
||||
import { MODEL_PROVIDERS } from '@/providers/utils'
|
||||
import { ToolResponse } from '@/tools/types'
|
||||
import { BlockConfig, ParamType } from '../types'
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { ConnectIcon } from '@/components/icons'
|
||||
import { MODEL_PROVIDERS } from '@/providers/consts'
|
||||
import { ProviderId } from '@/providers/registry'
|
||||
import { ProviderId } from '@/providers/types'
|
||||
import { MODEL_PROVIDERS } from '@/providers/utils'
|
||||
import { ToolResponse } from '@/tools/types'
|
||||
import { BlockConfig } from '../types'
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { TranslateIcon } from '@/components/icons'
|
||||
import { MODEL_PROVIDERS } from '@/providers/consts'
|
||||
import { ProviderId } from '@/providers/registry'
|
||||
import { ProviderId } from '@/providers/types'
|
||||
import { MODEL_PROVIDERS } from '@/providers/utils'
|
||||
import { BlockConfig } from '../types'
|
||||
|
||||
const getTranslationPrompt = (
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
// export const MODEL_TOOLS = {
|
||||
// 'gpt-4o': 'openai_chat',
|
||||
// o1: 'openai_chat',
|
||||
// 'o3-mini': 'openai_chat',
|
||||
// 'deepseek-v3': 'deepseek_chat',
|
||||
// 'deepseek-r1': 'deepseek_reasoner',
|
||||
// 'claude-3-7-sonnet-20250219': 'anthropic_chat',
|
||||
// 'gemini-2.0-flash': 'google_chat',
|
||||
// 'grok-2-latest': 'xai_chat',
|
||||
// 'llama-3.3-70b': 'cerebras_chat',
|
||||
// } as const
|
||||
|
||||
// export type ModelType = keyof typeof MODEL_TOOLS
|
||||
// export type ToolType = (typeof MODEL_TOOLS)[ModelType]
|
||||
@@ -1,4 +1,4 @@
|
||||
import { BlockState, SubBlockState } from '@/stores/workflow/types'
|
||||
import { SubBlockState } from '@/stores/workflow/types'
|
||||
import { BlockOutput, OutputConfig } from '@/blocks/types'
|
||||
|
||||
interface CodeLine {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { getAllBlocks } from '@/blocks'
|
||||
import { generateRouterPrompt } from '@/blocks/blocks/router'
|
||||
import { BlockOutput } from '@/blocks/types'
|
||||
import { executeProviderRequest } from '@/providers/service'
|
||||
import { executeProviderRequest } from '@/providers'
|
||||
import { getProviderFromModel } from '@/providers/utils'
|
||||
import { SerializedBlock } from '@/serializer/types'
|
||||
import { executeTool, getTool } from '@/tools'
|
||||
@@ -103,9 +103,6 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
.filter((t): t is NonNullable<typeof t> => t !== null)
|
||||
: []
|
||||
|
||||
// Add local_execution: true for Cerebras provider
|
||||
const additionalParams = providerId === 'cerebras' ? { local_execution: true } : {}
|
||||
|
||||
const response = await executeProviderRequest(providerId, {
|
||||
model,
|
||||
systemPrompt: inputs.systemPrompt,
|
||||
@@ -117,7 +114,6 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
maxTokens: inputs.maxTokens,
|
||||
apiKey: inputs.apiKey,
|
||||
responseFormat,
|
||||
...additionalParams,
|
||||
})
|
||||
|
||||
// Return structured or standard response based on responseFormat
|
||||
@@ -447,7 +443,7 @@ export class ApiBlockHandler implements BlockHandler {
|
||||
throw new Error(`Tool not found: ${block.config.tool}`)
|
||||
}
|
||||
|
||||
const result = await executeTool(block.config.tool, inputs)
|
||||
const result = await executeTool(block.config.tool, inputs, true)
|
||||
if (!result.success) {
|
||||
throw new Error(result.error || `API request failed with no error message`)
|
||||
}
|
||||
@@ -474,7 +470,7 @@ export class FunctionBlockHandler implements BlockHandler {
|
||||
throw new Error(`Tool not found: ${block.config.tool}`)
|
||||
}
|
||||
|
||||
const result = await executeTool(block.config.tool, inputs)
|
||||
const result = await executeTool(block.config.tool, inputs, true)
|
||||
if (!result.success) {
|
||||
throw new Error(result.error || `Function execution failed with no error message`)
|
||||
}
|
||||
@@ -502,7 +498,7 @@ export class GenericBlockHandler implements BlockHandler {
|
||||
throw new Error(`Tool not found: ${block.config.tool}`)
|
||||
}
|
||||
|
||||
const result = await executeTool(block.config.tool, inputs)
|
||||
const result = await executeTool(block.config.tool, inputs, true)
|
||||
if (!result.success) {
|
||||
throw new Error(result.error || `Block execution failed with no error message`)
|
||||
}
|
||||
|
||||
62
package-lock.json
generated
62
package-lock.json
generated
@@ -8,7 +8,7 @@
|
||||
"name": "sim",
|
||||
"version": "0.1.0",
|
||||
"dependencies": {
|
||||
"@cerebras/cerebras_cloud_sdk": "^1.23.0",
|
||||
"@anthropic-ai/sdk": "^0.38.0",
|
||||
"@radix-ui/react-alert-dialog": "^1.1.5",
|
||||
"@radix-ui/react-checkbox": "^1.1.3",
|
||||
"@radix-ui/react-dialog": "^1.1.5",
|
||||
@@ -105,6 +105,36 @@
|
||||
"node": ">=6.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@anthropic-ai/sdk": {
|
||||
"version": "0.38.0",
|
||||
"resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.38.0.tgz",
|
||||
"integrity": "sha512-ZUYjadEEeb1wwKd9MEM+plSfl7zQjYixhjHRtyPsjO7MtzRmbZvBb1n1ofo2kHwb0aUiIdb3aAEbAwS9Bcbm/A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/node": "^18.11.18",
|
||||
"@types/node-fetch": "^2.6.4",
|
||||
"abort-controller": "^3.0.0",
|
||||
"agentkeepalive": "^4.2.1",
|
||||
"form-data-encoder": "1.7.2",
|
||||
"formdata-node": "^4.3.2",
|
||||
"node-fetch": "^2.6.7"
|
||||
}
|
||||
},
|
||||
"node_modules/@anthropic-ai/sdk/node_modules/@types/node": {
|
||||
"version": "18.19.76",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.76.tgz",
|
||||
"integrity": "sha512-yvR7Q9LdPz2vGpmpJX5LolrgRdWvB67MJKDPSgIIzpFbaf9a1j/f5DnLp5VDyHGMR0QZHlTr1afsD87QCXFHKw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"undici-types": "~5.26.4"
|
||||
}
|
||||
},
|
||||
"node_modules/@anthropic-ai/sdk/node_modules/undici-types": {
|
||||
"version": "5.26.5",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
|
||||
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@babel/code-frame": {
|
||||
"version": "7.26.2",
|
||||
"resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz",
|
||||
@@ -647,36 +677,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@better-fetch/fetch/-/fetch-1.1.12.tgz",
|
||||
"integrity": "sha512-B3bfloI/2UBQWIATRN6qmlORrvx3Mp0kkNjmXLv0b+DtbtR+pP4/I5kQA/rDUv+OReLywCCldf6co4LdDmh8JA=="
|
||||
},
|
||||
"node_modules/@cerebras/cerebras_cloud_sdk": {
|
||||
"version": "1.23.0",
|
||||
"resolved": "https://registry.npmjs.org/@cerebras/cerebras_cloud_sdk/-/cerebras_cloud_sdk-1.23.0.tgz",
|
||||
"integrity": "sha512-1krbmU4nTbJICUbcJGQGGo+MtB0nzHx/jwW24ZhoBzuC5QT8H/WzNjLdKtvdf3TB8GS1AtdWUkUHNJf1EZfvJA==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@types/node": "^18.11.18",
|
||||
"@types/node-fetch": "^2.6.4",
|
||||
"abort-controller": "^3.0.0",
|
||||
"agentkeepalive": "^4.2.1",
|
||||
"form-data-encoder": "1.7.2",
|
||||
"formdata-node": "^4.3.2",
|
||||
"node-fetch": "^2.6.7"
|
||||
}
|
||||
},
|
||||
"node_modules/@cerebras/cerebras_cloud_sdk/node_modules/@types/node": {
|
||||
"version": "18.19.76",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.76.tgz",
|
||||
"integrity": "sha512-yvR7Q9LdPz2vGpmpJX5LolrgRdWvB67MJKDPSgIIzpFbaf9a1j/f5DnLp5VDyHGMR0QZHlTr1afsD87QCXFHKw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"undici-types": "~5.26.4"
|
||||
}
|
||||
},
|
||||
"node_modules/@cerebras/cerebras_cloud_sdk/node_modules/undici-types": {
|
||||
"version": "5.26.5",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
|
||||
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@drizzle-team/brocli": {
|
||||
"version": "0.10.2",
|
||||
"resolved": "https://registry.npmjs.org/@drizzle-team/brocli/-/brocli-0.10.2.tgz",
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"test:coverage": "jest --coverage"
|
||||
},
|
||||
"dependencies": {
|
||||
"@cerebras/cerebras_cloud_sdk": "^1.23.0",
|
||||
"@anthropic-ai/sdk": "^0.38.0",
|
||||
"@radix-ui/react-alert-dialog": "^1.1.5",
|
||||
"@radix-ui/react-checkbox": "^1.1.3",
|
||||
"@radix-ui/react-dialog": "^1.1.5",
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import {
|
||||
FunctionCallResponse,
|
||||
Message,
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderToolConfig,
|
||||
} from '../types'
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { executeTool } from '@/tools'
|
||||
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
|
||||
|
||||
export const anthropicProvider: ProviderConfig = {
|
||||
id: 'anthropic',
|
||||
@@ -14,73 +10,40 @@ export const anthropicProvider: ProviderConfig = {
|
||||
models: ['claude-3-7-sonnet-20250219'],
|
||||
defaultModel: 'claude-3-7-sonnet-20250219',
|
||||
|
||||
baseUrl: 'https://api.anthropic.com/v1/messages',
|
||||
headers: (apiKey: string) => ({
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': apiKey,
|
||||
'anthropic-version': '2023-06-01',
|
||||
}),
|
||||
|
||||
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
|
||||
if (!tools || tools.length === 0) {
|
||||
return undefined
|
||||
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
|
||||
if (!request.apiKey) {
|
||||
throw new Error('API key is required for Anthropic')
|
||||
}
|
||||
|
||||
return tools.map((tool) => ({
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: 'object',
|
||||
properties: tool.parameters.properties,
|
||||
required: tool.parameters.required,
|
||||
},
|
||||
}))
|
||||
},
|
||||
const anthropic = new Anthropic({
|
||||
apiKey: request.apiKey,
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
|
||||
transformFunctionCallResponse: (
|
||||
response: any,
|
||||
tools?: ProviderToolConfig[]
|
||||
): FunctionCallResponse => {
|
||||
const rawResponse = response?.output || response
|
||||
if (!rawResponse?.content) {
|
||||
throw new Error('No content found in response')
|
||||
// Helper function to generate a simple unique ID for tool uses
|
||||
const generateToolUseId = (toolName: string) => {
|
||||
return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}`
|
||||
}
|
||||
|
||||
const toolUse = rawResponse.content.find((item: any) => item.type === 'tool_use')
|
||||
if (!toolUse) {
|
||||
throw new Error('No tool use found in response')
|
||||
}
|
||||
|
||||
const tool = tools?.find((t) => t.id === toolUse.name)
|
||||
if (!tool) {
|
||||
throw new Error(`Tool not found: ${toolUse.name}`)
|
||||
}
|
||||
|
||||
let input = toolUse.input
|
||||
if (typeof input === 'string') {
|
||||
try {
|
||||
input = JSON.parse(input)
|
||||
} catch (e) {
|
||||
console.error('Failed to parse tool input:', e)
|
||||
input = {}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
name: toolUse.name,
|
||||
arguments: {
|
||||
...tool.params,
|
||||
...input,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
transformRequest: (request: ProviderRequest, functions?: any) => {
|
||||
// Transform messages to Anthropic format
|
||||
const messages =
|
||||
request.messages?.map((msg) => {
|
||||
const messages = []
|
||||
|
||||
// Add system prompt if present
|
||||
let systemPrompt = request.systemPrompt || ''
|
||||
|
||||
// Add context if present
|
||||
if (request.context) {
|
||||
messages.push({
|
||||
role: 'user',
|
||||
content: request.context,
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
request.messages.forEach((msg) => {
|
||||
if (msg.role === 'function') {
|
||||
return {
|
||||
messages.push({
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
@@ -89,58 +52,55 @@ export const anthropicProvider: ProviderConfig = {
|
||||
content: msg.content,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
if (msg.function_call) {
|
||||
return {
|
||||
})
|
||||
} else if (msg.function_call) {
|
||||
const toolUseId = msg.function_call.name + '-' + Date.now()
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_use',
|
||||
id: msg.function_call.name,
|
||||
id: toolUseId,
|
||||
name: msg.function_call.name,
|
||||
input: JSON.parse(msg.function_call.arguments),
|
||||
},
|
||||
],
|
||||
}
|
||||
})
|
||||
} else {
|
||||
messages.push({
|
||||
role: msg.role === 'assistant' ? 'assistant' : 'user',
|
||||
content: msg.content ? [{ type: 'text', text: msg.content }] : [],
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
role: msg.role === 'assistant' ? 'assistant' : 'user',
|
||||
content: msg.content ? [{ type: 'text', text: msg.content }] : [],
|
||||
}
|
||||
}) || []
|
||||
|
||||
// Add context if provided
|
||||
if (request.context) {
|
||||
messages.unshift({
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: request.context }],
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure there's at least one message by adding the system prompt as a user message if no messages exist
|
||||
// Ensure there's at least one message
|
||||
if (messages.length === 0) {
|
||||
messages.push({
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: request.systemPrompt || '' }],
|
||||
content: [{ type: 'text', text: systemPrompt || 'Hello' }],
|
||||
})
|
||||
// Clear system prompt since we've used it as a user message
|
||||
systemPrompt = ''
|
||||
}
|
||||
|
||||
// Build the request payload
|
||||
const payload = {
|
||||
model: request.model || 'claude-3-7-sonnet-20250219',
|
||||
messages,
|
||||
system: request.systemPrompt || '',
|
||||
max_tokens: parseInt(String(request.maxTokens)) || 1024,
|
||||
temperature: parseFloat(String(request.temperature ?? 0.7)),
|
||||
...(functions && { tools: functions }),
|
||||
}
|
||||
// Transform tools to Anthropic format if provided
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: 'object',
|
||||
properties: tool.parameters.properties,
|
||||
required: tool.parameters.required,
|
||||
},
|
||||
}))
|
||||
: undefined
|
||||
|
||||
// If response format is specified, add strict formatting instructions
|
||||
if (request.responseFormat) {
|
||||
payload.system = `${payload.system}\n\nIMPORTANT RESPONSE FORMAT INSTRUCTIONS:
|
||||
systemPrompt = `${systemPrompt}\n\nIMPORTANT RESPONSE FORMAT INSTRUCTIONS:
|
||||
1. Your response must be EXACTLY in this format, with no additional fields:
|
||||
{
|
||||
${request.responseFormat.fields.map((field) => ` "${field.name}": ${field.type === 'string' ? '"value"' : field.type === 'array' ? '[]' : field.type === 'object' ? '{}' : field.type === 'number' ? '0' : 'true/false'}`).join(',\n')}
|
||||
@@ -155,67 +115,163 @@ ${request.responseFormat.fields.map((field) => `${field.name} (${field.type})${f
|
||||
5. Your response MUST be valid JSON and include all the specified fields with their correct types`
|
||||
}
|
||||
|
||||
return payload
|
||||
},
|
||||
// Build the request payload
|
||||
const payload: any = {
|
||||
model: request.model || 'claude-3-7-sonnet-20250219',
|
||||
messages,
|
||||
system: systemPrompt,
|
||||
max_tokens: parseInt(String(request.maxTokens)) || 1024,
|
||||
temperature: parseFloat(String(request.temperature ?? 0.7)),
|
||||
}
|
||||
|
||||
// Add tools if provided
|
||||
if (tools?.length) {
|
||||
payload.tools = tools
|
||||
}
|
||||
|
||||
// Make the initial API request
|
||||
let currentResponse = await anthropic.messages.create(payload)
|
||||
let content = ''
|
||||
|
||||
// Extract text content from the message
|
||||
if (Array.isArray(currentResponse.content)) {
|
||||
content = currentResponse.content
|
||||
.filter((item) => item.type === 'text')
|
||||
.map((item) => item.text)
|
||||
.join('\n')
|
||||
}
|
||||
|
||||
let tokens = {
|
||||
prompt: currentResponse.usage?.input_tokens || 0,
|
||||
completion: currentResponse.usage?.output_tokens || 0,
|
||||
total:
|
||||
(currentResponse.usage?.input_tokens || 0) + (currentResponse.usage?.output_tokens || 0),
|
||||
}
|
||||
|
||||
let toolCalls = []
|
||||
let toolResults = []
|
||||
let currentMessages = [...messages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
transformResponse: (response: any) => {
|
||||
try {
|
||||
if (!response) {
|
||||
console.warn('Received undefined response from Anthropic API')
|
||||
return { content: '' }
|
||||
}
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolUses = currentResponse.content.filter((item) => item.type === 'tool_use')
|
||||
if (!toolUses || toolUses.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
// Get the actual response content
|
||||
const rawResponse = response.output || response
|
||||
// Process each tool call
|
||||
for (const toolUse of toolUses) {
|
||||
try {
|
||||
const toolName = toolUse.name
|
||||
const toolArgs = toolUse.input as Record<string, any>
|
||||
|
||||
// Extract text content from the message
|
||||
let content = ''
|
||||
const messageContent = rawResponse?.content || rawResponse?.message?.content
|
||||
// Get the tool from the tools registry
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
if (Array.isArray(messageContent)) {
|
||||
content = messageContent
|
||||
// Execute the tool
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const result = await executeTool(toolName, mergedArgs, true)
|
||||
|
||||
if (!result.success) continue
|
||||
|
||||
toolResults.push(result.output)
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolArgs,
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages
|
||||
const toolUseId = generateToolUseId(toolName)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_use',
|
||||
id: toolUseId,
|
||||
name: toolName,
|
||||
input: toolArgs,
|
||||
} as any,
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: toolUseId,
|
||||
content: JSON.stringify(result.output),
|
||||
} as any,
|
||||
],
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Error processing tool call:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// Make the next request with updated messages
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
// Make the next request
|
||||
currentResponse = await anthropic.messages.create(nextPayload)
|
||||
|
||||
// Update content if we have a text response
|
||||
const textContent = currentResponse.content
|
||||
.filter((item) => item.type === 'text')
|
||||
.map((item) => item.text)
|
||||
.join('\n')
|
||||
} else if (typeof messageContent === 'string') {
|
||||
content = messageContent
|
||||
}
|
||||
|
||||
// If the content looks like it contains JSON, extract just the JSON part
|
||||
if (content.includes('{') && content.includes('}')) {
|
||||
try {
|
||||
const jsonMatch = content.match(/\{[\s\S]*\}/m)
|
||||
if (jsonMatch) {
|
||||
content = jsonMatch[0]
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error extracting JSON from response:', e)
|
||||
if (textContent) {
|
||||
content = textContent
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content,
|
||||
model: rawResponse?.model || response?.model || 'claude-3-7-sonnet-20250219',
|
||||
tokens: rawResponse?.usage && {
|
||||
prompt: rawResponse.usage.input_tokens,
|
||||
completion: rawResponse.usage.output_tokens,
|
||||
total: rawResponse.usage.input_tokens + rawResponse.usage.output_tokens,
|
||||
},
|
||||
// Update token counts
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.input_tokens || 0
|
||||
tokens.completion += currentResponse.usage.output_tokens || 0
|
||||
tokens.total +=
|
||||
(currentResponse.usage.input_tokens || 0) + (currentResponse.usage.output_tokens || 0)
|
||||
}
|
||||
|
||||
iterationCount++
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in transformResponse:', error)
|
||||
return { content: '' }
|
||||
console.error('Error in Anthropic request:', error)
|
||||
throw error
|
||||
}
|
||||
},
|
||||
|
||||
hasFunctionCall: (response: any) => {
|
||||
try {
|
||||
if (!response) return false
|
||||
const rawResponse = response.output || response
|
||||
return rawResponse?.content?.some((item: any) => item.type === 'tool_use') || false
|
||||
} catch (error) {
|
||||
console.error('Error in hasFunctionCall:', error)
|
||||
return false
|
||||
// If the content looks like it contains JSON, extract just the JSON part
|
||||
if (content.includes('{') && content.includes('}')) {
|
||||
try {
|
||||
const jsonMatch = content.match(/\{[\s\S]*\}/m)
|
||||
if (jsonMatch) {
|
||||
content = jsonMatch[0]
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error extracting JSON from response:', e)
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content,
|
||||
model: request.model || 'claude-3-7-sonnet-20250219',
|
||||
tokens,
|
||||
toolCalls:
|
||||
toolCalls.length > 0
|
||||
? toolCalls.map((tc) => ({
|
||||
name: tc.name,
|
||||
arguments: tc.arguments as Record<string, any>,
|
||||
}))
|
||||
: undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
import { Cerebras } from '@cerebras/cerebras_cloud_sdk'
|
||||
import {
|
||||
FunctionCallResponse,
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
ProviderToolConfig,
|
||||
} from '../types'
|
||||
import OpenAI from 'openai'
|
||||
import { executeTool } from '@/tools'
|
||||
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
|
||||
|
||||
export const cerebrasProvider: ProviderConfig = {
|
||||
id: 'cerebras',
|
||||
@@ -14,153 +9,18 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
version: '1.0.0',
|
||||
models: ['llama-3.3-70b'],
|
||||
defaultModel: 'llama-3.3-70b',
|
||||
implementationType: 'sdk',
|
||||
|
||||
// SDK-based implementation
|
||||
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
|
||||
try {
|
||||
const client = new Cerebras({
|
||||
apiKey: request.apiKey,
|
||||
})
|
||||
|
||||
// Start with an empty array for all messages
|
||||
const allMessages = []
|
||||
|
||||
// Add system prompt if present
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
content: request.systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
// Add context if present
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
content: request.context,
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to Cerebras format if provided
|
||||
const functions = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
}))
|
||||
: undefined
|
||||
|
||||
// Build the request payload
|
||||
const payload: any = {
|
||||
model: request.model || 'llama-3.3-70b',
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
// Add optional parameters
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
// Add functions if provided
|
||||
if (functions?.length) {
|
||||
payload.functions = functions
|
||||
payload.function_call = 'auto'
|
||||
}
|
||||
|
||||
// Add local execution flag if specified
|
||||
if (request.local_execution) {
|
||||
payload.local_execution = true
|
||||
}
|
||||
|
||||
// Execute the request using the SDK
|
||||
const response = (await client.chat.completions.create(payload)) as any
|
||||
|
||||
// Extract content and token info
|
||||
const content = response.choices?.[0]?.message?.content || ''
|
||||
const tokens = {
|
||||
prompt: response.usage?.prompt_tokens || 0,
|
||||
completion: response.usage?.completion_tokens || 0,
|
||||
total: response.usage?.total_tokens || 0,
|
||||
}
|
||||
|
||||
// Check for function calls
|
||||
const functionCall = response.choices?.[0]?.message?.function_call
|
||||
let toolCalls = undefined
|
||||
|
||||
if (functionCall) {
|
||||
const tool = request.tools?.find((t) => t.id === functionCall.name)
|
||||
const toolParams = tool?.params || {}
|
||||
|
||||
toolCalls = [
|
||||
{
|
||||
name: functionCall.name,
|
||||
arguments: {
|
||||
...toolParams,
|
||||
...JSON.parse(functionCall.arguments),
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
// Return the response in the expected format
|
||||
return {
|
||||
content,
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls,
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error('Error executing Cerebras request:', error)
|
||||
throw new Error(`Cerebras API error: ${error.message}`)
|
||||
}
|
||||
},
|
||||
|
||||
// These are still needed for backward compatibility
|
||||
baseUrl: 'https://api.cerebras.cloud/v1/chat/completions',
|
||||
headers: (apiKey: string) => ({
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
}),
|
||||
|
||||
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
|
||||
if (!tools || tools.length === 0) {
|
||||
return undefined
|
||||
if (!request.apiKey) {
|
||||
throw new Error('API key is required for Cerebras')
|
||||
}
|
||||
|
||||
return tools.map((tool) => ({
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
}))
|
||||
},
|
||||
const openai = new OpenAI({
|
||||
apiKey: request.apiKey,
|
||||
baseURL: 'https://api.cerebras.ai/v1',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
|
||||
transformFunctionCallResponse: (
|
||||
response: any,
|
||||
tools?: ProviderToolConfig[]
|
||||
): FunctionCallResponse => {
|
||||
const functionCall = response.choices?.[0]?.message?.function_call
|
||||
if (!functionCall) {
|
||||
throw new Error('No function call found in response')
|
||||
}
|
||||
|
||||
const tool = tools?.find((t) => t.id === functionCall.name)
|
||||
const toolParams = tool?.params || {}
|
||||
|
||||
return {
|
||||
name: functionCall.name,
|
||||
arguments: {
|
||||
...toolParams,
|
||||
...JSON.parse(functionCall.arguments),
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
transformRequest: (request: ProviderRequest, functions?: any) => {
|
||||
// Start with an empty array for all messages
|
||||
const allMessages = []
|
||||
|
||||
@@ -185,6 +45,18 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to OpenAI format if provided
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}))
|
||||
: undefined
|
||||
|
||||
// Build the request payload
|
||||
const payload: any = {
|
||||
model: request.model || 'llama-3.3-70b',
|
||||
@@ -195,38 +67,128 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
// Add function calling support
|
||||
if (functions) {
|
||||
payload.functions = functions
|
||||
payload.function_call = 'auto'
|
||||
// Add response format for structured output if specified
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = { type: 'json_object' }
|
||||
}
|
||||
|
||||
// Add local execution flag if specified
|
||||
// Add tools if provided
|
||||
if (tools?.length) {
|
||||
payload.tools = tools
|
||||
payload.tool_choice = 'auto'
|
||||
}
|
||||
|
||||
// Add local execution flag if specified by Cerebras
|
||||
if (request.local_execution) {
|
||||
payload.local_execution = true
|
||||
}
|
||||
|
||||
return payload
|
||||
},
|
||||
|
||||
transformResponse: (response: any) => {
|
||||
const output = {
|
||||
content: response.choices?.[0]?.message?.content || '',
|
||||
tokens: undefined as any,
|
||||
// Make the initial API request
|
||||
let currentResponse = await openai.chat.completions.create(payload)
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
let tokens = {
|
||||
prompt: currentResponse.usage?.prompt_tokens || 0,
|
||||
completion: currentResponse.usage?.completion_tokens || 0,
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
let toolCalls = []
|
||||
let toolResults = []
|
||||
let currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
if (response.usage) {
|
||||
output.tokens = {
|
||||
prompt: response.usage.prompt_tokens,
|
||||
completion: response.usage.completion_tokens,
|
||||
total: response.usage.total_tokens,
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
// Process each tool call
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
// Get the tool from the tools registry
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
// Execute the tool
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const result = await executeTool(toolName, mergedArgs, true)
|
||||
|
||||
if (!result.success) continue
|
||||
|
||||
toolResults.push(result.output)
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolArgs,
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const toolResultContent = JSON.stringify(result.output)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: toolResultContent,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Error processing tool call:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// Make the next request with updated messages
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
// Make the next request
|
||||
currentResponse = await openai.chat.completions.create(nextPayload)
|
||||
|
||||
// Update content if we have a text response
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
// Update token counts
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += currentResponse.usage.completion_tokens || 0
|
||||
tokens.total += currentResponse.usage.total_tokens || 0
|
||||
}
|
||||
|
||||
iterationCount++
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in Cerebras request:', error)
|
||||
throw error
|
||||
}
|
||||
|
||||
return output
|
||||
},
|
||||
|
||||
hasFunctionCall: (response: any) => {
|
||||
return !!response.choices?.[0]?.message?.function_call
|
||||
return {
|
||||
content,
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
import { ProviderId } from './registry'
|
||||
|
||||
/**
|
||||
* Direct mapping from model names to provider IDs
|
||||
* This replaces the need for the MODEL_TOOLS mapping in blocks/consts.ts
|
||||
*/
|
||||
export const MODEL_PROVIDERS: Record<string, ProviderId> = {
|
||||
'gpt-4o': 'openai',
|
||||
o1: 'openai',
|
||||
'o3-mini': 'openai',
|
||||
'claude-3-7-sonnet-20250219': 'anthropic',
|
||||
'gemini-2.0-flash': 'google',
|
||||
'grok-2-latest': 'xai',
|
||||
'deepseek-v3': 'deepseek',
|
||||
'deepseek-r1': 'deepseek',
|
||||
'llama-3.3-70b': 'cerebras',
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
import { FunctionCallResponse, ProviderConfig, ProviderRequest, ProviderToolConfig } from '../types'
|
||||
import OpenAI from 'openai'
|
||||
import { executeTool } from '@/tools'
|
||||
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
|
||||
|
||||
export const deepseekProvider: ProviderConfig = {
|
||||
id: 'deepseek',
|
||||
@@ -8,116 +10,74 @@ export const deepseekProvider: ProviderConfig = {
|
||||
models: ['deepseek-chat'],
|
||||
defaultModel: 'deepseek-chat',
|
||||
|
||||
baseUrl: 'https://api.deepseek.com/v1/chat/completions',
|
||||
headers: (apiKey: string) => ({
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
}),
|
||||
|
||||
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
|
||||
if (!tools || tools.length === 0) {
|
||||
return undefined
|
||||
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
|
||||
if (!request.apiKey) {
|
||||
throw new Error('API key is required for Deepseek')
|
||||
}
|
||||
|
||||
return tools.map((tool) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}))
|
||||
},
|
||||
|
||||
transformFunctionCallResponse: (
|
||||
response: any,
|
||||
tools?: ProviderToolConfig[]
|
||||
): FunctionCallResponse => {
|
||||
const toolCall = response.choices?.[0]?.message?.tool_calls?.[0]
|
||||
if (!toolCall || !toolCall.function) {
|
||||
throw new Error('No valid tool call found in response')
|
||||
}
|
||||
|
||||
const tool = tools?.find((t) => t.id === toolCall.function.name)
|
||||
if (!tool) {
|
||||
throw new Error(`Tool not found: ${toolCall.function.name}`)
|
||||
}
|
||||
|
||||
let args = toolCall.function.arguments
|
||||
if (typeof args === 'string') {
|
||||
try {
|
||||
args = JSON.parse(args)
|
||||
} catch (e) {
|
||||
console.error('Failed to parse tool arguments:', e)
|
||||
args = {}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
name: toolCall.function.name,
|
||||
arguments: {
|
||||
...tool.params,
|
||||
...args,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
transformRequest: (request: ProviderRequest, functions?: any) => {
|
||||
// Transform messages from internal format to Deepseek format
|
||||
const messages = (request.messages || []).map((msg) => {
|
||||
if (msg.role === 'function') {
|
||||
return {
|
||||
role: 'tool',
|
||||
content: msg.content,
|
||||
tool_call_id: msg.name,
|
||||
}
|
||||
}
|
||||
|
||||
if (msg.function_call) {
|
||||
return {
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: msg.function_call.name,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: msg.function_call.name,
|
||||
arguments: msg.function_call.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
return msg
|
||||
// Deepseek uses the OpenAI SDK with a custom baseURL
|
||||
const deepseek = new OpenAI({
|
||||
apiKey: request.apiKey,
|
||||
baseURL: 'https://api.deepseek.com/v1',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
|
||||
const payload = {
|
||||
model: 'deepseek-chat',
|
||||
messages: [
|
||||
{ role: 'system', content: request.systemPrompt },
|
||||
...(request.context ? [{ role: 'user', content: request.context }] : []),
|
||||
...messages,
|
||||
],
|
||||
temperature: request.temperature || 0.7,
|
||||
max_tokens: request.maxTokens || 1024,
|
||||
...(functions && { tools: functions }),
|
||||
// Start with an empty array for all messages
|
||||
const allMessages = []
|
||||
|
||||
// Add system prompt if present
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
content: request.systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
return payload
|
||||
},
|
||||
|
||||
transformResponse: (response: any) => {
|
||||
if (!response) {
|
||||
console.warn('Received undefined response from Deepseek API')
|
||||
return { content: '' }
|
||||
// Add context if present
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
content: request.context,
|
||||
})
|
||||
}
|
||||
|
||||
const output = response.choices?.[0]?.message
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Try to clean up the response content if it exists
|
||||
let content = output?.content || ''
|
||||
// Transform tools to OpenAI format if provided
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}))
|
||||
: undefined
|
||||
|
||||
const payload: any = {
|
||||
model: 'deepseek-chat', // Hardcode to deepseek-chat regardless of what's selected in the UI
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
// Add optional parameters
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
// Add tools if provided
|
||||
if (tools?.length) {
|
||||
payload.tools = tools
|
||||
payload.tool_choice = 'auto'
|
||||
}
|
||||
|
||||
// Make the initial API request
|
||||
let currentResponse = await deepseek.chat.completions.create(payload)
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
|
||||
// Clean up the response content if it exists
|
||||
if (content) {
|
||||
// Remove any markdown code block markers
|
||||
content = content.replace(/```json\n?|\n?```/g, '')
|
||||
@@ -125,18 +85,110 @@ export const deepseekProvider: ProviderConfig = {
|
||||
content = content.trim()
|
||||
}
|
||||
|
||||
let tokens = {
|
||||
prompt: currentResponse.usage?.prompt_tokens || 0,
|
||||
completion: currentResponse.usage?.completion_tokens || 0,
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
let toolCalls = []
|
||||
let toolResults = []
|
||||
let currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
// Process each tool call
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
// Get the tool from the tools registry
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
// Execute the tool
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const result = await executeTool(toolName, mergedArgs, true)
|
||||
|
||||
if (!result.success) continue
|
||||
|
||||
toolResults.push(result.output)
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolArgs,
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(result.output),
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Error processing tool call:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// Make the next request with updated messages
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
// Make the next request
|
||||
currentResponse = await deepseek.chat.completions.create(nextPayload)
|
||||
|
||||
// Update content if we have a text response
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
// Clean up the response content
|
||||
content = content.replace(/```json\n?|\n?```/g, '')
|
||||
content = content.trim()
|
||||
}
|
||||
|
||||
// Update token counts
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += currentResponse.usage.completion_tokens || 0
|
||||
tokens.total += currentResponse.usage.total_tokens || 0
|
||||
}
|
||||
|
||||
iterationCount++
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in Deepseek request:', error)
|
||||
throw error
|
||||
}
|
||||
|
||||
return {
|
||||
content,
|
||||
tokens: response.usage && {
|
||||
prompt: response.usage.prompt_tokens,
|
||||
completion: response.usage.completion_tokens,
|
||||
total: response.usage.total_tokens,
|
||||
},
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
}
|
||||
},
|
||||
|
||||
hasFunctionCall: (response: any) => {
|
||||
if (!response) return false
|
||||
return !!response.choices?.[0]?.message?.tool_calls?.[0]
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,205 +1,189 @@
|
||||
import { ToolConfig } from '@/tools/types'
|
||||
import { FunctionCallResponse, ProviderConfig, ProviderRequest, ProviderToolConfig } from '../types'
|
||||
import OpenAI from 'openai'
|
||||
import { executeTool } from '@/tools'
|
||||
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
|
||||
|
||||
export const googleProvider: ProviderConfig = {
|
||||
id: 'google',
|
||||
name: 'Google',
|
||||
description: "Google's Gemini models",
|
||||
version: '1.0.0',
|
||||
models: ['gemini-2.0-flash-001'],
|
||||
defaultModel: 'gemini-2.0-flash-001',
|
||||
models: ['gemini-2.0-flash'],
|
||||
defaultModel: 'gemini-2.0-flash',
|
||||
|
||||
baseUrl:
|
||||
'https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-001:generateContent',
|
||||
headers: (apiKey: string) => ({
|
||||
'Content-Type': 'application/json',
|
||||
'x-goog-api-key': apiKey,
|
||||
}),
|
||||
|
||||
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
|
||||
if (!tools || tools.length === 0) {
|
||||
return undefined
|
||||
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
|
||||
if (!request.apiKey) {
|
||||
throw new Error('API key is required for Google Gemini')
|
||||
}
|
||||
|
||||
const transformProperties = (properties: Record<string, any>): Record<string, any> => {
|
||||
return Object.entries(properties).reduce((acc, [key, value]) => {
|
||||
// Skip complex JSON/object parameters for Gemini
|
||||
if (value.type === 'json' || (value.type === 'object' && !value.properties)) {
|
||||
return acc
|
||||
}
|
||||
const openai = new OpenAI({
|
||||
apiKey: request.apiKey,
|
||||
baseURL: 'https://generativelanguage.googleapis.com/v1beta/openai/',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
|
||||
// For object types with defined properties
|
||||
if (value.type === 'object' && value.properties) {
|
||||
return {
|
||||
...acc,
|
||||
[key]: {
|
||||
type: 'OBJECT',
|
||||
description: value.description || '',
|
||||
properties: transformProperties(value.properties),
|
||||
},
|
||||
}
|
||||
}
|
||||
// Start with an empty array for all messages
|
||||
const allMessages = []
|
||||
|
||||
// For simple types
|
||||
return {
|
||||
...acc,
|
||||
[key]: {
|
||||
type: (value.type || 'string').toUpperCase(),
|
||||
description: value.description || '',
|
||||
// Add system prompt if present
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
content: request.systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
// Add context if present
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
content: request.context,
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to OpenAI format if provided
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}))
|
||||
: undefined
|
||||
|
||||
// Build the request payload
|
||||
const payload: any = {
|
||||
model: request.model || 'gemini-2.0-flash',
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
// Add optional parameters
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
// Add response format for structured output if specified
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = { type: 'json_object' }
|
||||
}
|
||||
|
||||
// Add tools if provided
|
||||
if (tools?.length) {
|
||||
payload.tools = tools
|
||||
payload.tool_choice = 'auto'
|
||||
}
|
||||
|
||||
// Make the initial API request
|
||||
let currentResponse = await openai.chat.completions.create(payload)
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
let tokens = {
|
||||
prompt: currentResponse.usage?.prompt_tokens || 0,
|
||||
completion: currentResponse.usage?.completion_tokens || 0,
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
let toolCalls = []
|
||||
let toolResults = []
|
||||
let currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
}, {})
|
||||
}
|
||||
|
||||
return {
|
||||
functionDeclarations: tools.map((tool) => {
|
||||
// Get properties excluding complex types
|
||||
const properties = transformProperties(tool.parameters.properties)
|
||||
// Process each tool call
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
// Filter required fields to only include ones that exist in properties
|
||||
const required = (tool.parameters.required || []).filter((field) => field in properties)
|
||||
// Get the tool from the tools registry
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
return {
|
||||
name: tool.id,
|
||||
description: tool.description || '',
|
||||
parameters: {
|
||||
type: 'OBJECT',
|
||||
properties,
|
||||
required,
|
||||
},
|
||||
}
|
||||
}),
|
||||
}
|
||||
},
|
||||
// Execute the tool
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const result = await executeTool(toolName, mergedArgs, true)
|
||||
|
||||
transformFunctionCallResponse: (
|
||||
response: any,
|
||||
tools?: ProviderToolConfig[]
|
||||
): FunctionCallResponse => {
|
||||
// Extract function call from Gemini response
|
||||
const functionCall = response.candidates?.[0]?.content?.parts?.[0]?.functionCall
|
||||
if (!functionCall) {
|
||||
throw new Error('No function call found in response')
|
||||
}
|
||||
if (!result.success) continue
|
||||
|
||||
// Log the function call for debugging
|
||||
console.log('Raw function call from Gemini:', JSON.stringify(functionCall, null, 2))
|
||||
toolResults.push(result.output)
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolArgs,
|
||||
})
|
||||
|
||||
const tool = tools?.find((t) => t.id === functionCall.name)
|
||||
if (!tool) {
|
||||
throw new Error(`Tool not found: ${functionCall.name}`)
|
||||
}
|
||||
|
||||
// Ensure args is an object
|
||||
let args = functionCall.args
|
||||
if (typeof args === 'string') {
|
||||
try {
|
||||
args = JSON.parse(args)
|
||||
} catch (e) {
|
||||
console.error('Failed to parse function call args:', e)
|
||||
args = {}
|
||||
}
|
||||
}
|
||||
|
||||
// Get arguments from function call, but NEVER override apiKey
|
||||
const { apiKey: _, ...functionArgs } = args
|
||||
|
||||
return {
|
||||
name: functionCall.name,
|
||||
arguments: {
|
||||
...functionArgs,
|
||||
apiKey: tool.params.apiKey, // Always use the apiKey from tool params
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
transformRequest: (request: ProviderRequest, functions?: any) => {
|
||||
// Combine system prompt and context into a single message if both exist
|
||||
const initialMessage = request.systemPrompt + (request.context ? `\n\n${request.context}` : '')
|
||||
|
||||
const messages = [
|
||||
{ role: 'user', parts: [{ text: initialMessage }] },
|
||||
...(request.messages || []).map((msg) => {
|
||||
if (msg.role === 'function') {
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: msg.name,
|
||||
response: {
|
||||
name: msg.name,
|
||||
content: JSON.parse(msg.content || '{}'),
|
||||
// Add the tool call and result to messages
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
],
|
||||
})
|
||||
|
||||
const toolResultContent = JSON.stringify(result.output)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: toolResultContent,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Error processing tool call:', error)
|
||||
}
|
||||
}
|
||||
|
||||
if (msg.function_call) {
|
||||
return {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
name: msg.function_call.name,
|
||||
args: JSON.parse(msg.function_call.arguments),
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
// Make the next request with updated messages
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
return {
|
||||
role: msg.role === 'assistant' ? 'model' : 'user',
|
||||
parts: [{ text: msg.content || '' }],
|
||||
// Make the next request
|
||||
currentResponse = await openai.chat.completions.create(nextPayload)
|
||||
|
||||
// Update content if we have a text response
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
}),
|
||||
]
|
||||
|
||||
// Log the request for debugging
|
||||
console.log(
|
||||
'Gemini request:',
|
||||
JSON.stringify(
|
||||
{
|
||||
messages,
|
||||
tools: functions ? [{ functionDeclarations: functions.functionDeclarations }] : undefined,
|
||||
},
|
||||
null,
|
||||
2
|
||||
)
|
||||
)
|
||||
// Update token counts
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += currentResponse.usage.completion_tokens || 0
|
||||
tokens.total += currentResponse.usage.total_tokens || 0
|
||||
}
|
||||
|
||||
return {
|
||||
contents: messages,
|
||||
tools: functions ? [{ functionDeclarations: functions.functionDeclarations }] : undefined,
|
||||
generationConfig: {
|
||||
temperature: request.temperature || 0.7,
|
||||
maxOutputTokens: request.maxTokens || 1024,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
transformResponse: (response: any) => {
|
||||
let content = response.candidates?.[0]?.content?.parts?.[0]?.text || ''
|
||||
|
||||
if (content) {
|
||||
content = content.replace(/```json\n?|\n?```/g, '')
|
||||
content = content.trim()
|
||||
iterationCount++
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in Google Gemini request:', error)
|
||||
throw error
|
||||
}
|
||||
|
||||
return {
|
||||
content,
|
||||
tokens: response.usageMetadata && {
|
||||
prompt: response.usageMetadata.promptTokenCount,
|
||||
completion: response.usageMetadata.candidatesTokenCount,
|
||||
total: response.usageMetadata.totalTokenCount,
|
||||
},
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
}
|
||||
},
|
||||
|
||||
hasFunctionCall: (response: any) => {
|
||||
return !!response.candidates?.[0]?.content?.parts?.[0]?.functionCall
|
||||
},
|
||||
}
|
||||
|
||||
27
providers/index.ts
Normal file
27
providers/index.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import { ProviderRequest, ProviderResponse } from './types'
|
||||
import { generateStructuredOutputInstructions, getProvider } from './utils'
|
||||
|
||||
export async function executeProviderRequest(
|
||||
providerId: string,
|
||||
request: ProviderRequest
|
||||
): Promise<ProviderResponse> {
|
||||
const provider = getProvider(providerId)
|
||||
if (!provider) {
|
||||
throw new Error(`Provider not found: ${providerId}`)
|
||||
}
|
||||
|
||||
if (!provider.executeRequest) {
|
||||
throw new Error(`Provider ${providerId} does not implement executeRequest`)
|
||||
}
|
||||
|
||||
// If responseFormat is provided, modify the system prompt to enforce structured output
|
||||
if (request.responseFormat) {
|
||||
const structuredOutputInstructions = generateStructuredOutputInstructions(
|
||||
request.responseFormat
|
||||
)
|
||||
request.systemPrompt = `${request.systemPrompt}\n\n${structuredOutputInstructions}`
|
||||
}
|
||||
|
||||
// Execute the request using the provider's implementation
|
||||
return await provider.executeRequest(request)
|
||||
}
|
||||
@@ -1,11 +1,6 @@
|
||||
import OpenAI from 'openai'
|
||||
import {
|
||||
FunctionCallResponse,
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
ProviderToolConfig,
|
||||
} from '../types'
|
||||
import { executeTool } from '@/tools'
|
||||
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
|
||||
|
||||
export const openaiProvider: ProviderConfig = {
|
||||
id: 'openai',
|
||||
@@ -14,9 +9,7 @@ export const openaiProvider: ProviderConfig = {
|
||||
version: '1.0.0',
|
||||
models: ['gpt-4o', 'o1', 'o3-mini'],
|
||||
defaultModel: 'gpt-4o',
|
||||
implementationType: 'sdk',
|
||||
|
||||
// SDK-based implementation
|
||||
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
|
||||
if (!request.apiKey) {
|
||||
throw new Error('API key is required for OpenAI')
|
||||
@@ -116,9 +109,9 @@ export const openaiProvider: ProviderConfig = {
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
// Execute the tool (this will need to be imported from your tools system)
|
||||
const { executeTool } = await import('@/tools')
|
||||
const result = await executeTool(toolName, toolArgs)
|
||||
// Execute the tool
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
const result = await executeTool(toolName, mergedArgs, true)
|
||||
|
||||
if (!result.success) continue
|
||||
|
||||
@@ -190,190 +183,4 @@ export const openaiProvider: ProviderConfig = {
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
}
|
||||
},
|
||||
|
||||
// These are still needed for backward compatibility
|
||||
baseUrl: 'https://api.openai.com/v1/chat/completions',
|
||||
headers: (apiKey: string) => ({
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
}),
|
||||
|
||||
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
|
||||
if (!tools || tools.length === 0) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return tools.map((tool) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}))
|
||||
},
|
||||
|
||||
transformFunctionCallResponse: (
|
||||
response: any,
|
||||
tools?: ProviderToolConfig[]
|
||||
): FunctionCallResponse => {
|
||||
// Handle SDK response format
|
||||
if (response.choices?.[0]?.message?.tool_calls) {
|
||||
const toolCall = response.choices[0].message.tool_calls[0]
|
||||
if (!toolCall) {
|
||||
throw new Error('No tool call found in response')
|
||||
}
|
||||
|
||||
const tool = tools?.find((t) => t.id === toolCall.function.name)
|
||||
const toolParams = tool?.params || {}
|
||||
|
||||
return {
|
||||
name: toolCall.function.name,
|
||||
arguments: {
|
||||
...toolParams,
|
||||
...JSON.parse(toolCall.function.arguments),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Handle legacy function_call format for backward compatibility
|
||||
const functionCall = response.choices?.[0]?.message?.function_call
|
||||
if (!functionCall) {
|
||||
throw new Error('No function call found in response')
|
||||
}
|
||||
|
||||
const tool = tools?.find((t) => t.id === functionCall.name)
|
||||
const toolParams = tool?.params || {}
|
||||
|
||||
return {
|
||||
name: functionCall.name,
|
||||
arguments: {
|
||||
...toolParams,
|
||||
...JSON.parse(functionCall.arguments),
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
transformRequest: (request: ProviderRequest, functions?: any) => {
|
||||
const isO1Model = request.model?.startsWith('o1')
|
||||
const isO1Mini = request.model === 'o1-mini'
|
||||
|
||||
// Helper function to transform message role
|
||||
const transformMessageRole = (message: any) => {
|
||||
if (isO1Mini && message.role === 'system') {
|
||||
return { ...message, role: 'user' }
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
// Start with an empty array for all messages
|
||||
const allMessages = []
|
||||
|
||||
// Add system prompt if present
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push(
|
||||
transformMessageRole({
|
||||
role: 'system',
|
||||
content: request.systemPrompt,
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
// Add context if present
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
content: request.context,
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages, transforming roles as needed
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages.map(transformMessageRole))
|
||||
}
|
||||
|
||||
// Build the request payload
|
||||
const payload: any = {
|
||||
model: request.model || 'gpt-4o',
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
// Only add parameters supported by the model type
|
||||
if (!isO1Model) {
|
||||
// gpt-4o supports standard parameters
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
// Add response format for structured output if specified
|
||||
if (request.responseFormat) {
|
||||
// Use OpenAI's simpler response format
|
||||
payload.response_format = { type: 'json_object' }
|
||||
|
||||
// If we have both function calls and response format, we need to guide the model
|
||||
if (functions) {
|
||||
payload.messages[0].content = `${payload.messages[0].content}\n\nProcess:\n1. First, use the provided functions to gather the necessary data\n2. Then, format your final response as a SINGLE JSON object with these exact fields and types:\n${request.responseFormat.fields
|
||||
.map(
|
||||
(field) =>
|
||||
`- "${field.name}" (${field.type})${field.description ? `: ${field.description}` : ''}`
|
||||
)
|
||||
.join(
|
||||
'\n'
|
||||
)}\n\nYour final response after function calls must be a SINGLE valid JSON object with all required fields and correct types. Do not return multiple objects or include any text outside the JSON.`
|
||||
} else {
|
||||
// If no functions, just format as JSON directly
|
||||
payload.messages[0].content = `${payload.messages[0].content}\n\nYou MUST return a SINGLE JSON object with exactly these fields and types:\n${request.responseFormat.fields
|
||||
.map(
|
||||
(field) =>
|
||||
`- "${field.name}" (${field.type})${field.description ? `: ${field.description}` : ''}`
|
||||
)
|
||||
.join(
|
||||
'\n'
|
||||
)}\n\nThe response must:\n1. Be a single valid JSON object\n2. Include all the specified fields\n3. Use the correct type for each field\n4. Not include any additional fields\n5. Not include any explanatory text outside the JSON\n6. Not return multiple objects`
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// o1 models use max_completion_tokens
|
||||
if (request.maxTokens !== undefined) {
|
||||
payload.max_completion_tokens = request.maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Add function calling support (supported by all models)
|
||||
if (functions) {
|
||||
payload.tools = functions
|
||||
payload.tool_choice = 'auto'
|
||||
}
|
||||
|
||||
return payload
|
||||
},
|
||||
|
||||
transformResponse: (response: any) => {
|
||||
const output = {
|
||||
content: response.choices?.[0]?.message?.content || '',
|
||||
tokens: undefined as any,
|
||||
}
|
||||
|
||||
if (response.usage) {
|
||||
output.tokens = {
|
||||
prompt: response.usage.prompt_tokens,
|
||||
completion: response.usage.completion_tokens,
|
||||
total: response.usage.total_tokens,
|
||||
}
|
||||
|
||||
// Add reasoning_tokens for o1 models if available
|
||||
if (response.usage.completion_tokens_details?.reasoning_tokens) {
|
||||
output.tokens.reasoning = response.usage.completion_tokens_details.reasoning_tokens
|
||||
}
|
||||
}
|
||||
|
||||
return output
|
||||
},
|
||||
|
||||
hasFunctionCall: (response: any) => {
|
||||
return (
|
||||
!!response.choices?.[0]?.message?.function_call ||
|
||||
(response.choices?.[0]?.message?.tool_calls &&
|
||||
response.choices[0].message.tool_calls.length > 0)
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
import { anthropicProvider } from './anthropic'
|
||||
import { cerebrasProvider } from './cerebras'
|
||||
import { deepseekProvider } from './deepseek'
|
||||
import { googleProvider } from './google'
|
||||
import { openaiProvider } from './openai'
|
||||
import { ProviderConfig } from './types'
|
||||
import { xAIProvider } from './xai'
|
||||
|
||||
export type ProviderId = 'openai' | 'anthropic' | 'google' | 'deepseek' | 'xai' | 'cerebras'
|
||||
|
||||
export const providers: Record<ProviderId, ProviderConfig> = {
|
||||
openai: openaiProvider,
|
||||
anthropic: anthropicProvider,
|
||||
google: googleProvider,
|
||||
deepseek: deepseekProvider,
|
||||
xai: xAIProvider,
|
||||
cerebras: cerebrasProvider,
|
||||
}
|
||||
|
||||
export function getProvider(id: string): ProviderConfig | undefined {
|
||||
// Handle both formats: 'openai' and 'openai/chat'
|
||||
const providerId = id.split('/')[0] as ProviderId
|
||||
return providers[providerId]
|
||||
}
|
||||
|
||||
export function getProviderChatId(providerId: ProviderId): string {
|
||||
return `${providerId}/chat`
|
||||
}
|
||||
@@ -1,341 +0,0 @@
|
||||
import OpenAI from 'openai'
|
||||
import { executeTool, getTool } from '@/tools'
|
||||
import { getProvider } from './registry'
|
||||
import { ProviderRequest, ProviderResponse, TokenInfo } from './types'
|
||||
import { extractAndParseJSON } from './utils'
|
||||
|
||||
// Helper function to generate provider-specific structured output instructions
|
||||
function generateStructuredOutputInstructions(responseFormat: any): string {
|
||||
if (!responseFormat?.fields) return ''
|
||||
|
||||
function generateFieldStructure(field: any): string {
|
||||
if (field.type === 'object' && field.properties) {
|
||||
return `{
|
||||
${Object.entries(field.properties)
|
||||
.map(([key, prop]: [string, any]) => `"${key}": ${prop.type === 'number' ? '0' : '"value"'}`)
|
||||
.join(',\n ')}
|
||||
}`
|
||||
}
|
||||
return field.type === 'string'
|
||||
? '"value"'
|
||||
: field.type === 'number'
|
||||
? '0'
|
||||
: field.type === 'boolean'
|
||||
? 'true/false'
|
||||
: '[]'
|
||||
}
|
||||
|
||||
const exampleFormat = responseFormat.fields
|
||||
.map((field: any) => ` "${field.name}": ${generateFieldStructure(field)}`)
|
||||
.join(',\n')
|
||||
|
||||
const fieldDescriptions = responseFormat.fields
|
||||
.map((field: any) => {
|
||||
let desc = `${field.name} (${field.type})`
|
||||
if (field.description) desc += `: ${field.description}`
|
||||
if (field.type === 'object' && field.properties) {
|
||||
desc += '\nProperties:'
|
||||
Object.entries(field.properties).forEach(([key, prop]: [string, any]) => {
|
||||
desc += `\n - ${key} (${(prop as any).type}): ${(prop as any).description || ''}`
|
||||
})
|
||||
}
|
||||
return desc
|
||||
})
|
||||
.join('\n')
|
||||
|
||||
return `
|
||||
Please provide your response in the following JSON format:
|
||||
{
|
||||
${exampleFormat}
|
||||
}
|
||||
|
||||
Field descriptions:
|
||||
${fieldDescriptions}
|
||||
|
||||
Your response MUST be valid JSON and include all the specified fields with their correct types.
|
||||
Each metric should be an object containing 'score' (number) and 'reasoning' (string).`
|
||||
}
|
||||
|
||||
export async function executeProviderRequest(
|
||||
providerId: string,
|
||||
request: ProviderRequest
|
||||
): Promise<ProviderResponse> {
|
||||
const provider = getProvider(providerId)
|
||||
if (!provider) {
|
||||
throw new Error(`Provider not found: ${providerId}`)
|
||||
}
|
||||
|
||||
// Use SDK-based implementation if available
|
||||
if (provider.implementationType === 'sdk' && provider.executeRequest) {
|
||||
return provider.executeRequest(request)
|
||||
}
|
||||
|
||||
// Legacy HTTP-based implementation for other providers
|
||||
return executeHttpProviderRequest(provider, request)
|
||||
}
|
||||
|
||||
// Legacy HTTP-based implementation
|
||||
async function executeHttpProviderRequest(
|
||||
provider: any,
|
||||
request: ProviderRequest
|
||||
): Promise<ProviderResponse> {
|
||||
// If responseFormat is provided, modify the system prompt to enforce structured output
|
||||
if (request.responseFormat) {
|
||||
const structuredOutputInstructions = generateStructuredOutputInstructions(
|
||||
request.responseFormat
|
||||
)
|
||||
request.systemPrompt = `${request.systemPrompt}\n\n${structuredOutputInstructions}`
|
||||
}
|
||||
|
||||
// Transform tools to provider-specific function format
|
||||
const functions =
|
||||
request.tools && request.tools.length > 0
|
||||
? provider.transformToolsToFunctions(request.tools)
|
||||
: undefined
|
||||
|
||||
// Transform the request using provider-specific logic
|
||||
const payload = provider.transformRequest(request, functions)
|
||||
|
||||
// Make the initial API request through the proxy
|
||||
let currentResponse = await makeProxyRequest(provider.id, payload, request.apiKey)
|
||||
let content = ''
|
||||
let tokens: TokenInfo | undefined = undefined
|
||||
let toolCalls = []
|
||||
let toolResults = []
|
||||
let currentMessages = [...(request.messages || [])]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
// Transform the response using provider-specific logic
|
||||
const transformedResponse = provider.transformResponse(currentResponse)
|
||||
content = transformedResponse.content
|
||||
|
||||
// If responseFormat is specified and we have content (not a function call), validate and parse the response
|
||||
if (request.responseFormat && content && !provider.hasFunctionCall(currentResponse)) {
|
||||
try {
|
||||
// Extract and parse the JSON content
|
||||
const parsedContent = extractAndParseJSON(content)
|
||||
|
||||
// Validate that all required fields are present and have correct types
|
||||
const validationErrors = request.responseFormat.fields
|
||||
.map((field: any) => {
|
||||
if (!(field.name in parsedContent)) {
|
||||
return `Missing field: ${field.name}`
|
||||
}
|
||||
const value = parsedContent[field.name]
|
||||
const type = typeof value
|
||||
if (field.type === 'string' && type !== 'string') {
|
||||
return `Invalid type for ${field.name}: expected string, got ${type}`
|
||||
}
|
||||
if (field.type === 'number' && type !== 'number') {
|
||||
return `Invalid type for ${field.name}: expected number, got ${type}`
|
||||
}
|
||||
if (field.type === 'boolean' && type !== 'boolean') {
|
||||
return `Invalid type for ${field.name}: expected boolean, got ${type}`
|
||||
}
|
||||
if (field.type === 'array' && !Array.isArray(value)) {
|
||||
return `Invalid type for ${field.name}: expected array, got ${type}`
|
||||
}
|
||||
if (field.type === 'object' && (type !== 'object' || Array.isArray(value))) {
|
||||
return `Invalid type for ${field.name}: expected object, got ${type}`
|
||||
}
|
||||
return null
|
||||
})
|
||||
.filter(Boolean)
|
||||
|
||||
if (validationErrors.length > 0) {
|
||||
throw new Error(`Response format validation failed:\n${validationErrors.join('\n')}`)
|
||||
}
|
||||
|
||||
// Store the validated JSON response
|
||||
content = JSON.stringify(parsedContent)
|
||||
} catch (error: any) {
|
||||
console.error('Raw content:', content)
|
||||
throw new Error(`Failed to parse structured response: ${error.message}`)
|
||||
}
|
||||
}
|
||||
|
||||
// Update tokens
|
||||
if (transformedResponse.tokens) {
|
||||
const newTokens: TokenInfo = {
|
||||
prompt: (tokens?.prompt ?? 0) + (transformedResponse.tokens?.prompt ?? 0),
|
||||
completion: (tokens?.completion ?? 0) + (transformedResponse.tokens?.completion ?? 0),
|
||||
total: (tokens?.total ?? 0) + (transformedResponse.tokens?.total ?? 0),
|
||||
}
|
||||
tokens = newTokens
|
||||
}
|
||||
|
||||
// Check for function calls using provider-specific logic
|
||||
const hasFunctionCall = provider.hasFunctionCall(currentResponse)
|
||||
|
||||
// Break if we have content and no function call
|
||||
if (!hasFunctionCall) {
|
||||
break
|
||||
}
|
||||
|
||||
// Safety check: if we have the same function call multiple times in a row
|
||||
// with the same arguments, break to prevent infinite loops
|
||||
let functionCall
|
||||
try {
|
||||
functionCall = provider.transformFunctionCallResponse(currentResponse, request.tools)
|
||||
|
||||
// Check if this is a duplicate call
|
||||
const lastCall = toolCalls[toolCalls.length - 1]
|
||||
if (
|
||||
lastCall &&
|
||||
lastCall.name === functionCall.name &&
|
||||
JSON.stringify(lastCall.arguments) === JSON.stringify(functionCall.arguments)
|
||||
) {
|
||||
console.log(
|
||||
'Detected duplicate function call, breaking loop to prevent infinite recursion'
|
||||
)
|
||||
break
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('Error transforming function call:', error)
|
||||
break
|
||||
}
|
||||
|
||||
if (!functionCall) {
|
||||
break
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
const tool = getTool(functionCall.name)
|
||||
if (!tool) {
|
||||
break
|
||||
}
|
||||
|
||||
const result = await executeTool(functionCall.name, functionCall.arguments)
|
||||
|
||||
if (!result.success) {
|
||||
break
|
||||
}
|
||||
|
||||
toolResults.push(result.output)
|
||||
toolCalls.push(functionCall)
|
||||
|
||||
// Add the function call and result to messages
|
||||
// Check if we're dealing with the new tool_calls format or the legacy function_call format
|
||||
const hasToolCalls = currentResponse.choices?.[0]?.message?.tool_calls
|
||||
|
||||
if (hasToolCalls) {
|
||||
const toolCall = currentResponse.choices[0].message.tool_calls[0]
|
||||
if (toolCall && toolCall.id) {
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: functionCall.name,
|
||||
arguments: JSON.stringify(functionCall.arguments),
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(result.output),
|
||||
})
|
||||
} else {
|
||||
// Fallback to legacy format if id is missing
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
function_call: {
|
||||
name: functionCall.name,
|
||||
arguments: JSON.stringify(functionCall.arguments),
|
||||
},
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'function',
|
||||
name: functionCall.name,
|
||||
content: JSON.stringify(result.output),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Legacy format
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
function_call: {
|
||||
name: functionCall.name,
|
||||
arguments: JSON.stringify(functionCall.arguments),
|
||||
},
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'function',
|
||||
name: functionCall.name,
|
||||
content: JSON.stringify(result.output),
|
||||
})
|
||||
}
|
||||
|
||||
// Prepare the next request
|
||||
const nextPayload = provider.transformRequest(
|
||||
{
|
||||
...request,
|
||||
messages: currentMessages,
|
||||
},
|
||||
functions
|
||||
)
|
||||
|
||||
// Make the next request
|
||||
currentResponse = await makeProxyRequest(provider.id, nextPayload, request.apiKey)
|
||||
iterationCount++
|
||||
}
|
||||
|
||||
if (iterationCount >= MAX_ITERATIONS) {
|
||||
console.log('Max iterations of tool calls reached, breaking loop')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error executing tool:', error)
|
||||
throw error
|
||||
}
|
||||
|
||||
return {
|
||||
content,
|
||||
model: currentResponse.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
}
|
||||
}
|
||||
|
||||
async function makeProxyRequest(providerId: string, payload: any, apiKey: string) {
|
||||
const baseUrl = process.env.NEXT_PUBLIC_APP_URL
|
||||
if (!baseUrl) {
|
||||
throw new Error('NEXT_PUBLIC_APP_URL environment variable is not set')
|
||||
}
|
||||
|
||||
const proxyUrl = new URL('/api/proxy', baseUrl).toString()
|
||||
const response = await fetch(proxyUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
toolId: `${providerId}/chat`,
|
||||
params: {
|
||||
...payload,
|
||||
apiKey,
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
||||
const data = await response.json()
|
||||
|
||||
if (!data.success) {
|
||||
throw new Error(data.error || 'Provider API error')
|
||||
}
|
||||
|
||||
return data.output
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ToolConfig } from '@/tools/types'
|
||||
export type ProviderId = 'openai' | 'anthropic' | 'google' | 'deepseek' | 'xai' | 'cerebras'
|
||||
|
||||
export interface TokenInfo {
|
||||
prompt?: number
|
||||
@@ -18,33 +18,7 @@ export interface ProviderConfig {
|
||||
version: string
|
||||
models: string[]
|
||||
defaultModel: string
|
||||
|
||||
// Provider implementation type
|
||||
implementationType: 'sdk' | 'http'
|
||||
|
||||
// For HTTP-based providers
|
||||
baseUrl?: string
|
||||
headers?: (apiKey: string) => Record<string, string>
|
||||
|
||||
// For SDK-based providers
|
||||
executeRequest?: (request: ProviderRequest) => Promise<ProviderResponse>
|
||||
|
||||
// Tool calling support
|
||||
transformToolsToFunctions: (tools: ProviderToolConfig[]) => any
|
||||
transformFunctionCallResponse: (
|
||||
response: any,
|
||||
tools?: ProviderToolConfig[]
|
||||
) => FunctionCallResponse
|
||||
|
||||
// Provider-specific request/response transformations
|
||||
transformRequest: (request: ProviderRequest, functions?: any) => any
|
||||
transformResponse: (response: any) => TransformedResponse
|
||||
|
||||
// Function to check if response contains a function call
|
||||
hasFunctionCall: (response: any) => boolean
|
||||
|
||||
// Internal state for tool name mapping
|
||||
_toolNameMapping?: Map<string, string>
|
||||
}
|
||||
|
||||
export interface FunctionCallResponse {
|
||||
|
||||
@@ -1,60 +1,161 @@
|
||||
import { ProviderId } from './registry'
|
||||
import { anthropicProvider } from './anthropic'
|
||||
import { cerebrasProvider } from './cerebras'
|
||||
import { deepseekProvider } from './deepseek'
|
||||
import { googleProvider } from './google'
|
||||
import { openaiProvider } from './openai'
|
||||
import { ProviderConfig, ProviderId } from './types'
|
||||
import { xAIProvider } from './xai'
|
||||
|
||||
/**
|
||||
* Direct mapping from model names to provider IDs
|
||||
* Provider configurations with associated model names/patterns
|
||||
*/
|
||||
export const MODEL_PROVIDERS: Record<string, ProviderId> = {
|
||||
'gpt-4o': 'openai',
|
||||
o1: 'openai',
|
||||
'o3-mini': 'openai',
|
||||
'claude-3-7-sonnet-20250219': 'anthropic',
|
||||
'gemini-2.0-flash': 'google',
|
||||
'grok-2-latest': 'xai',
|
||||
'deepseek-v3': 'deepseek',
|
||||
'deepseek-r1': 'deepseek',
|
||||
'llama-3.3-70b': 'cerebras',
|
||||
export const providers: Record<
|
||||
ProviderId,
|
||||
ProviderConfig & {
|
||||
models: string[]
|
||||
modelPatterns?: RegExp[]
|
||||
}
|
||||
> = {
|
||||
openai: {
|
||||
...openaiProvider,
|
||||
models: ['gpt-4o', 'o1', 'o3-mini'],
|
||||
modelPatterns: [/^gpt/, /^o1/],
|
||||
},
|
||||
anthropic: {
|
||||
...anthropicProvider,
|
||||
models: ['claude-3-7-sonnet-20250219'],
|
||||
modelPatterns: [/^claude/],
|
||||
},
|
||||
google: {
|
||||
...googleProvider,
|
||||
models: ['gemini-2.0-flash'],
|
||||
modelPatterns: [/^gemini/],
|
||||
},
|
||||
deepseek: {
|
||||
...deepseekProvider,
|
||||
models: ['deepseek-v3', 'deepseek-r1'],
|
||||
modelPatterns: [/^deepseek/],
|
||||
},
|
||||
xai: {
|
||||
...xAIProvider,
|
||||
models: ['grok-2-latest'],
|
||||
modelPatterns: [/^grok/],
|
||||
},
|
||||
cerebras: {
|
||||
...cerebrasProvider,
|
||||
models: ['llama-3.3-70b'],
|
||||
modelPatterns: [/^llama/],
|
||||
},
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the provider ID based on the model name.
|
||||
* Uses the MODEL_PROVIDERS mapping and falls back to pattern matching if needed.
|
||||
*
|
||||
* @param model - The model name/identifier
|
||||
* @returns The corresponding provider ID
|
||||
* Direct mapping from model names to provider IDs
|
||||
* Automatically generated from the providers configuration
|
||||
*/
|
||||
export const MODEL_PROVIDERS: Record<string, ProviderId> = Object.entries(providers).reduce(
|
||||
(map, [providerId, config]) => {
|
||||
config.models.forEach((model) => {
|
||||
map[model.toLowerCase()] = providerId as ProviderId
|
||||
})
|
||||
return map
|
||||
},
|
||||
{} as Record<string, ProviderId>
|
||||
)
|
||||
|
||||
export function getProviderFromModel(model: string): ProviderId {
|
||||
const normalizedModel = model.toLowerCase()
|
||||
|
||||
// First try to match exactly from our MODEL_PROVIDERS mapping
|
||||
if (normalizedModel in MODEL_PROVIDERS) {
|
||||
return MODEL_PROVIDERS[normalizedModel]
|
||||
}
|
||||
|
||||
// If no exact match, use pattern matching as fallback
|
||||
if (normalizedModel.startsWith('gpt') || normalizedModel.startsWith('o1')) {
|
||||
return 'openai'
|
||||
for (const [providerId, config] of Object.entries(providers)) {
|
||||
if (config.modelPatterns) {
|
||||
for (const pattern of config.modelPatterns) {
|
||||
if (pattern.test(normalizedModel)) {
|
||||
return providerId as ProviderId
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (normalizedModel.startsWith('claude')) {
|
||||
return 'anthropic'
|
||||
}
|
||||
|
||||
if (normalizedModel.startsWith('gemini')) {
|
||||
return 'google'
|
||||
}
|
||||
|
||||
if (normalizedModel.startsWith('grok')) {
|
||||
return 'xai'
|
||||
}
|
||||
|
||||
if (normalizedModel.startsWith('llama')) {
|
||||
return 'cerebras'
|
||||
}
|
||||
|
||||
// Default to deepseek for any other models
|
||||
return 'deepseek'
|
||||
}
|
||||
|
||||
export function getProvider(id: string): ProviderConfig | undefined {
|
||||
// Handle both formats: 'openai' and 'openai/chat'
|
||||
const providerId = id.split('/')[0] as ProviderId
|
||||
return providers[providerId]
|
||||
}
|
||||
|
||||
export function getProviderConfigFromModel(model: string): ProviderConfig | undefined {
|
||||
const providerId = getProviderFromModel(model)
|
||||
return providers[providerId]
|
||||
}
|
||||
|
||||
export function getAllModels(): string[] {
|
||||
return Object.values(providers).flatMap((provider) => provider.models || [])
|
||||
}
|
||||
|
||||
export function getAllProviderIds(): ProviderId[] {
|
||||
return Object.keys(providers) as ProviderId[]
|
||||
}
|
||||
|
||||
export function getProviderModels(providerId: ProviderId): string[] {
|
||||
const provider = providers[providerId]
|
||||
return provider?.models || []
|
||||
}
|
||||
|
||||
export function generateStructuredOutputInstructions(responseFormat: any): string {
|
||||
if (!responseFormat?.fields) return ''
|
||||
|
||||
function generateFieldStructure(field: any): string {
|
||||
if (field.type === 'object' && field.properties) {
|
||||
return `{
|
||||
${Object.entries(field.properties)
|
||||
.map(([key, prop]: [string, any]) => `"${key}": ${prop.type === 'number' ? '0' : '"value"'}`)
|
||||
.join(',\n ')}
|
||||
}`
|
||||
}
|
||||
return field.type === 'string'
|
||||
? '"value"'
|
||||
: field.type === 'number'
|
||||
? '0'
|
||||
: field.type === 'boolean'
|
||||
? 'true/false'
|
||||
: '[]'
|
||||
}
|
||||
|
||||
const exampleFormat = responseFormat.fields
|
||||
.map((field: any) => ` "${field.name}": ${generateFieldStructure(field)}`)
|
||||
.join(',\n')
|
||||
|
||||
const fieldDescriptions = responseFormat.fields
|
||||
.map((field: any) => {
|
||||
let desc = `${field.name} (${field.type})`
|
||||
if (field.description) desc += `: ${field.description}`
|
||||
if (field.type === 'object' && field.properties) {
|
||||
desc += '\nProperties:'
|
||||
Object.entries(field.properties).forEach(([key, prop]: [string, any]) => {
|
||||
desc += `\n - ${key} (${(prop as any).type}): ${(prop as any).description || ''}`
|
||||
})
|
||||
}
|
||||
return desc
|
||||
})
|
||||
.join('\n')
|
||||
|
||||
return `
|
||||
Please provide your response in the following JSON format:
|
||||
{
|
||||
${exampleFormat}
|
||||
}
|
||||
|
||||
Field descriptions:
|
||||
${fieldDescriptions}
|
||||
|
||||
Your response MUST be valid JSON and include all the specified fields with their correct types.
|
||||
Each metric should be an object containing 'score' (number) and 'reasoning' (string).`
|
||||
}
|
||||
|
||||
export function extractAndParseJSON(content: string): any {
|
||||
// First clean up the string
|
||||
const trimmed = content.trim()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { FunctionCallResponse, ProviderConfig, ProviderRequest, ProviderToolConfig } from '../types'
|
||||
import OpenAI from 'openai'
|
||||
import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
|
||||
|
||||
export const xAIProvider: ProviderConfig = {
|
||||
id: 'xai',
|
||||
@@ -8,210 +9,210 @@ export const xAIProvider: ProviderConfig = {
|
||||
models: ['grok-2-latest'],
|
||||
defaultModel: 'grok-2-latest',
|
||||
|
||||
baseUrl: 'https://api.x.ai/v1/chat/completions',
|
||||
headers: (apiKey: string) => ({
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
}),
|
||||
|
||||
transformToolsToFunctions: (tools: ProviderToolConfig[]) => {
|
||||
if (!tools || tools.length === 0) {
|
||||
return undefined
|
||||
executeRequest: async (request: ProviderRequest): Promise<ProviderResponse> => {
|
||||
if (!request.apiKey) {
|
||||
throw new Error('API key is required for xAI')
|
||||
}
|
||||
|
||||
return tools.map((tool) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}))
|
||||
},
|
||||
|
||||
transformFunctionCallResponse: (
|
||||
response: any,
|
||||
tools?: ProviderToolConfig[]
|
||||
): FunctionCallResponse => {
|
||||
// xAI returns tool_calls array like OpenAI
|
||||
const toolCall = response.choices?.[0]?.message?.tool_calls?.[0]
|
||||
if (!toolCall || !toolCall.function) {
|
||||
throw new Error('No valid tool call found in response')
|
||||
}
|
||||
|
||||
const tool = tools?.find((t) => t.id === toolCall.function.name)
|
||||
if (!tool) {
|
||||
throw new Error(`Tool not found: ${toolCall.function.name}`)
|
||||
}
|
||||
|
||||
let args = toolCall.function.arguments
|
||||
if (typeof args === 'string') {
|
||||
try {
|
||||
args = JSON.parse(args)
|
||||
} catch (e) {
|
||||
console.error('Failed to parse tool arguments:', e)
|
||||
args = {}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
name: toolCall.function.name,
|
||||
arguments: {
|
||||
...tool.params,
|
||||
...args,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
transformRequest: (request: ProviderRequest, functions?: any) => {
|
||||
// Convert function messages to tool messages
|
||||
const messages = (request.messages || []).map((msg) => {
|
||||
if (msg.role === 'function') {
|
||||
return {
|
||||
role: 'tool',
|
||||
content: msg.content,
|
||||
tool_call_id: msg.name, // xAI expects tool_call_id for tool results
|
||||
}
|
||||
}
|
||||
|
||||
if (msg.function_call) {
|
||||
return {
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: msg.function_call.name,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: msg.function_call.name,
|
||||
arguments: msg.function_call.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
return msg
|
||||
const xai = new OpenAI({
|
||||
apiKey: request.apiKey,
|
||||
baseURL: 'https://api.x.ai/v1',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
|
||||
// Add response format for structured output if specified
|
||||
let systemPrompt = request.systemPrompt
|
||||
if (request.responseFormat) {
|
||||
systemPrompt += `\n\nYou MUST respond with a valid JSON object. DO NOT include any other text, explanations, or markdown formatting in your response - ONLY the JSON object.\n\nThe response MUST match this schema:\n${JSON.stringify(
|
||||
{
|
||||
type: 'object',
|
||||
properties: request.responseFormat.fields.reduce(
|
||||
(acc, field) => ({
|
||||
...acc,
|
||||
[field.name]: {
|
||||
type:
|
||||
field.type === 'array'
|
||||
? 'array'
|
||||
: field.type === 'object'
|
||||
? 'object'
|
||||
: field.type,
|
||||
description: field.description,
|
||||
},
|
||||
}),
|
||||
{}
|
||||
),
|
||||
required: request.responseFormat.fields.map((f) => f.name),
|
||||
},
|
||||
null,
|
||||
2
|
||||
)}\n\nExample response format:\n{\n${request.responseFormat.fields
|
||||
.map(
|
||||
(f) =>
|
||||
` "${f.name}": ${
|
||||
f.type === 'string'
|
||||
? '"value"'
|
||||
: f.type === 'number'
|
||||
? '0'
|
||||
: f.type === 'boolean'
|
||||
? 'true'
|
||||
: f.type === 'array'
|
||||
? '[]'
|
||||
: '{}'
|
||||
}`
|
||||
)
|
||||
.join(',\n')}\n}`
|
||||
const allMessages = []
|
||||
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
content: request.systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
const payload = {
|
||||
model: request.model || 'grok-2-latest',
|
||||
messages: [
|
||||
{ role: 'system', content: systemPrompt },
|
||||
...(request.context ? [{ role: 'user', content: request.context }] : []),
|
||||
...messages,
|
||||
],
|
||||
temperature: request.temperature || 0.7,
|
||||
max_tokens: request.maxTokens || 1024,
|
||||
...(functions && {
|
||||
tools: functions,
|
||||
tool_choice: 'auto', // xAI specific parameter
|
||||
}),
|
||||
...(request.responseFormat && {
|
||||
response_format: {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: 'structured_response',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: request.responseFormat.fields.reduce(
|
||||
(acc, field) => ({
|
||||
...acc,
|
||||
[field.name]: {
|
||||
type:
|
||||
field.type === 'array'
|
||||
? 'array'
|
||||
: field.type === 'object'
|
||||
? 'object'
|
||||
: field.type === 'number'
|
||||
? 'number'
|
||||
: field.type === 'boolean'
|
||||
? 'boolean'
|
||||
: 'string',
|
||||
description: field.description || '',
|
||||
...(field.type === 'array' && {
|
||||
items: { type: 'string' },
|
||||
}),
|
||||
...(field.type === 'object' && {
|
||||
additionalProperties: true,
|
||||
}),
|
||||
},
|
||||
}),
|
||||
{}
|
||||
),
|
||||
required: request.responseFormat.fields.map((f) => f.name),
|
||||
additionalProperties: false,
|
||||
},
|
||||
strict: true,
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
content: request.context,
|
||||
})
|
||||
}
|
||||
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
},
|
||||
}),
|
||||
}))
|
||||
: undefined
|
||||
|
||||
const payload: any = {
|
||||
model: request.model || 'grok-2-latest',
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
return payload
|
||||
},
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
transformResponse: (response: any) => {
|
||||
if (!response) {
|
||||
console.warn('Received undefined response from xAI API')
|
||||
return { content: '' }
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: 'structured_response',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: request.responseFormat.fields.reduce(
|
||||
(acc, field) => ({
|
||||
...acc,
|
||||
[field.name]: {
|
||||
type:
|
||||
field.type === 'array'
|
||||
? 'array'
|
||||
: field.type === 'object'
|
||||
? 'object'
|
||||
: field.type === 'number'
|
||||
? 'number'
|
||||
: field.type === 'boolean'
|
||||
? 'boolean'
|
||||
: 'string',
|
||||
description: field.description || '',
|
||||
...(field.type === 'array' && {
|
||||
items: { type: 'string' },
|
||||
}),
|
||||
...(field.type === 'object' && {
|
||||
additionalProperties: true,
|
||||
}),
|
||||
},
|
||||
}),
|
||||
{}
|
||||
),
|
||||
required: request.responseFormat.fields.map((f) => f.name),
|
||||
additionalProperties: false,
|
||||
},
|
||||
strict: true,
|
||||
},
|
||||
}
|
||||
|
||||
if (allMessages.length > 0 && allMessages[0].role === 'system') {
|
||||
allMessages[0].content = `${allMessages[0].content}\n\nYou MUST respond with a valid JSON object. DO NOT include any other text, explanations, or markdown formatting in your response - ONLY the JSON object.`
|
||||
} else {
|
||||
allMessages.unshift({
|
||||
role: 'system',
|
||||
content: `You MUST respond with a valid JSON object. DO NOT include any other text, explanations, or markdown formatting in your response - ONLY the JSON object.`,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if (tools?.length) {
|
||||
payload.tools = tools
|
||||
payload.tool_choice = 'auto'
|
||||
}
|
||||
|
||||
let currentResponse = await xai.chat.completions.create(payload)
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
let tokens = {
|
||||
prompt: currentResponse.usage?.prompt_tokens || 0,
|
||||
completion: currentResponse.usage?.completion_tokens || 0,
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
let toolCalls = []
|
||||
let toolResults = []
|
||||
let currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
const { executeTool } = await import('@/tools')
|
||||
const mergedArgs = { ...tool.params, ...toolArgs }
|
||||
console.log(`Merged tool args for ${toolName}:`, {
|
||||
toolParams: tool.params,
|
||||
llmArgs: toolArgs,
|
||||
mergedArgs,
|
||||
})
|
||||
const result = await executeTool(toolName, mergedArgs, true)
|
||||
|
||||
if (!result.success) continue
|
||||
|
||||
toolResults.push(result.output)
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolArgs,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(result.output),
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Error processing tool call:', error)
|
||||
}
|
||||
}
|
||||
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
currentResponse = await xai.chat.completions.create(nextPayload)
|
||||
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += currentResponse.usage.completion_tokens || 0
|
||||
tokens.total += currentResponse.usage.total_tokens || 0
|
||||
}
|
||||
|
||||
iterationCount++
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in xAI request:', error)
|
||||
throw error
|
||||
}
|
||||
|
||||
return {
|
||||
content: response.choices?.[0]?.message?.content || '',
|
||||
tokens: response.usage && {
|
||||
prompt: response.usage.prompt_tokens,
|
||||
completion: response.usage.completion_tokens,
|
||||
total: response.usage.total_tokens,
|
||||
},
|
||||
content,
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
}
|
||||
},
|
||||
|
||||
hasFunctionCall: (response: any) => {
|
||||
if (!response) return false
|
||||
return !!response.choices?.[0]?.message?.tool_calls?.[0]
|
||||
},
|
||||
}
|
||||
|
||||
111
tools/index.ts
111
tools/index.ts
@@ -24,6 +24,7 @@ import { slackMessageTool } from './slack/message'
|
||||
import { extractTool as tavilyExtract } from './tavily/extract'
|
||||
import { searchTool as tavilySearch } from './tavily/search'
|
||||
import { ToolConfig, ToolResponse } from './types'
|
||||
import { executeRequest, formatRequestParams, validateToolRequest } from './utils'
|
||||
import { readTool as xRead } from './x/read'
|
||||
import { searchTool as xSearch } from './x/search'
|
||||
import { userTool as xUser } from './x/user'
|
||||
@@ -72,54 +73,41 @@ export function getTool(toolId: string): ToolConfig | undefined {
|
||||
// Execute a tool by calling either the proxy for external APIs or directly for internal routes
|
||||
export async function executeTool(
|
||||
toolId: string,
|
||||
params: Record<string, any>
|
||||
params: Record<string, any>,
|
||||
skipProxy = false
|
||||
): Promise<ToolResponse> {
|
||||
try {
|
||||
const tool = getTool(toolId)
|
||||
console.log(`Tool being called: ${toolId}`, {
|
||||
params: { ...params, apiKey: params.apiKey ? '[REDACTED]' : undefined },
|
||||
skipProxy,
|
||||
})
|
||||
|
||||
// Validate the tool and its parameters
|
||||
validateToolRequest(toolId, tool, params)
|
||||
|
||||
// After validation, we know tool exists
|
||||
if (!tool) {
|
||||
throw new Error(`Tool not found: ${toolId}`)
|
||||
}
|
||||
|
||||
// For internal routes, call the API directly
|
||||
if (tool.request.isInternalRoute) {
|
||||
const url =
|
||||
typeof tool.request.url === 'function' ? tool.request.url(params) : tool.request.url
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: tool.request.method,
|
||||
headers: tool.request.headers(params),
|
||||
body: JSON.stringify(tool.request.body ? tool.request.body(params) : params),
|
||||
// For internal routes or when skipProxy is true, call the API directly
|
||||
if (tool.request.isInternalRoute || skipProxy) {
|
||||
console.log(`Calling internal request for ${toolId}`)
|
||||
const result = await handleInternalRequest(toolId, tool, params)
|
||||
console.log(`Tool ${toolId} execution result:`, {
|
||||
success: result.success,
|
||||
outputKeys: result.success ? Object.keys(result.output) : [],
|
||||
error: result.error,
|
||||
})
|
||||
|
||||
const result = await tool.transformResponse(response)
|
||||
return result
|
||||
}
|
||||
|
||||
// For external APIs, use the proxy
|
||||
const baseUrl = process.env.NEXT_PUBLIC_APP_URL
|
||||
if (!baseUrl) {
|
||||
throw new Error('NEXT_PUBLIC_APP_URL environment variable is not set')
|
||||
}
|
||||
|
||||
const proxyUrl = new URL('/api/proxy', baseUrl).toString()
|
||||
const response = await fetch(proxyUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ toolId, params }),
|
||||
})
|
||||
|
||||
const result = await response.json()
|
||||
|
||||
if (!result.success) {
|
||||
return {
|
||||
success: false,
|
||||
output: {},
|
||||
error: result.error,
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
console.log(`Calling proxy request for ${toolId}`)
|
||||
return await handleProxyRequest(toolId, params)
|
||||
} catch (error: any) {
|
||||
console.error(`Error executing tool ${toolId}:`, error)
|
||||
return {
|
||||
success: false,
|
||||
output: {},
|
||||
@@ -127,3 +115,56 @@ export async function executeTool(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle an internal/direct tool request
|
||||
*/
|
||||
async function handleInternalRequest(
|
||||
toolId: string,
|
||||
tool: ToolConfig,
|
||||
params: Record<string, any>
|
||||
): Promise<ToolResponse> {
|
||||
// Log the request for debugging
|
||||
console.log(`Executing tool ${toolId} with params:`, {
|
||||
toolId,
|
||||
params: { ...params, apiKey: params.apiKey ? '[REDACTED]' : undefined },
|
||||
})
|
||||
|
||||
// Format the request parameters
|
||||
const requestParams = formatRequestParams(tool, params)
|
||||
|
||||
// Execute the request
|
||||
return await executeRequest(toolId, tool, requestParams)
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle a request via the proxy
|
||||
*/
|
||||
async function handleProxyRequest(
|
||||
toolId: string,
|
||||
params: Record<string, any>
|
||||
): Promise<ToolResponse> {
|
||||
const baseUrl = process.env.NEXT_PUBLIC_APP_URL
|
||||
if (!baseUrl) {
|
||||
throw new Error('NEXT_PUBLIC_APP_URL environment variable is not set')
|
||||
}
|
||||
|
||||
const proxyUrl = new URL('/api/proxy', baseUrl).toString()
|
||||
const response = await fetch(proxyUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ toolId, params }),
|
||||
})
|
||||
|
||||
const result = await response.json()
|
||||
|
||||
if (!result.success) {
|
||||
return {
|
||||
success: false,
|
||||
output: {},
|
||||
error: result.error,
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
109
tools/utils.ts
109
tools/utils.ts
@@ -1,4 +1,5 @@
|
||||
import { TableRow } from './types'
|
||||
import { ToolConfig, ToolResponse } from './types'
|
||||
|
||||
/**
|
||||
* Transforms a table from the store format to a key-value object
|
||||
@@ -18,3 +19,111 @@ export const transformTable = (table: TableRow[] | null): Record<string, string>
|
||||
{} as Record<string, string>
|
||||
)
|
||||
}
|
||||
|
||||
interface RequestParams {
|
||||
url: string
|
||||
method: string
|
||||
headers: Record<string, string>
|
||||
body?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Format request parameters based on tool configuration and provided params
|
||||
*/
|
||||
export function formatRequestParams(tool: ToolConfig, params: Record<string, any>): RequestParams {
|
||||
// Process URL
|
||||
const url = typeof tool.request.url === 'function' ? tool.request.url(params) : tool.request.url
|
||||
|
||||
// Process method
|
||||
const method = params.method || tool.request.method || 'GET'
|
||||
|
||||
// Process headers
|
||||
const headers = tool.request.headers ? tool.request.headers(params) : {}
|
||||
|
||||
// Process body
|
||||
const hasBody = method !== 'GET' && method !== 'HEAD' && !!tool.request.body
|
||||
const bodyResult = tool.request.body ? tool.request.body(params) : undefined
|
||||
|
||||
// Special handling for NDJSON content type
|
||||
const isNDJSON = headers['Content-Type'] === 'application/x-ndjson'
|
||||
const body = hasBody
|
||||
? isNDJSON && bodyResult
|
||||
? bodyResult.body
|
||||
: JSON.stringify(bodyResult)
|
||||
: undefined
|
||||
|
||||
return { url, method, headers, body }
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the actual request and transform the response
|
||||
*/
|
||||
export async function executeRequest(
|
||||
toolId: string,
|
||||
tool: ToolConfig,
|
||||
requestParams: RequestParams
|
||||
): Promise<ToolResponse> {
|
||||
try {
|
||||
const { url, method, headers, body } = requestParams
|
||||
|
||||
// Log the request for debugging
|
||||
console.log(`Executing tool ${toolId}:`, { url, method })
|
||||
|
||||
const externalResponse = await fetch(url, { method, headers, body })
|
||||
|
||||
// Log response status
|
||||
console.log(`${toolId} response status:`, externalResponse.status, externalResponse.statusText)
|
||||
|
||||
if (!externalResponse.ok) {
|
||||
let errorContent
|
||||
try {
|
||||
errorContent = await externalResponse.json()
|
||||
} catch (e) {
|
||||
errorContent = { message: externalResponse.statusText }
|
||||
}
|
||||
|
||||
// Use the tool's error transformer or a default message
|
||||
const error = tool.transformError
|
||||
? tool.transformError(errorContent)
|
||||
: errorContent.message || `${toolId} API error: ${externalResponse.statusText}`
|
||||
|
||||
console.error(`${toolId} error:`, error)
|
||||
throw new Error(error)
|
||||
}
|
||||
|
||||
const transformResponse =
|
||||
tool.transformResponse ||
|
||||
(async (resp: Response) => ({
|
||||
success: true,
|
||||
output: await resp.json(),
|
||||
}))
|
||||
|
||||
return await transformResponse(externalResponse)
|
||||
} catch (error: any) {
|
||||
return {
|
||||
success: false,
|
||||
output: {},
|
||||
error: error.message || 'Unknown error',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the tool and its parameters
|
||||
*/
|
||||
export function validateToolRequest(
|
||||
toolId: string,
|
||||
tool: ToolConfig | undefined,
|
||||
params: Record<string, any>
|
||||
): void {
|
||||
if (!tool) {
|
||||
throw new Error(`Tool not found: ${toolId}`)
|
||||
}
|
||||
|
||||
// Ensure all required parameters for tool call are provided
|
||||
for (const [paramName, paramConfig] of Object.entries(tool.params)) {
|
||||
if (paramConfig.requiredForToolCall && !(paramName in params)) {
|
||||
throw new Error(`Parameter "${paramName}" is required for ${toolId} but was not provided`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user