mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-09 15:07:55 -05:00
feature(models): added vllm provider (#2103)
* Add vLLM self-hosted provider * updated vllm to have pull parity with openai, dynamically fetch models --------- Co-authored-by: MagellaX <alphacr792@gmail.com>
This commit is contained in:
@@ -13,6 +13,7 @@ const modelProviderIcons = [
|
||||
{ icon: Icons.OllamaIcon, label: 'Ollama' },
|
||||
{ icon: Icons.DeepseekIcon, label: 'Deepseek' },
|
||||
{ icon: Icons.ElevenLabsIcon, label: 'ElevenLabs' },
|
||||
{ icon: Icons.VllmIcon, label: 'vLLM' },
|
||||
]
|
||||
|
||||
const communicationIcons = [
|
||||
@@ -88,7 +89,6 @@ interface TickerRowProps {
|
||||
}
|
||||
|
||||
function TickerRow({ direction, offset, showOdd, icons }: TickerRowProps) {
|
||||
// Create multiple copies of the icons array for seamless looping
|
||||
const extendedIcons = [...icons, ...icons, ...icons, ...icons]
|
||||
|
||||
return (
|
||||
|
||||
@@ -20,4 +20,6 @@ INTERNAL_API_SECRET=your_internal_api_secret # Use `openssl rand -hex 32` to gen
|
||||
# If left commented out, emails will be logged to console instead
|
||||
|
||||
# Local AI Models (Optional)
|
||||
# OLLAMA_URL=http://localhost:11434 # URL for local Ollama server - uncomment if using local models
|
||||
# OLLAMA_URL=http://localhost:11434 # URL for local Ollama server - uncomment if using local models
|
||||
# VLLM_BASE_URL=http://localhost:8000 # Base URL for your self-hosted vLLM (OpenAI-compatible)
|
||||
# VLLM_API_KEY= # Optional bearer token if your vLLM instance requires auth
|
||||
|
||||
56
apps/sim/app/api/providers/vllm/models/route.ts
Normal file
56
apps/sim/app/api/providers/vllm/models/route.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('VLLMModelsAPI')
|
||||
|
||||
/**
|
||||
* Get available vLLM models
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
const baseUrl = (env.VLLM_BASE_URL || '').replace(/\/$/, '')
|
||||
|
||||
if (!baseUrl) {
|
||||
logger.info('VLLM_BASE_URL not configured')
|
||||
return NextResponse.json({ models: [] })
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info('Fetching vLLM models', {
|
||||
baseUrl,
|
||||
})
|
||||
|
||||
const response = await fetch(`${baseUrl}/v1/models`, {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
next: { revalidate: 60 },
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
logger.warn('vLLM service is not available', {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
})
|
||||
return NextResponse.json({ models: [] })
|
||||
}
|
||||
|
||||
const data = (await response.json()) as { data: Array<{ id: string }> }
|
||||
const models = data.data.map((model) => `vllm/${model.id}`)
|
||||
|
||||
logger.info('Successfully fetched vLLM models', {
|
||||
count: models.length,
|
||||
models,
|
||||
})
|
||||
|
||||
return NextResponse.json({ models })
|
||||
} catch (error) {
|
||||
logger.error('Failed to fetch vLLM models', {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
baseUrl,
|
||||
})
|
||||
|
||||
// Return empty array instead of error to avoid breaking the UI
|
||||
return NextResponse.json({ models: [] })
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,11 @@
|
||||
import { useEffect } from 'react'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { useProviderModels } from '@/hooks/queries/providers'
|
||||
import { updateOllamaProviderModels, updateOpenRouterProviderModels } from '@/providers/utils'
|
||||
import {
|
||||
updateOllamaProviderModels,
|
||||
updateOpenRouterProviderModels,
|
||||
updateVLLMProviderModels,
|
||||
} from '@/providers/utils'
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
import type { ProviderName } from '@/stores/providers/types'
|
||||
|
||||
@@ -24,6 +28,8 @@ function useSyncProvider(provider: ProviderName) {
|
||||
try {
|
||||
if (provider === 'ollama') {
|
||||
updateOllamaProviderModels(data)
|
||||
} else if (provider === 'vllm') {
|
||||
updateVLLMProviderModels(data)
|
||||
} else if (provider === 'openrouter') {
|
||||
void updateOpenRouterProviderModels(data)
|
||||
}
|
||||
@@ -44,6 +50,7 @@ function useSyncProvider(provider: ProviderName) {
|
||||
export function ProviderModelsLoader() {
|
||||
useSyncProvider('base')
|
||||
useSyncProvider('ollama')
|
||||
useSyncProvider('vllm')
|
||||
useSyncProvider('openrouter')
|
||||
return null
|
||||
}
|
||||
|
||||
@@ -18,6 +18,10 @@ const getCurrentOllamaModels = () => {
|
||||
return useProvidersStore.getState().providers.ollama.models
|
||||
}
|
||||
|
||||
const getCurrentVLLMModels = () => {
|
||||
return useProvidersStore.getState().providers.vllm.models
|
||||
}
|
||||
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
import type { ToolResponse } from '@/tools/types'
|
||||
|
||||
@@ -90,8 +94,11 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
||||
const providersState = useProvidersStore.getState()
|
||||
const baseModels = providersState.providers.base.models
|
||||
const ollamaModels = providersState.providers.ollama.models
|
||||
const vllmModels = providersState.providers.vllm.models
|
||||
const openrouterModels = providersState.providers.openrouter.models
|
||||
const allModels = Array.from(new Set([...baseModels, ...ollamaModels, ...openrouterModels]))
|
||||
const allModels = Array.from(
|
||||
new Set([...baseModels, ...ollamaModels, ...vllmModels, ...openrouterModels])
|
||||
)
|
||||
|
||||
return allModels.map((model) => {
|
||||
const icon = getProviderIcon(model)
|
||||
@@ -172,7 +179,7 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
||||
password: true,
|
||||
connectionDroppable: false,
|
||||
required: true,
|
||||
// Hide API key for hosted models and Ollama models
|
||||
// Hide API key for hosted models, Ollama models, and vLLM models
|
||||
condition: isHosted
|
||||
? {
|
||||
field: 'model',
|
||||
@@ -181,8 +188,8 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
||||
}
|
||||
: () => ({
|
||||
field: 'model',
|
||||
value: getCurrentOllamaModels(),
|
||||
not: true, // Show for all models EXCEPT Ollama models
|
||||
value: [...getCurrentOllamaModels(), ...getCurrentVLLMModels()],
|
||||
not: true, // Show for all models EXCEPT Ollama and vLLM models
|
||||
}),
|
||||
},
|
||||
{
|
||||
|
||||
@@ -4150,3 +4150,13 @@ export function VideoIcon(props: SVGProps<SVGSVGElement>) {
|
||||
</svg>
|
||||
)
|
||||
}
|
||||
|
||||
export function VllmIcon(props: SVGProps<SVGSVGElement>) {
|
||||
return (
|
||||
<svg {...props} fill='currentColor' viewBox='0 0 24 24' xmlns='http://www.w3.org/2000/svg'>
|
||||
<title>vLLM</title>
|
||||
<path d='M0 4.973h9.324V23L0 4.973z' fill='#FDB515' />
|
||||
<path d='M13.986 4.351L22.378 0l-6.216 23H9.324l4.662-18.649z' fill='#30A2FF' />
|
||||
</svg>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ const logger = createLogger('ProviderModelsQuery')
|
||||
const providerEndpoints: Record<ProviderName, string> = {
|
||||
base: '/api/providers/base/models',
|
||||
ollama: '/api/providers/ollama/models',
|
||||
vllm: '/api/providers/vllm/models',
|
||||
openrouter: '/api/providers/openrouter/models',
|
||||
}
|
||||
|
||||
|
||||
@@ -77,6 +77,8 @@ export const env = createEnv({
|
||||
ANTHROPIC_API_KEY_2: z.string().min(1).optional(), // Additional Anthropic API key for load balancing
|
||||
ANTHROPIC_API_KEY_3: z.string().min(1).optional(), // Additional Anthropic API key for load balancing
|
||||
OLLAMA_URL: z.string().url().optional(), // Ollama local LLM server URL
|
||||
VLLM_BASE_URL: z.string().url().optional(), // vLLM self-hosted base URL (OpenAI-compatible)
|
||||
VLLM_API_KEY: z.string().optional(), // Optional bearer token for vLLM
|
||||
ELEVENLABS_API_KEY: z.string().min(1).optional(), // ElevenLabs API key for text-to-speech in deployed chat
|
||||
SERPER_API_KEY: z.string().min(1).optional(), // Serper API key for online search
|
||||
EXA_API_KEY: z.string().min(1).optional(), // Exa AI API key for enhanced online search
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
OllamaIcon,
|
||||
OpenAIIcon,
|
||||
OpenRouterIcon,
|
||||
VllmIcon,
|
||||
xAIIcon,
|
||||
} from '@/components/icons'
|
||||
|
||||
@@ -82,6 +83,19 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
contextInformationAvailable: false,
|
||||
models: [],
|
||||
},
|
||||
vllm: {
|
||||
id: 'vllm',
|
||||
name: 'vLLM',
|
||||
icon: VllmIcon,
|
||||
description: 'Self-hosted vLLM with an OpenAI-compatible API',
|
||||
defaultModel: 'vllm/generic',
|
||||
modelPatterns: [/^vllm\//],
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
toolUsageControl: true,
|
||||
},
|
||||
models: [],
|
||||
},
|
||||
openai: {
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
@@ -1366,6 +1380,21 @@ export function updateOllamaModels(models: string[]): void {
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Update vLLM models dynamically
|
||||
*/
|
||||
export function updateVLLMModels(models: string[]): void {
|
||||
PROVIDER_DEFINITIONS.vllm.models = models.map((modelId) => ({
|
||||
id: modelId,
|
||||
pricing: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
updatedAt: new Date().toISOString().split('T')[0],
|
||||
},
|
||||
capabilities: {},
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Update OpenRouter models dynamically
|
||||
*/
|
||||
|
||||
@@ -12,6 +12,7 @@ export type ProviderId =
|
||||
| 'mistral'
|
||||
| 'ollama'
|
||||
| 'openrouter'
|
||||
| 'vllm'
|
||||
|
||||
/**
|
||||
* Model pricing information per million tokens
|
||||
|
||||
@@ -30,6 +30,7 @@ import { ollamaProvider } from '@/providers/ollama'
|
||||
import { openaiProvider } from '@/providers/openai'
|
||||
import { openRouterProvider } from '@/providers/openrouter'
|
||||
import type { ProviderConfig, ProviderId, ProviderToolConfig } from '@/providers/types'
|
||||
import { vllmProvider } from '@/providers/vllm'
|
||||
import { xAIProvider } from '@/providers/xai'
|
||||
import { useCustomToolsStore } from '@/stores/custom-tools/store'
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
@@ -86,6 +87,11 @@ export const providers: Record<
|
||||
models: getProviderModelsFromDefinitions('groq'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.groq.modelPatterns,
|
||||
},
|
||||
vllm: {
|
||||
...vllmProvider,
|
||||
models: getProviderModelsFromDefinitions('vllm'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.vllm.modelPatterns,
|
||||
},
|
||||
mistral: {
|
||||
...mistralProvider,
|
||||
models: getProviderModelsFromDefinitions('mistral'),
|
||||
@@ -123,6 +129,12 @@ export function updateOllamaProviderModels(models: string[]): void {
|
||||
providers.ollama.models = getProviderModelsFromDefinitions('ollama')
|
||||
}
|
||||
|
||||
export function updateVLLMProviderModels(models: string[]): void {
|
||||
const { updateVLLMModels } = require('@/providers/models')
|
||||
updateVLLMModels(models)
|
||||
providers.vllm.models = getProviderModelsFromDefinitions('vllm')
|
||||
}
|
||||
|
||||
export async function updateOpenRouterProviderModels(models: string[]): Promise<void> {
|
||||
const { updateOpenRouterModels } = await import('@/providers/models')
|
||||
updateOpenRouterModels(models)
|
||||
@@ -131,7 +143,10 @@ export async function updateOpenRouterProviderModels(models: string[]): Promise<
|
||||
|
||||
export function getBaseModelProviders(): Record<string, ProviderId> {
|
||||
const allProviders = Object.entries(providers)
|
||||
.filter(([providerId]) => providerId !== 'ollama' && providerId !== 'openrouter')
|
||||
.filter(
|
||||
([providerId]) =>
|
||||
providerId !== 'ollama' && providerId !== 'vllm' && providerId !== 'openrouter'
|
||||
)
|
||||
.reduce(
|
||||
(map, [providerId, config]) => {
|
||||
config.models.forEach((model) => {
|
||||
|
||||
635
apps/sim/providers/vllm/index.ts
Normal file
635
apps/sim/providers/vllm/index.ts
Normal file
@@ -0,0 +1,635 @@
|
||||
import OpenAI from 'openai'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('VLLMProvider')
|
||||
const VLLM_VERSION = '1.0.0'
|
||||
|
||||
/**
|
||||
* Helper function to convert a vLLM stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
*/
|
||||
function createReadableStreamFromVLLMStream(
|
||||
vllmStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of vllmStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const vllmProvider: ProviderConfig = {
|
||||
id: 'vllm',
|
||||
name: 'vLLM',
|
||||
description: 'Self-hosted vLLM with OpenAI-compatible API',
|
||||
version: VLLM_VERSION,
|
||||
models: getProviderModels('vllm'),
|
||||
defaultModel: getProviderDefaultModel('vllm'),
|
||||
|
||||
async initialize() {
|
||||
if (typeof window !== 'undefined') {
|
||||
logger.info('Skipping vLLM initialization on client side to avoid CORS issues')
|
||||
return
|
||||
}
|
||||
|
||||
const baseUrl = (env.VLLM_BASE_URL || '').replace(/\/$/, '')
|
||||
if (!baseUrl) {
|
||||
logger.info('VLLM_BASE_URL not configured, skipping initialization')
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${baseUrl}/v1/models`)
|
||||
if (!response.ok) {
|
||||
useProvidersStore.getState().setProviderModels('vllm', [])
|
||||
logger.warn('vLLM service is not available. The provider will be disabled.')
|
||||
return
|
||||
}
|
||||
|
||||
const data = (await response.json()) as { data: Array<{ id: string }> }
|
||||
const models = data.data.map((model) => `vllm/${model.id}`)
|
||||
|
||||
this.models = models
|
||||
useProvidersStore.getState().setProviderModels('vllm', models)
|
||||
|
||||
logger.info(`Discovered ${models.length} vLLM model(s):`, { models })
|
||||
} catch (error) {
|
||||
logger.warn('vLLM model instantiation failed. The provider will be disabled.', {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
executeRequest: async (
|
||||
request: ProviderRequest
|
||||
): Promise<ProviderResponse | StreamingExecution> => {
|
||||
logger.info('Preparing vLLM request', {
|
||||
model: request.model,
|
||||
hasSystemPrompt: !!request.systemPrompt,
|
||||
hasMessages: !!request.messages?.length,
|
||||
hasTools: !!request.tools?.length,
|
||||
toolCount: request.tools?.length || 0,
|
||||
hasResponseFormat: !!request.responseFormat,
|
||||
stream: !!request.stream,
|
||||
})
|
||||
|
||||
const baseUrl = (request.azureEndpoint || env.VLLM_BASE_URL || '').replace(/\/$/, '')
|
||||
if (!baseUrl) {
|
||||
throw new Error('VLLM_BASE_URL is required for vLLM provider')
|
||||
}
|
||||
|
||||
const apiKey = request.apiKey || env.VLLM_API_KEY || 'empty'
|
||||
const vllm = new OpenAI({
|
||||
apiKey,
|
||||
baseURL: `${baseUrl}/v1`,
|
||||
})
|
||||
|
||||
const allMessages = [] as any[]
|
||||
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
content: request.systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
content: request.context,
|
||||
})
|
||||
}
|
||||
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
}))
|
||||
: undefined
|
||||
|
||||
const payload: any = {
|
||||
model: (request.model || getProviderDefaultModel('vllm')).replace(/^vllm\//, ''),
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: request.responseFormat.name || 'response_schema',
|
||||
schema: request.responseFormat.schema || request.responseFormat,
|
||||
strict: request.responseFormat.strict !== false,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info('Added JSON schema response format to vLLM request')
|
||||
}
|
||||
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
let hasActiveTools = false
|
||||
|
||||
if (tools?.length) {
|
||||
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'vllm')
|
||||
const { tools: filteredTools, toolChoice } = preparedTools
|
||||
|
||||
if (filteredTools?.length && toolChoice) {
|
||||
payload.tools = filteredTools
|
||||
payload.tool_choice = toolChoice
|
||||
hasActiveTools = true
|
||||
|
||||
logger.info('vLLM request configuration:', {
|
||||
toolCount: filteredTools.length,
|
||||
toolChoice:
|
||||
typeof toolChoice === 'string'
|
||||
? toolChoice
|
||||
: toolChoice.type === 'function'
|
||||
? `force:${toolChoice.function.name}`
|
||||
: 'unknown',
|
||||
model: payload.model,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
if (request.stream && (!tools || tools.length === 0 || !hasActiveTools)) {
|
||||
logger.info('Using streaming response for vLLM request')
|
||||
|
||||
const streamResponse = await vllm.chat.completions.create({
|
||||
...payload,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
})
|
||||
|
||||
const tokenUsage = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromVLLMStream(streamResponse, (content, usage) => {
|
||||
let cleanContent = content
|
||||
if (cleanContent && request.responseFormat) {
|
||||
cleanContent = cleanContent.replace(/```json\n?|\n?```/g, '').trim()
|
||||
}
|
||||
|
||||
streamingResult.execution.output.content = cleanContent
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
if (streamingResult.execution.output.providerTiming) {
|
||||
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
|
||||
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
|
||||
streamEndTime
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokenUsage.prompt,
|
||||
completion: usage.completion_tokens || tokenUsage.completion,
|
||||
total: usage.total_tokens || tokenUsage.total,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: tokenUsage,
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
timeSegments: [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Streaming response',
|
||||
startTime: providerStartTime,
|
||||
endTime: Date.now(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
const initialCallTime = Date.now()
|
||||
|
||||
const originalToolChoice = payload.tool_choice
|
||||
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
|
||||
const checkForForcedToolUsage = (
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }
|
||||
) => {
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
'vllm',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
}
|
||||
}
|
||||
|
||||
let currentResponse = await vllm.chat.completions.create(payload)
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
|
||||
if (content && request.responseFormat) {
|
||||
content = content.replace(/```json\n?|\n?```/g, '').trim()
|
||||
}
|
||||
|
||||
const tokens = {
|
||||
prompt: currentResponse.usage?.prompt_tokens || 0,
|
||||
completion: currentResponse.usage?.completion_tokens || 0,
|
||||
total: currentResponse.usage?.total_tokens || 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10
|
||||
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
let hasUsedForcedTool = false
|
||||
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Initial response',
|
||||
startTime: initialCallTime,
|
||||
endTime: initialCallTime + firstResponseTime,
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
]
|
||||
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
|
||||
)
|
||||
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
const toolCallStartTime = Date.now()
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error processing tool call:', {
|
||||
error,
|
||||
toolName: toolCall?.function?.name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) {
|
||||
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
|
||||
|
||||
if (remainingTools.length > 0) {
|
||||
nextPayload.tool_choice = {
|
||||
type: 'function',
|
||||
function: { name: remainingTools[0] },
|
||||
}
|
||||
logger.info(`Forcing next tool: ${remainingTools[0]}`)
|
||||
} else {
|
||||
nextPayload.tool_choice = 'auto'
|
||||
logger.info('All forced tools have been used, switching to auto tool_choice')
|
||||
}
|
||||
}
|
||||
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
currentResponse = await vllm.chat.completions.create(nextPayload)
|
||||
|
||||
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
startTime: nextModelStartTime,
|
||||
endTime: nextModelEndTime,
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
modelTime += thisModelTime
|
||||
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
if (request.responseFormat) {
|
||||
content = content.replace(/```json\n?|\n?```/g, '').trim()
|
||||
}
|
||||
}
|
||||
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += currentResponse.usage.completion_tokens || 0
|
||||
tokens.total += currentResponse.usage.total_tokens || 0
|
||||
}
|
||||
|
||||
iterationCount++
|
||||
}
|
||||
|
||||
if (request.stream) {
|
||||
logger.info('Using streaming for final response after tool processing')
|
||||
|
||||
const streamingPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto',
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
|
||||
const streamResponse = await vllm.chat.completions.create(streamingPayload)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromVLLMStream(streamResponse, (content, usage) => {
|
||||
let cleanContent = content
|
||||
if (cleanContent && request.responseFormat) {
|
||||
cleanContent = cleanContent.replace(/```json\n?|\n?```/g, '').trim()
|
||||
}
|
||||
|
||||
streamingResult.execution.output.content = cleanContent
|
||||
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokens.prompt,
|
||||
completion: usage.completion_tokens || tokens.completion,
|
||||
total: usage.total_tokens || tokens.total,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
completion: tokens.completion,
|
||||
total: tokens.total,
|
||||
},
|
||||
toolCalls:
|
||||
toolCalls.length > 0
|
||||
? {
|
||||
list: toolCalls,
|
||||
count: toolCalls.length,
|
||||
}
|
||||
: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
return {
|
||||
content,
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
timing: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
let errorMessage = error instanceof Error ? error.message : String(error)
|
||||
let errorType: string | undefined
|
||||
let errorCode: number | undefined
|
||||
|
||||
if (error && typeof error === 'object' && 'error' in error) {
|
||||
const vllmError = error.error as any
|
||||
if (vllmError && typeof vllmError === 'object') {
|
||||
errorMessage = vllmError.message || errorMessage
|
||||
errorType = vllmError.type
|
||||
errorCode = vllmError.code
|
||||
}
|
||||
}
|
||||
|
||||
logger.error('Error in vLLM request:', {
|
||||
error: errorMessage,
|
||||
errorType,
|
||||
errorCode,
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
const enhancedError = new Error(errorMessage)
|
||||
// @ts-ignore - Adding timing and vLLM error properties
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
}
|
||||
if (errorType) {
|
||||
// @ts-ignore
|
||||
enhancedError.vllmErrorType = errorType
|
||||
}
|
||||
if (errorCode) {
|
||||
// @ts-ignore
|
||||
enhancedError.vllmErrorCode = errorCode
|
||||
}
|
||||
|
||||
throw enhancedError
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -8,6 +8,7 @@ export const useProvidersStore = create<ProvidersStore>((set, get) => ({
|
||||
providers: {
|
||||
base: { models: [], isLoading: false },
|
||||
ollama: { models: [], isLoading: false },
|
||||
vllm: { models: [], isLoading: false },
|
||||
openrouter: { models: [], isLoading: false },
|
||||
},
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
export type ProviderName = 'ollama' | 'openrouter' | 'base'
|
||||
export type ProviderName = 'ollama' | 'vllm' | 'openrouter' | 'base'
|
||||
|
||||
export interface ProviderState {
|
||||
models: string[]
|
||||
|
||||
Reference in New Issue
Block a user