diff --git a/package-lock.json b/package-lock.json index c27c9ac84..b818dd8e0 100644 --- a/package-lock.json +++ b/package-lock.json @@ -33,6 +33,7 @@ "croner": "^9.0.0", "date-fns": "^3.6.0", "drizzle-orm": "^0.39.3", + "groq-sdk": "^0.15.0", "lodash.debounce": "^4.0.8", "lucide-react": "^0.469.0", "next": "^15.2.0", @@ -6419,6 +6420,36 @@ "dev": true, "license": "ISC" }, + "node_modules/groq-sdk": { + "version": "0.15.0", + "resolved": "https://registry.npmjs.org/groq-sdk/-/groq-sdk-0.15.0.tgz", + "integrity": "sha512-aYDEdr4qczx3cLCRRe+Beb37I7g/9bD5kHF+EEDxcrREWw1vKoRcfP3vHEkJB7Ud/8oOuF0scRwDpwWostTWuQ==", + "license": "Apache-2.0", + "dependencies": { + "@types/node": "^18.11.18", + "@types/node-fetch": "^2.6.4", + "abort-controller": "^3.0.0", + "agentkeepalive": "^4.2.1", + "form-data-encoder": "1.7.2", + "formdata-node": "^4.3.2", + "node-fetch": "^2.6.7" + } + }, + "node_modules/groq-sdk/node_modules/@types/node": { + "version": "18.19.79", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.79.tgz", + "integrity": "sha512-90K8Oayimbctc5zTPHPfZloc/lGVs7f3phUAAMcTgEPtg8kKquGZDERC8K4vkBYkQQh48msiYUslYtxTWvqcAg==", + "license": "MIT", + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "node_modules/groq-sdk/node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", + "license": "MIT" + }, "node_modules/has-flag": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", diff --git a/package.json b/package.json index 42ab1ab86..0f2d61aa2 100644 --- a/package.json +++ b/package.json @@ -42,6 +42,7 @@ "croner": "^9.0.0", "date-fns": "^3.6.0", "drizzle-orm": "^0.39.3", + "groq-sdk": "^0.15.0", "lodash.debounce": "^4.0.8", "lucide-react": "^0.469.0", "next": "^15.2.0", diff --git a/providers/groq/index.ts b/providers/groq/index.ts new file mode 100644 index 000000000..dab4092d7 --- /dev/null +++ b/providers/groq/index.ts @@ -0,0 +1,186 @@ +import { Groq } from 'groq-sdk' +import { executeTool } from '@/tools' +import { ProviderConfig, ProviderRequest, ProviderResponse } from '../types' + +export const groqProvider: ProviderConfig = { + id: 'groq', + name: 'Groq', + description: "Groq's LLM models with high-performance inference", + version: '1.0.0', + models: ['groq/llama-3.3-70b-specdec', 'groq/deepseek-r1-distill-llama-70b', 'groq/qwen-2.5-32b'], + defaultModel: 'groq/llama-3.3-70b-specdec', + + executeRequest: async (request: ProviderRequest): Promise => { + if (!request.apiKey) { + throw new Error('API key is required for Groq') + } + + const groq = new Groq({ + apiKey: request.apiKey, + dangerouslyAllowBrowser: true, + }) + + // Start with an empty array for all messages + const allMessages = [] + + // Add system prompt if present + if (request.systemPrompt) { + allMessages.push({ + role: 'system', + content: request.systemPrompt, + }) + } + + // Add context if present + if (request.context) { + allMessages.push({ + role: 'user', + content: request.context, + }) + } + + // Add remaining messages + if (request.messages) { + allMessages.push(...request.messages) + } + + // Transform tools to function format if provided + const tools = request.tools?.length + ? request.tools.map((tool) => ({ + type: 'function', + function: { + name: tool.id, + description: tool.description, + parameters: tool.parameters, + }, + })) + : undefined + + // Build the request payload + const payload: any = { + model: (request.model || 'groq/llama-3.3-70b-specdec').replace('groq/', ''), + messages: allMessages, + } + + // Add optional parameters + if (request.temperature !== undefined) payload.temperature = request.temperature + if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens + + // Add response format for structured output if specified + if (request.responseFormat) { + payload.response_format = { type: 'json_object' } + } + + // Add tools if provided + if (tools?.length) { + payload.tools = tools + payload.tool_choice = 'auto' + } + + // Make the initial API request + let currentResponse = await groq.chat.completions.create(payload) + let content = currentResponse.choices[0]?.message?.content || '' + let tokens = { + prompt: currentResponse.usage?.prompt_tokens || 0, + completion: currentResponse.usage?.completion_tokens || 0, + total: currentResponse.usage?.total_tokens || 0, + } + let toolCalls = [] + let toolResults = [] + let currentMessages = [...allMessages] + let iterationCount = 0 + const MAX_ITERATIONS = 10 // Prevent infinite loops + + try { + while (iterationCount < MAX_ITERATIONS) { + // Check for tool calls + const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls + if (!toolCallsInResponse || toolCallsInResponse.length === 0) { + break + } + + // Process each tool call + for (const toolCall of toolCallsInResponse) { + try { + const toolName = toolCall.function.name + const toolArgs = JSON.parse(toolCall.function.arguments) + + // Get the tool from the tools registry + const tool = request.tools?.find((t) => t.id === toolName) + if (!tool) continue + + // Execute the tool + const mergedArgs = { ...tool.params, ...toolArgs } + const result = await executeTool(toolName, mergedArgs) + + if (!result.success) continue + + toolResults.push(result.output) + toolCalls.push({ + name: toolName, + arguments: toolArgs, + }) + + // Add the tool call and result to messages + currentMessages.push({ + role: 'assistant', + content: null, + tool_calls: [ + { + id: toolCall.id, + type: 'function', + function: { + name: toolName, + arguments: toolCall.function.arguments, + }, + }, + ], + }) + + currentMessages.push({ + role: 'tool', + tool_call_id: toolCall.id, + content: JSON.stringify(result.output), + }) + } catch (error) { + console.error('Error processing tool call:', error) + } + } + + // Make the next request with updated messages + const nextPayload = { + ...payload, + messages: currentMessages, + } + + // Make the next request + currentResponse = await groq.chat.completions.create(nextPayload) + + // Update content if we have a text response + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + } + + // Update token counts + if (currentResponse.usage) { + tokens.prompt += currentResponse.usage.prompt_tokens || 0 + tokens.completion += currentResponse.usage.completion_tokens || 0 + tokens.total += currentResponse.usage.total_tokens || 0 + } + + iterationCount++ + } + } catch (error) { + console.error('Error in Groq request:', error) + throw error + } + + return { + content, + model: request.model, + tokens, + toolCalls: toolCalls.length > 0 ? toolCalls : undefined, + toolResults: toolResults.length > 0 ? toolResults : undefined, + } + }, +} diff --git a/providers/types.ts b/providers/types.ts index 81c46260c..4c6801177 100644 --- a/providers/types.ts +++ b/providers/types.ts @@ -1,4 +1,11 @@ -export type ProviderId = 'openai' | 'anthropic' | 'google' | 'deepseek' | 'xai' | 'cerebras' +export type ProviderId = + | 'openai' + | 'anthropic' + | 'google' + | 'deepseek' + | 'xai' + | 'cerebras' + | 'groq' export interface TokenInfo { prompt?: number diff --git a/providers/utils.ts b/providers/utils.ts index f1c9419d4..788dd67c1 100644 --- a/providers/utils.ts +++ b/providers/utils.ts @@ -3,6 +3,7 @@ import { anthropicProvider } from './anthropic' import { cerebrasProvider } from './cerebras' import { deepseekProvider } from './deepseek' import { googleProvider } from './google' +import { groqProvider } from './groq' import { openaiProvider } from './openai' import { ProviderConfig, ProviderId, ProviderToolConfig } from './types' import { xAIProvider } from './xai' @@ -45,7 +46,16 @@ export const providers: Record< cerebras: { ...cerebrasProvider, models: ['cerebras/llama-3.3-70b'], - modelPatterns: [/^cerebras\/llama/], + modelPatterns: [/^cerebras/], + }, + groq: { + ...groqProvider, + models: [ + 'groq/llama-3.3-70b-specdec', + 'groq/deepseek-r1-distill-llama-70b', + 'groq/qwen-2.5-32b', + ], + modelPatterns: [/^groq/], }, }