diff --git a/package-lock.json b/package-lock.json index db911c92c8..a0ecb4b49b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,6 +9,7 @@ "version": "0.1.0", "dependencies": { "@anthropic-ai/sdk": "^0.38.0", + "@cerebras/cerebras_cloud_sdk": "^1.23.0", "@radix-ui/react-alert-dialog": "^1.1.5", "@radix-ui/react-checkbox": "^1.1.3", "@radix-ui/react-dialog": "^1.1.5", @@ -677,6 +678,36 @@ "resolved": "https://registry.npmjs.org/@better-fetch/fetch/-/fetch-1.1.12.tgz", "integrity": "sha512-B3bfloI/2UBQWIATRN6qmlORrvx3Mp0kkNjmXLv0b+DtbtR+pP4/I5kQA/rDUv+OReLywCCldf6co4LdDmh8JA==" }, + "node_modules/@cerebras/cerebras_cloud_sdk": { + "version": "1.23.0", + "resolved": "https://registry.npmjs.org/@cerebras/cerebras_cloud_sdk/-/cerebras_cloud_sdk-1.23.0.tgz", + "integrity": "sha512-1krbmU4nTbJICUbcJGQGGo+MtB0nzHx/jwW24ZhoBzuC5QT8H/WzNjLdKtvdf3TB8GS1AtdWUkUHNJf1EZfvJA==", + "license": "Apache-2.0", + "dependencies": { + "@types/node": "^18.11.18", + "@types/node-fetch": "^2.6.4", + "abort-controller": "^3.0.0", + "agentkeepalive": "^4.2.1", + "form-data-encoder": "1.7.2", + "formdata-node": "^4.3.2", + "node-fetch": "^2.6.7" + } + }, + "node_modules/@cerebras/cerebras_cloud_sdk/node_modules/@types/node": { + "version": "18.19.76", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.76.tgz", + "integrity": "sha512-yvR7Q9LdPz2vGpmpJX5LolrgRdWvB67MJKDPSgIIzpFbaf9a1j/f5DnLp5VDyHGMR0QZHlTr1afsD87QCXFHKw==", + "license": "MIT", + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "node_modules/@cerebras/cerebras_cloud_sdk/node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", + "license": "MIT" + }, "node_modules/@drizzle-team/brocli": { "version": "0.10.2", "resolved": "https://registry.npmjs.org/@drizzle-team/brocli/-/brocli-0.10.2.tgz", diff --git a/package.json b/package.json index 7541a76b0d..44799ccbfc 100644 --- a/package.json +++ b/package.json @@ -18,6 +18,7 @@ }, "dependencies": { "@anthropic-ai/sdk": "^0.38.0", + "@cerebras/cerebras_cloud_sdk": "^1.23.0", "@radix-ui/react-alert-dialog": "^1.1.5", "@radix-ui/react-checkbox": "^1.1.3", "@radix-ui/react-dialog": "^1.1.5", diff --git a/providers/cerebras/index.ts b/providers/cerebras/index.ts index a2be7f0fef..6228380f82 100644 --- a/providers/cerebras/index.ts +++ b/providers/cerebras/index.ts @@ -1,4 +1,4 @@ -import OpenAI from 'openai' +import { Cerebras } from '@cerebras/cerebras_cloud_sdk' import { executeTool } from '@/tools' import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types' @@ -7,188 +7,240 @@ export const cerebrasProvider: ProviderConfig = { name: 'Cerebras', description: 'Cerebras Cloud LLMs', version: '1.0.0', - models: ['llama-3.3-70b'], - defaultModel: 'llama-3.3-70b', - + models: ['cerebras/llama-3.3-70b'], + defaultModel: 'cerebras/llama-3.3-70b', executeRequest: async (request: ProviderRequest): Promise => { if (!request.apiKey) { throw new Error('API key is required for Cerebras') } - const openai = new OpenAI({ - apiKey: request.apiKey, - baseURL: 'https://api.cerebras.ai/v1', - dangerouslyAllowBrowser: true, - }) - - // Start with an empty array for all messages - const allMessages = [] - - // Add system prompt if present - if (request.systemPrompt) { - allMessages.push({ - role: 'system', - content: request.systemPrompt, - }) - } - - // Add context if present - if (request.context) { - allMessages.push({ - role: 'user', - content: request.context, - }) - } - - // Add remaining messages - if (request.messages) { - allMessages.push(...request.messages) - } - - // Transform tools to OpenAI format if provided - const tools = request.tools?.length - ? request.tools.map((tool) => ({ - type: 'function', - function: { - name: tool.id, - description: tool.description, - parameters: tool.parameters, - }, - })) - : undefined - - // Build the request payload - const payload: any = { - model: request.model || 'llama-3.3-70b', - messages: allMessages, - } - - // Add optional parameters - if (request.temperature !== undefined) payload.temperature = request.temperature - if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens - - // Add response format for structured output if specified - if (request.responseFormat) { - payload.response_format = { type: 'json_object' } - } - - // Add tools if provided - if (tools?.length) { - payload.tools = tools - payload.tool_choice = 'auto' - } - - // Add local execution flag if specified by Cerebras - if (request.local_execution) { - payload.local_execution = true - } - - // Make the initial API request - let currentResponse = await openai.chat.completions.create(payload) - let content = currentResponse.choices[0]?.message?.content || '' - let tokens = { - prompt: currentResponse.usage?.prompt_tokens || 0, - completion: currentResponse.usage?.completion_tokens || 0, - total: currentResponse.usage?.total_tokens || 0, - } - let toolCalls = [] - let toolResults = [] - let currentMessages = [...allMessages] - let iterationCount = 0 - const MAX_ITERATIONS = 10 // Prevent infinite loops - try { - while (iterationCount < MAX_ITERATIONS) { - // Check for tool calls - const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls - if (!toolCallsInResponse || toolCallsInResponse.length === 0) { - break - } + const client = new Cerebras({ + apiKey: request.apiKey, + }) - // Process each tool call - for (const toolCall of toolCallsInResponse) { - try { - const toolName = toolCall.function.name - const toolArgs = JSON.parse(toolCall.function.arguments) + // Start with an empty array for all messages + const allMessages = [] - // Get the tool from the tools registry - const tool = request.tools?.find((t) => t.id === toolName) - if (!tool) continue + // Add system prompt if present + if (request.systemPrompt) { + allMessages.push({ + role: 'system', + content: request.systemPrompt, + }) + } - // Execute the tool - const mergedArgs = { ...tool.params, ...toolArgs } - const result = await executeTool(toolName, mergedArgs, true) + // Add context if present + if (request.context) { + allMessages.push({ + role: 'user', + content: request.context, + }) + } - if (!result.success) continue + // Add remaining messages + if (request.messages) { + allMessages.push(...request.messages) + } - toolResults.push(result.output) - toolCalls.push({ - name: toolName, - arguments: toolArgs, - }) + // Transform tools to Cerebras format if provided + const tools = request.tools?.length + ? request.tools.map((tool) => ({ + type: 'function', + function: { + name: tool.id, + description: tool.description, + parameters: tool.parameters, + }, + })) + : undefined - // Add the tool call and result to messages - currentMessages.push({ - role: 'assistant', - content: null, - tool_calls: [ - { - id: toolCall.id, - type: 'function', - function: { - name: toolName, - arguments: toolCall.function.arguments, + // Build the request payload + const payload: any = { + model: (request.model || 'cerebras/llama-3.3-70b').replace('cerebras/', ''), + messages: allMessages, + } + + // Add optional parameters + if (request.temperature !== undefined) payload.temperature = request.temperature + if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens + + // Add response format for structured output if specified + if (request.responseFormat) { + payload.response_format = { type: 'json_object' } + } + + // Add tools if provided + if (tools?.length) { + payload.tools = tools + payload.tool_choice = 'auto' + } + + // Make the initial API request + let currentResponse = (await client.chat.completions.create(payload)) as CerebrasResponse + + let content = currentResponse.choices[0]?.message?.content || '' + let tokens = { + prompt: currentResponse.usage?.prompt_tokens || 0, + completion: currentResponse.usage?.completion_tokens || 0, + total: currentResponse.usage?.total_tokens || 0, + } + let toolCalls = [] + let toolResults = [] + let currentMessages = [...allMessages] + let iterationCount = 0 + const MAX_ITERATIONS = 10 // Prevent infinite loops + + // Keep track of processed tool calls to avoid duplicates + const processedToolCallIds = new Set() + // Keep track of tool call signatures to detect repeats + const toolCallSignatures = new Set() + + try { + while (iterationCount < MAX_ITERATIONS) { + // Check for tool calls + const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls + + // Break if no tool calls + if (!toolCallsInResponse || toolCallsInResponse.length === 0) { + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + } + break + } + + // Process each tool call + let processedAnyToolCall = false + let hasRepeatedToolCalls = false + + for (const toolCall of toolCallsInResponse) { + // Skip if we've already processed this tool call + if (processedToolCallIds.has(toolCall.id)) { + continue + } + + // Create a signature for this tool call to detect repeats + const toolCallSignature = `${toolCall.function.name}-${toolCall.function.arguments}` + if (toolCallSignatures.has(toolCallSignature)) { + hasRepeatedToolCalls = true + continue + } + + try { + processedToolCallIds.add(toolCall.id) + toolCallSignatures.add(toolCallSignature) + processedAnyToolCall = true + + const toolName = toolCall.function.name + const toolArgs = JSON.parse(toolCall.function.arguments) + + // Get the tool from the tools registry + const tool = request.tools?.find((t) => t.id === toolName) + if (!tool) continue + + // Execute the tool + const mergedArgs = { ...tool.params, ...toolArgs } + const result = await executeTool(toolName, mergedArgs, true) + + if (!result.success) continue + + toolResults.push(result.output) + toolCalls.push({ + name: toolName, + arguments: toolArgs, + }) + + // Add the tool call and result to messages + currentMessages.push({ + role: 'assistant', + content: null, + tool_calls: [ + { + id: toolCall.id, + type: 'function', + function: { + name: toolName, + arguments: toolCall.function.arguments, + }, }, - }, - ], - }) + ], + }) - const toolResultContent = JSON.stringify(result.output) + currentMessages.push({ + role: 'tool', + tool_call_id: toolCall.id, + content: JSON.stringify(result.output), + }) + } catch (error) { + console.error('Error processing tool call:', error) + } + } - currentMessages.push({ - role: 'tool', - tool_call_id: toolCall.id, - content: toolResultContent, - }) - } catch (error) { - console.error('Error processing tool call:', error) + // After processing tool calls, get a final response + if (processedAnyToolCall || hasRepeatedToolCalls) { + // Make the final request + const finalPayload = { + ...payload, + messages: currentMessages, + tool_choice: 'none', + } + + const finalResponse = (await client.chat.completions.create( + finalPayload + )) as CerebrasResponse + + if (finalResponse.choices[0]?.message?.content) { + content = finalResponse.choices[0].message.content + } + + // Update final token counts + if (finalResponse.usage) { + tokens.prompt += finalResponse.usage.prompt_tokens || 0 + tokens.completion += finalResponse.usage.completion_tokens || 0 + tokens.total += finalResponse.usage.total_tokens || 0 + } + + break + } + + // Only continue if we haven't processed any tool calls and haven't seen repeats + if (!processedAnyToolCall && !hasRepeatedToolCalls) { + // Make the next request with updated messages + const nextPayload = { + ...payload, + messages: currentMessages, + } + + // Make the next request + currentResponse = (await client.chat.completions.create( + nextPayload + )) as CerebrasResponse + + // Update token counts + if (currentResponse.usage) { + tokens.prompt += currentResponse.usage.prompt_tokens || 0 + tokens.completion += currentResponse.usage.completion_tokens || 0 + tokens.total += currentResponse.usage.total_tokens || 0 + } + + iterationCount++ } } + } catch (error) { + console.error('Error in Cerebras tool processing:', error) + // Don't throw here, return what we have so far + } - // Make the next request with updated messages - const nextPayload = { - ...payload, - messages: currentMessages, - } - - // Make the next request - currentResponse = await openai.chat.completions.create(nextPayload) - - // Update content if we have a text response - if (currentResponse.choices[0]?.message?.content) { - content = currentResponse.choices[0].message.content - } - - // Update token counts - if (currentResponse.usage) { - tokens.prompt += currentResponse.usage.prompt_tokens || 0 - tokens.completion += currentResponse.usage.completion_tokens || 0 - tokens.total += currentResponse.usage.total_tokens || 0 - } - - iterationCount++ + return { + content, + model: request.model, + tokens, + toolCalls: toolCalls.length > 0 ? toolCalls : undefined, + toolResults: toolResults.length > 0 ? toolResults : undefined, } } catch (error) { console.error('Error in Cerebras request:', error) throw error } - - return { - content, - model: request.model, - tokens, - toolCalls: toolCalls.length > 0 ? toolCalls : undefined, - toolResults: toolResults.length > 0 ? toolResults : undefined, - } }, } diff --git a/providers/cerebras/types.ts b/providers/cerebras/types.ts new file mode 100644 index 0000000000..085687941c --- /dev/null +++ b/providers/cerebras/types.ts @@ -0,0 +1,34 @@ +interface CerebrasMessage { + role: string + content: string | null + tool_calls?: Array<{ + id: string + type: 'function' + function: { + name: string + arguments: string + } + }> + tool_call_id?: string +} + +interface CerebrasChoice { + message: CerebrasMessage + index: number + finish_reason: string +} + +interface CerebrasUsage { + prompt_tokens: number + completion_tokens: number + total_tokens: number +} + +interface CerebrasResponse { + id: string + object: string + created: number + model: string + choices: CerebrasChoice[] + usage: CerebrasUsage +} diff --git a/providers/utils.ts b/providers/utils.ts index de1f6b228a..2034f46776 100644 --- a/providers/utils.ts +++ b/providers/utils.ts @@ -44,8 +44,8 @@ export const providers: Record< }, cerebras: { ...cerebrasProvider, - models: ['llama-3.3-70b'], - modelPatterns: [/^llama/], + models: ['cerebras/llama-3.3-70b'], + modelPatterns: [/^cerebras\/llama/], }, }