From 6114c213d2a9b15336312fbb40f004a02c9561bc Mon Sep 17 00:00:00 2001 From: Waleed Date: Sat, 22 Nov 2025 14:50:43 -0800 Subject: [PATCH] feature(models): added vllm provider (#2103) * Add vLLM self-hosted provider * updated vllm to have pull parity with openai, dynamically fetch models --------- Co-authored-by: MagellaX --- .../components/integrations/integrations.tsx | 2 +- apps/sim/app/.env.example | 4 +- .../app/api/providers/vllm/models/route.ts | 56 ++ .../providers/provider-models-loader.tsx | 9 +- apps/sim/blocks/blocks/agent.ts | 15 +- apps/sim/components/icons.tsx | 10 + apps/sim/hooks/queries/providers.ts | 1 + apps/sim/lib/env.ts | 2 + apps/sim/providers/models.ts | 29 + apps/sim/providers/types.ts | 1 + apps/sim/providers/utils.ts | 17 +- apps/sim/providers/vllm/index.ts | 635 ++++++++++++++++++ apps/sim/stores/providers/store.ts | 1 + apps/sim/stores/providers/types.ts | 2 +- 14 files changed, 775 insertions(+), 9 deletions(-) create mode 100644 apps/sim/app/api/providers/vllm/models/route.ts create mode 100644 apps/sim/providers/vllm/index.ts diff --git a/apps/sim/app/(landing)/components/integrations/integrations.tsx b/apps/sim/app/(landing)/components/integrations/integrations.tsx index 5da1786bb..04e9487a2 100644 --- a/apps/sim/app/(landing)/components/integrations/integrations.tsx +++ b/apps/sim/app/(landing)/components/integrations/integrations.tsx @@ -13,6 +13,7 @@ const modelProviderIcons = [ { icon: Icons.OllamaIcon, label: 'Ollama' }, { icon: Icons.DeepseekIcon, label: 'Deepseek' }, { icon: Icons.ElevenLabsIcon, label: 'ElevenLabs' }, + { icon: Icons.VllmIcon, label: 'vLLM' }, ] const communicationIcons = [ @@ -88,7 +89,6 @@ interface TickerRowProps { } function TickerRow({ direction, offset, showOdd, icons }: TickerRowProps) { - // Create multiple copies of the icons array for seamless looping const extendedIcons = [...icons, ...icons, ...icons, ...icons] return ( diff --git a/apps/sim/app/.env.example b/apps/sim/app/.env.example index 0f9db4fe6..2c390dac1 100644 --- a/apps/sim/app/.env.example +++ b/apps/sim/app/.env.example @@ -20,4 +20,6 @@ INTERNAL_API_SECRET=your_internal_api_secret # Use `openssl rand -hex 32` to gen # If left commented out, emails will be logged to console instead # Local AI Models (Optional) -# OLLAMA_URL=http://localhost:11434 # URL for local Ollama server - uncomment if using local models \ No newline at end of file +# OLLAMA_URL=http://localhost:11434 # URL for local Ollama server - uncomment if using local models +# VLLM_BASE_URL=http://localhost:8000 # Base URL for your self-hosted vLLM (OpenAI-compatible) +# VLLM_API_KEY= # Optional bearer token if your vLLM instance requires auth diff --git a/apps/sim/app/api/providers/vllm/models/route.ts b/apps/sim/app/api/providers/vllm/models/route.ts new file mode 100644 index 000000000..71c4dd04a --- /dev/null +++ b/apps/sim/app/api/providers/vllm/models/route.ts @@ -0,0 +1,56 @@ +import { type NextRequest, NextResponse } from 'next/server' +import { env } from '@/lib/env' +import { createLogger } from '@/lib/logs/console/logger' + +const logger = createLogger('VLLMModelsAPI') + +/** + * Get available vLLM models + */ +export async function GET(request: NextRequest) { + const baseUrl = (env.VLLM_BASE_URL || '').replace(/\/$/, '') + + if (!baseUrl) { + logger.info('VLLM_BASE_URL not configured') + return NextResponse.json({ models: [] }) + } + + try { + logger.info('Fetching vLLM models', { + baseUrl, + }) + + const response = await fetch(`${baseUrl}/v1/models`, { + headers: { + 'Content-Type': 'application/json', + }, + next: { revalidate: 60 }, + }) + + if (!response.ok) { + logger.warn('vLLM service is not available', { + status: response.status, + statusText: response.statusText, + }) + return NextResponse.json({ models: [] }) + } + + const data = (await response.json()) as { data: Array<{ id: string }> } + const models = data.data.map((model) => `vllm/${model.id}`) + + logger.info('Successfully fetched vLLM models', { + count: models.length, + models, + }) + + return NextResponse.json({ models }) + } catch (error) { + logger.error('Failed to fetch vLLM models', { + error: error instanceof Error ? error.message : 'Unknown error', + baseUrl, + }) + + // Return empty array instead of error to avoid breaking the UI + return NextResponse.json({ models: [] }) + } +} diff --git a/apps/sim/app/workspace/[workspaceId]/providers/provider-models-loader.tsx b/apps/sim/app/workspace/[workspaceId]/providers/provider-models-loader.tsx index ab7506321..478492def 100644 --- a/apps/sim/app/workspace/[workspaceId]/providers/provider-models-loader.tsx +++ b/apps/sim/app/workspace/[workspaceId]/providers/provider-models-loader.tsx @@ -3,7 +3,11 @@ import { useEffect } from 'react' import { createLogger } from '@/lib/logs/console/logger' import { useProviderModels } from '@/hooks/queries/providers' -import { updateOllamaProviderModels, updateOpenRouterProviderModels } from '@/providers/utils' +import { + updateOllamaProviderModels, + updateOpenRouterProviderModels, + updateVLLMProviderModels, +} from '@/providers/utils' import { useProvidersStore } from '@/stores/providers/store' import type { ProviderName } from '@/stores/providers/types' @@ -24,6 +28,8 @@ function useSyncProvider(provider: ProviderName) { try { if (provider === 'ollama') { updateOllamaProviderModels(data) + } else if (provider === 'vllm') { + updateVLLMProviderModels(data) } else if (provider === 'openrouter') { void updateOpenRouterProviderModels(data) } @@ -44,6 +50,7 @@ function useSyncProvider(provider: ProviderName) { export function ProviderModelsLoader() { useSyncProvider('base') useSyncProvider('ollama') + useSyncProvider('vllm') useSyncProvider('openrouter') return null } diff --git a/apps/sim/blocks/blocks/agent.ts b/apps/sim/blocks/blocks/agent.ts index 7cd45cc36..7ade8d876 100644 --- a/apps/sim/blocks/blocks/agent.ts +++ b/apps/sim/blocks/blocks/agent.ts @@ -18,6 +18,10 @@ const getCurrentOllamaModels = () => { return useProvidersStore.getState().providers.ollama.models } +const getCurrentVLLMModels = () => { + return useProvidersStore.getState().providers.vllm.models +} + import { useProvidersStore } from '@/stores/providers/store' import type { ToolResponse } from '@/tools/types' @@ -90,8 +94,11 @@ export const AgentBlock: BlockConfig = { const providersState = useProvidersStore.getState() const baseModels = providersState.providers.base.models const ollamaModels = providersState.providers.ollama.models + const vllmModels = providersState.providers.vllm.models const openrouterModels = providersState.providers.openrouter.models - const allModels = Array.from(new Set([...baseModels, ...ollamaModels, ...openrouterModels])) + const allModels = Array.from( + new Set([...baseModels, ...ollamaModels, ...vllmModels, ...openrouterModels]) + ) return allModels.map((model) => { const icon = getProviderIcon(model) @@ -172,7 +179,7 @@ export const AgentBlock: BlockConfig = { password: true, connectionDroppable: false, required: true, - // Hide API key for hosted models and Ollama models + // Hide API key for hosted models, Ollama models, and vLLM models condition: isHosted ? { field: 'model', @@ -181,8 +188,8 @@ export const AgentBlock: BlockConfig = { } : () => ({ field: 'model', - value: getCurrentOllamaModels(), - not: true, // Show for all models EXCEPT Ollama models + value: [...getCurrentOllamaModels(), ...getCurrentVLLMModels()], + not: true, // Show for all models EXCEPT Ollama and vLLM models }), }, { diff --git a/apps/sim/components/icons.tsx b/apps/sim/components/icons.tsx index b4dd059b7..fb3bf4cf6 100644 --- a/apps/sim/components/icons.tsx +++ b/apps/sim/components/icons.tsx @@ -4150,3 +4150,13 @@ export function VideoIcon(props: SVGProps) { ) } + +export function VllmIcon(props: SVGProps) { + return ( + + vLLM + + + + ) +} diff --git a/apps/sim/hooks/queries/providers.ts b/apps/sim/hooks/queries/providers.ts index 21455e351..b99e98f7c 100644 --- a/apps/sim/hooks/queries/providers.ts +++ b/apps/sim/hooks/queries/providers.ts @@ -7,6 +7,7 @@ const logger = createLogger('ProviderModelsQuery') const providerEndpoints: Record = { base: '/api/providers/base/models', ollama: '/api/providers/ollama/models', + vllm: '/api/providers/vllm/models', openrouter: '/api/providers/openrouter/models', } diff --git a/apps/sim/lib/env.ts b/apps/sim/lib/env.ts index 1f58a44ba..d870fae28 100644 --- a/apps/sim/lib/env.ts +++ b/apps/sim/lib/env.ts @@ -77,6 +77,8 @@ export const env = createEnv({ ANTHROPIC_API_KEY_2: z.string().min(1).optional(), // Additional Anthropic API key for load balancing ANTHROPIC_API_KEY_3: z.string().min(1).optional(), // Additional Anthropic API key for load balancing OLLAMA_URL: z.string().url().optional(), // Ollama local LLM server URL + VLLM_BASE_URL: z.string().url().optional(), // vLLM self-hosted base URL (OpenAI-compatible) + VLLM_API_KEY: z.string().optional(), // Optional bearer token for vLLM ELEVENLABS_API_KEY: z.string().min(1).optional(), // ElevenLabs API key for text-to-speech in deployed chat SERPER_API_KEY: z.string().min(1).optional(), // Serper API key for online search EXA_API_KEY: z.string().min(1).optional(), // Exa AI API key for enhanced online search diff --git a/apps/sim/providers/models.ts b/apps/sim/providers/models.ts index afe2a8719..a9eefb23f 100644 --- a/apps/sim/providers/models.ts +++ b/apps/sim/providers/models.ts @@ -19,6 +19,7 @@ import { OllamaIcon, OpenAIIcon, OpenRouterIcon, + VllmIcon, xAIIcon, } from '@/components/icons' @@ -82,6 +83,19 @@ export const PROVIDER_DEFINITIONS: Record = { contextInformationAvailable: false, models: [], }, + vllm: { + id: 'vllm', + name: 'vLLM', + icon: VllmIcon, + description: 'Self-hosted vLLM with an OpenAI-compatible API', + defaultModel: 'vllm/generic', + modelPatterns: [/^vllm\//], + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + models: [], + }, openai: { id: 'openai', name: 'OpenAI', @@ -1366,6 +1380,21 @@ export function updateOllamaModels(models: string[]): void { })) } +/** + * Update vLLM models dynamically + */ +export function updateVLLMModels(models: string[]): void { + PROVIDER_DEFINITIONS.vllm.models = models.map((modelId) => ({ + id: modelId, + pricing: { + input: 0, + output: 0, + updatedAt: new Date().toISOString().split('T')[0], + }, + capabilities: {}, + })) +} + /** * Update OpenRouter models dynamically */ diff --git a/apps/sim/providers/types.ts b/apps/sim/providers/types.ts index 000f7781f..6c2fd1f00 100644 --- a/apps/sim/providers/types.ts +++ b/apps/sim/providers/types.ts @@ -12,6 +12,7 @@ export type ProviderId = | 'mistral' | 'ollama' | 'openrouter' + | 'vllm' /** * Model pricing information per million tokens diff --git a/apps/sim/providers/utils.ts b/apps/sim/providers/utils.ts index 4f0c6f58c..3380f6fde 100644 --- a/apps/sim/providers/utils.ts +++ b/apps/sim/providers/utils.ts @@ -30,6 +30,7 @@ import { ollamaProvider } from '@/providers/ollama' import { openaiProvider } from '@/providers/openai' import { openRouterProvider } from '@/providers/openrouter' import type { ProviderConfig, ProviderId, ProviderToolConfig } from '@/providers/types' +import { vllmProvider } from '@/providers/vllm' import { xAIProvider } from '@/providers/xai' import { useCustomToolsStore } from '@/stores/custom-tools/store' import { useProvidersStore } from '@/stores/providers/store' @@ -86,6 +87,11 @@ export const providers: Record< models: getProviderModelsFromDefinitions('groq'), modelPatterns: PROVIDER_DEFINITIONS.groq.modelPatterns, }, + vllm: { + ...vllmProvider, + models: getProviderModelsFromDefinitions('vllm'), + modelPatterns: PROVIDER_DEFINITIONS.vllm.modelPatterns, + }, mistral: { ...mistralProvider, models: getProviderModelsFromDefinitions('mistral'), @@ -123,6 +129,12 @@ export function updateOllamaProviderModels(models: string[]): void { providers.ollama.models = getProviderModelsFromDefinitions('ollama') } +export function updateVLLMProviderModels(models: string[]): void { + const { updateVLLMModels } = require('@/providers/models') + updateVLLMModels(models) + providers.vllm.models = getProviderModelsFromDefinitions('vllm') +} + export async function updateOpenRouterProviderModels(models: string[]): Promise { const { updateOpenRouterModels } = await import('@/providers/models') updateOpenRouterModels(models) @@ -131,7 +143,10 @@ export async function updateOpenRouterProviderModels(models: string[]): Promise< export function getBaseModelProviders(): Record { const allProviders = Object.entries(providers) - .filter(([providerId]) => providerId !== 'ollama' && providerId !== 'openrouter') + .filter( + ([providerId]) => + providerId !== 'ollama' && providerId !== 'vllm' && providerId !== 'openrouter' + ) .reduce( (map, [providerId, config]) => { config.models.forEach((model) => { diff --git a/apps/sim/providers/vllm/index.ts b/apps/sim/providers/vllm/index.ts new file mode 100644 index 000000000..7a9ee6afe --- /dev/null +++ b/apps/sim/providers/vllm/index.ts @@ -0,0 +1,635 @@ +import OpenAI from 'openai' +import { env } from '@/lib/env' +import { createLogger } from '@/lib/logs/console/logger' +import type { StreamingExecution } from '@/executor/types' +import { getProviderDefaultModel, getProviderModels } from '@/providers/models' +import type { + ProviderConfig, + ProviderRequest, + ProviderResponse, + TimeSegment, +} from '@/providers/types' +import { + prepareToolExecution, + prepareToolsWithUsageControl, + trackForcedToolUsage, +} from '@/providers/utils' +import { useProvidersStore } from '@/stores/providers/store' +import { executeTool } from '@/tools' + +const logger = createLogger('VLLMProvider') +const VLLM_VERSION = '1.0.0' + +/** + * Helper function to convert a vLLM stream to a standard ReadableStream + * and collect completion metrics + */ +function createReadableStreamFromVLLMStream( + vllmStream: any, + onComplete?: (content: string, usage?: any) => void +): ReadableStream { + let fullContent = '' + let usageData: any = null + + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of vllmStream) { + if (chunk.usage) { + usageData = chunk.usage + } + + const content = chunk.choices[0]?.delta?.content || '' + if (content) { + fullContent += content + controller.enqueue(new TextEncoder().encode(content)) + } + } + + if (onComplete) { + onComplete(fullContent, usageData) + } + + controller.close() + } catch (error) { + controller.error(error) + } + }, + }) +} + +export const vllmProvider: ProviderConfig = { + id: 'vllm', + name: 'vLLM', + description: 'Self-hosted vLLM with OpenAI-compatible API', + version: VLLM_VERSION, + models: getProviderModels('vllm'), + defaultModel: getProviderDefaultModel('vllm'), + + async initialize() { + if (typeof window !== 'undefined') { + logger.info('Skipping vLLM initialization on client side to avoid CORS issues') + return + } + + const baseUrl = (env.VLLM_BASE_URL || '').replace(/\/$/, '') + if (!baseUrl) { + logger.info('VLLM_BASE_URL not configured, skipping initialization') + return + } + + try { + const response = await fetch(`${baseUrl}/v1/models`) + if (!response.ok) { + useProvidersStore.getState().setProviderModels('vllm', []) + logger.warn('vLLM service is not available. The provider will be disabled.') + return + } + + const data = (await response.json()) as { data: Array<{ id: string }> } + const models = data.data.map((model) => `vllm/${model.id}`) + + this.models = models + useProvidersStore.getState().setProviderModels('vllm', models) + + logger.info(`Discovered ${models.length} vLLM model(s):`, { models }) + } catch (error) { + logger.warn('vLLM model instantiation failed. The provider will be disabled.', { + error: error instanceof Error ? error.message : 'Unknown error', + }) + } + }, + + executeRequest: async ( + request: ProviderRequest + ): Promise => { + logger.info('Preparing vLLM request', { + model: request.model, + hasSystemPrompt: !!request.systemPrompt, + hasMessages: !!request.messages?.length, + hasTools: !!request.tools?.length, + toolCount: request.tools?.length || 0, + hasResponseFormat: !!request.responseFormat, + stream: !!request.stream, + }) + + const baseUrl = (request.azureEndpoint || env.VLLM_BASE_URL || '').replace(/\/$/, '') + if (!baseUrl) { + throw new Error('VLLM_BASE_URL is required for vLLM provider') + } + + const apiKey = request.apiKey || env.VLLM_API_KEY || 'empty' + const vllm = new OpenAI({ + apiKey, + baseURL: `${baseUrl}/v1`, + }) + + const allMessages = [] as any[] + + if (request.systemPrompt) { + allMessages.push({ + role: 'system', + content: request.systemPrompt, + }) + } + + if (request.context) { + allMessages.push({ + role: 'user', + content: request.context, + }) + } + + if (request.messages) { + allMessages.push(...request.messages) + } + + const tools = request.tools?.length + ? request.tools.map((tool) => ({ + type: 'function', + function: { + name: tool.id, + description: tool.description, + parameters: tool.parameters, + }, + })) + : undefined + + const payload: any = { + model: (request.model || getProviderDefaultModel('vllm')).replace(/^vllm\//, ''), + messages: allMessages, + } + + if (request.temperature !== undefined) payload.temperature = request.temperature + if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens + + if (request.responseFormat) { + payload.response_format = { + type: 'json_schema', + json_schema: { + name: request.responseFormat.name || 'response_schema', + schema: request.responseFormat.schema || request.responseFormat, + strict: request.responseFormat.strict !== false, + }, + } + + logger.info('Added JSON schema response format to vLLM request') + } + + let preparedTools: ReturnType | null = null + let hasActiveTools = false + + if (tools?.length) { + preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'vllm') + const { tools: filteredTools, toolChoice } = preparedTools + + if (filteredTools?.length && toolChoice) { + payload.tools = filteredTools + payload.tool_choice = toolChoice + hasActiveTools = true + + logger.info('vLLM request configuration:', { + toolCount: filteredTools.length, + toolChoice: + typeof toolChoice === 'string' + ? toolChoice + : toolChoice.type === 'function' + ? `force:${toolChoice.function.name}` + : 'unknown', + model: payload.model, + }) + } + } + + const providerStartTime = Date.now() + const providerStartTimeISO = new Date(providerStartTime).toISOString() + + try { + if (request.stream && (!tools || tools.length === 0 || !hasActiveTools)) { + logger.info('Using streaming response for vLLM request') + + const streamResponse = await vllm.chat.completions.create({ + ...payload, + stream: true, + stream_options: { include_usage: true }, + }) + + const tokenUsage = { + prompt: 0, + completion: 0, + total: 0, + } + + const streamingResult = { + stream: createReadableStreamFromVLLMStream(streamResponse, (content, usage) => { + let cleanContent = content + if (cleanContent && request.responseFormat) { + cleanContent = cleanContent.replace(/```json\n?|\n?```/g, '').trim() + } + + streamingResult.execution.output.content = cleanContent + + const streamEndTime = Date.now() + const streamEndTimeISO = new Date(streamEndTime).toISOString() + + if (streamingResult.execution.output.providerTiming) { + streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO + streamingResult.execution.output.providerTiming.duration = + streamEndTime - providerStartTime + + if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) { + streamingResult.execution.output.providerTiming.timeSegments[0].endTime = + streamEndTime + streamingResult.execution.output.providerTiming.timeSegments[0].duration = + streamEndTime - providerStartTime + } + } + + if (usage) { + const newTokens = { + prompt: usage.prompt_tokens || tokenUsage.prompt, + completion: usage.completion_tokens || tokenUsage.completion, + total: usage.total_tokens || tokenUsage.total, + } + + streamingResult.execution.output.tokens = newTokens + } + }), + execution: { + success: true, + output: { + content: '', + model: request.model, + tokens: tokenUsage, + toolCalls: undefined, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + timeSegments: [ + { + type: 'model', + name: 'Streaming response', + startTime: providerStartTime, + endTime: Date.now(), + duration: Date.now() - providerStartTime, + }, + ], + }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + }, + }, + } as StreamingExecution + + return streamingResult as StreamingExecution + } + + const initialCallTime = Date.now() + + const originalToolChoice = payload.tool_choice + + const forcedTools = preparedTools?.forcedTools || [] + let usedForcedTools: string[] = [] + + const checkForForcedToolUsage = ( + response: any, + toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any } + ) => { + if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) { + const toolCallsResponse = response.choices[0].message.tool_calls + const result = trackForcedToolUsage( + toolCallsResponse, + toolChoice, + logger, + 'vllm', + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = result.hasUsedForcedTool + usedForcedTools = result.usedForcedTools + } + } + + let currentResponse = await vllm.chat.completions.create(payload) + const firstResponseTime = Date.now() - initialCallTime + + let content = currentResponse.choices[0]?.message?.content || '' + + if (content && request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() + } + + const tokens = { + prompt: currentResponse.usage?.prompt_tokens || 0, + completion: currentResponse.usage?.completion_tokens || 0, + total: currentResponse.usage?.total_tokens || 0, + } + const toolCalls = [] + const toolResults = [] + const currentMessages = [...allMessages] + let iterationCount = 0 + const MAX_ITERATIONS = 10 + + let modelTime = firstResponseTime + let toolsTime = 0 + + let hasUsedForcedTool = false + + const timeSegments: TimeSegment[] = [ + { + type: 'model', + name: 'Initial response', + startTime: initialCallTime, + endTime: initialCallTime + firstResponseTime, + duration: firstResponseTime, + }, + ] + + checkForForcedToolUsage(currentResponse, originalToolChoice) + + while (iterationCount < MAX_ITERATIONS) { + const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls + if (!toolCallsInResponse || toolCallsInResponse.length === 0) { + break + } + + logger.info( + `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})` + ) + + const toolsStartTime = Date.now() + + for (const toolCall of toolCallsInResponse) { + try { + const toolName = toolCall.function.name + const toolArgs = JSON.parse(toolCall.function.arguments) + + const tool = request.tools?.find((t) => t.id === toolName) + if (!tool) continue + + const toolCallStartTime = Date.now() + + const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) + const result = await executeTool(toolName, executionParams, true) + const toolCallEndTime = Date.now() + const toolCallDuration = toolCallEndTime - toolCallStartTime + + timeSegments.push({ + type: 'tool', + name: toolName, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallDuration, + }) + + let resultContent: any + if (result.success) { + toolResults.push(result.output) + resultContent = result.output + } else { + resultContent = { + error: true, + message: result.error || 'Tool execution failed', + tool: toolName, + } + } + + toolCalls.push({ + name: toolName, + arguments: toolParams, + startTime: new Date(toolCallStartTime).toISOString(), + endTime: new Date(toolCallEndTime).toISOString(), + duration: toolCallDuration, + result: resultContent, + success: result.success, + }) + + currentMessages.push({ + role: 'assistant', + content: null, + tool_calls: [ + { + id: toolCall.id, + type: 'function', + function: { + name: toolName, + arguments: toolCall.function.arguments, + }, + }, + ], + }) + + currentMessages.push({ + role: 'tool', + tool_call_id: toolCall.id, + content: JSON.stringify(resultContent), + }) + } catch (error) { + logger.error('Error processing tool call:', { + error, + toolName: toolCall?.function?.name, + }) + } + } + + const thisToolsTime = Date.now() - toolsStartTime + toolsTime += thisToolsTime + + const nextPayload = { + ...payload, + messages: currentMessages, + } + + if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) { + const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool)) + + if (remainingTools.length > 0) { + nextPayload.tool_choice = { + type: 'function', + function: { name: remainingTools[0] }, + } + logger.info(`Forcing next tool: ${remainingTools[0]}`) + } else { + nextPayload.tool_choice = 'auto' + logger.info('All forced tools have been used, switching to auto tool_choice') + } + } + + const nextModelStartTime = Date.now() + + currentResponse = await vllm.chat.completions.create(nextPayload) + + checkForForcedToolUsage(currentResponse, nextPayload.tool_choice) + + const nextModelEndTime = Date.now() + const thisModelTime = nextModelEndTime - nextModelStartTime + + timeSegments.push({ + type: 'model', + name: `Model response (iteration ${iterationCount + 1})`, + startTime: nextModelStartTime, + endTime: nextModelEndTime, + duration: thisModelTime, + }) + + modelTime += thisModelTime + + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + if (request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() + } + } + + if (currentResponse.usage) { + tokens.prompt += currentResponse.usage.prompt_tokens || 0 + tokens.completion += currentResponse.usage.completion_tokens || 0 + tokens.total += currentResponse.usage.total_tokens || 0 + } + + iterationCount++ + } + + if (request.stream) { + logger.info('Using streaming for final response after tool processing') + + const streamingPayload = { + ...payload, + messages: currentMessages, + tool_choice: 'auto', + stream: true, + stream_options: { include_usage: true }, + } + + const streamResponse = await vllm.chat.completions.create(streamingPayload) + + const streamingResult = { + stream: createReadableStreamFromVLLMStream(streamResponse, (content, usage) => { + let cleanContent = content + if (cleanContent && request.responseFormat) { + cleanContent = cleanContent.replace(/```json\n?|\n?```/g, '').trim() + } + + streamingResult.execution.output.content = cleanContent + + if (usage) { + const newTokens = { + prompt: usage.prompt_tokens || tokens.prompt, + completion: usage.completion_tokens || tokens.completion, + total: usage.total_tokens || tokens.total, + } + + streamingResult.execution.output.tokens = newTokens + } + }), + execution: { + success: true, + output: { + content: '', + model: request.model, + tokens: { + prompt: tokens.prompt, + completion: tokens.completion, + total: tokens.total, + }, + toolCalls: + toolCalls.length > 0 + ? { + list: toolCalls, + count: toolCalls.length, + } + : undefined, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + modelTime: modelTime, + toolsTime: toolsTime, + firstResponseTime: firstResponseTime, + iterations: iterationCount + 1, + timeSegments: timeSegments, + }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + }, + }, + } as StreamingExecution + + return streamingResult as StreamingExecution + } + + const providerEndTime = Date.now() + const providerEndTimeISO = new Date(providerEndTime).toISOString() + const totalDuration = providerEndTime - providerStartTime + + return { + content, + model: request.model, + tokens, + toolCalls: toolCalls.length > 0 ? toolCalls : undefined, + toolResults: toolResults.length > 0 ? toolResults : undefined, + timing: { + startTime: providerStartTimeISO, + endTime: providerEndTimeISO, + duration: totalDuration, + modelTime: modelTime, + toolsTime: toolsTime, + firstResponseTime: firstResponseTime, + iterations: iterationCount + 1, + timeSegments: timeSegments, + }, + } + } catch (error) { + const providerEndTime = Date.now() + const providerEndTimeISO = new Date(providerEndTime).toISOString() + const totalDuration = providerEndTime - providerStartTime + + let errorMessage = error instanceof Error ? error.message : String(error) + let errorType: string | undefined + let errorCode: number | undefined + + if (error && typeof error === 'object' && 'error' in error) { + const vllmError = error.error as any + if (vllmError && typeof vllmError === 'object') { + errorMessage = vllmError.message || errorMessage + errorType = vllmError.type + errorCode = vllmError.code + } + } + + logger.error('Error in vLLM request:', { + error: errorMessage, + errorType, + errorCode, + duration: totalDuration, + }) + + const enhancedError = new Error(errorMessage) + // @ts-ignore - Adding timing and vLLM error properties + enhancedError.timing = { + startTime: providerStartTimeISO, + endTime: providerEndTimeISO, + duration: totalDuration, + } + if (errorType) { + // @ts-ignore + enhancedError.vllmErrorType = errorType + } + if (errorCode) { + // @ts-ignore + enhancedError.vllmErrorCode = errorCode + } + + throw enhancedError + } + }, +} diff --git a/apps/sim/stores/providers/store.ts b/apps/sim/stores/providers/store.ts index 403175d40..fc353e101 100644 --- a/apps/sim/stores/providers/store.ts +++ b/apps/sim/stores/providers/store.ts @@ -8,6 +8,7 @@ export const useProvidersStore = create((set, get) => ({ providers: { base: { models: [], isLoading: false }, ollama: { models: [], isLoading: false }, + vllm: { models: [], isLoading: false }, openrouter: { models: [], isLoading: false }, }, diff --git a/apps/sim/stores/providers/types.ts b/apps/sim/stores/providers/types.ts index 80555826a..d89c41b3d 100644 --- a/apps/sim/stores/providers/types.ts +++ b/apps/sim/stores/providers/types.ts @@ -1,4 +1,4 @@ -export type ProviderName = 'ollama' | 'openrouter' | 'base' +export type ProviderName = 'ollama' | 'vllm' | 'openrouter' | 'base' export interface ProviderState { models: string[]