diff --git a/apps/docs/package.json b/apps/docs/package.json index a589e671e..76e0fffa8 100644 --- a/apps/docs/package.json +++ b/apps/docs/package.json @@ -4,7 +4,7 @@ "private": true, "license": "Apache-2.0", "scripts": { - "dev": "next dev --port 3001", + "dev": "next dev --port 7322", "build": "fumadocs-mdx && NODE_OPTIONS='--max-old-space-size=8192' next build", "start": "next start", "postinstall": "fumadocs-mdx", diff --git a/apps/sim/app/api/copilot/chat/route.ts b/apps/sim/app/api/copilot/chat/route.ts index 7c3ead718..eb7331e0e 100644 --- a/apps/sim/app/api/copilot/chat/route.ts +++ b/apps/sim/app/api/copilot/chat/route.ts @@ -303,6 +303,14 @@ export async function POST(req: NextRequest) { apiVersion: 'preview', endpoint: env.AZURE_OPENAI_ENDPOINT, } + } else if (providerEnv === 'vertex') { + providerConfig = { + provider: 'vertex', + model: modelToUse, + apiKey: env.COPILOT_API_KEY, + vertexProject: env.VERTEX_PROJECT, + vertexLocation: env.VERTEX_LOCATION, + } } else { providerConfig = { provider: providerEnv, diff --git a/apps/sim/app/api/copilot/context-usage/route.ts b/apps/sim/app/api/copilot/context-usage/route.ts index edb2b31c5..fba208bb4 100644 --- a/apps/sim/app/api/copilot/context-usage/route.ts +++ b/apps/sim/app/api/copilot/context-usage/route.ts @@ -66,6 +66,14 @@ export async function POST(req: NextRequest) { apiVersion: env.AZURE_OPENAI_API_VERSION, endpoint: env.AZURE_OPENAI_ENDPOINT, } + } else if (providerEnv === 'vertex') { + providerConfig = { + provider: 'vertex', + model: modelToUse, + apiKey: env.COPILOT_API_KEY, + vertexProject: env.VERTEX_PROJECT, + vertexLocation: env.VERTEX_LOCATION, + } } else { providerConfig = { provider: providerEnv, diff --git a/apps/sim/app/api/providers/route.ts b/apps/sim/app/api/providers/route.ts index 6b95f67e9..ada02eb09 100644 --- a/apps/sim/app/api/providers/route.ts +++ b/apps/sim/app/api/providers/route.ts @@ -35,6 +35,8 @@ export async function POST(request: NextRequest) { apiKey, azureEndpoint, azureApiVersion, + vertexProject, + vertexLocation, responseFormat, workflowId, workspaceId, @@ -58,6 +60,8 @@ export async function POST(request: NextRequest) { hasApiKey: !!apiKey, hasAzureEndpoint: !!azureEndpoint, hasAzureApiVersion: !!azureApiVersion, + hasVertexProject: !!vertexProject, + hasVertexLocation: !!vertexLocation, hasResponseFormat: !!responseFormat, workflowId, stream: !!stream, @@ -104,6 +108,8 @@ export async function POST(request: NextRequest) { apiKey: finalApiKey, azureEndpoint, azureApiVersion, + vertexProject, + vertexLocation, responseFormat, workflowId, workspaceId, diff --git a/apps/sim/blocks/blocks/agent.ts b/apps/sim/blocks/blocks/agent.ts index 3e321d2cd..d9cbed2b5 100644 --- a/apps/sim/blocks/blocks/agent.ts +++ b/apps/sim/blocks/blocks/agent.ts @@ -8,6 +8,8 @@ import { getHostedModels, getMaxTemperature, getProviderIcon, + getReasoningEffortValuesForModel, + getVerbosityValuesForModel, MODELS_WITH_REASONING_EFFORT, MODELS_WITH_VERBOSITY, providers, @@ -114,12 +116,47 @@ export const AgentBlock: BlockConfig = { type: 'dropdown', placeholder: 'Select reasoning effort...', options: [ - { label: 'none', id: 'none' }, - { label: 'minimal', id: 'minimal' }, { label: 'low', id: 'low' }, { label: 'medium', id: 'medium' }, { label: 'high', id: 'high' }, ], + dependsOn: ['model'], + fetchOptions: async (blockId: string) => { + const { useSubBlockStore } = await import('@/stores/workflows/subblock/store') + const { useWorkflowRegistry } = await import('@/stores/workflows/registry/store') + + const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId + if (!activeWorkflowId) { + return [ + { label: 'low', id: 'low' }, + { label: 'medium', id: 'medium' }, + { label: 'high', id: 'high' }, + ] + } + + const workflowValues = useSubBlockStore.getState().workflowValues[activeWorkflowId] + const blockValues = workflowValues?.[blockId] + const modelValue = blockValues?.model as string + + if (!modelValue) { + return [ + { label: 'low', id: 'low' }, + { label: 'medium', id: 'medium' }, + { label: 'high', id: 'high' }, + ] + } + + const validOptions = getReasoningEffortValuesForModel(modelValue) + if (!validOptions) { + return [ + { label: 'low', id: 'low' }, + { label: 'medium', id: 'medium' }, + { label: 'high', id: 'high' }, + ] + } + + return validOptions.map((opt) => ({ label: opt, id: opt })) + }, value: () => 'medium', condition: { field: 'model', @@ -136,6 +173,43 @@ export const AgentBlock: BlockConfig = { { label: 'medium', id: 'medium' }, { label: 'high', id: 'high' }, ], + dependsOn: ['model'], + fetchOptions: async (blockId: string) => { + const { useSubBlockStore } = await import('@/stores/workflows/subblock/store') + const { useWorkflowRegistry } = await import('@/stores/workflows/registry/store') + + const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId + if (!activeWorkflowId) { + return [ + { label: 'low', id: 'low' }, + { label: 'medium', id: 'medium' }, + { label: 'high', id: 'high' }, + ] + } + + const workflowValues = useSubBlockStore.getState().workflowValues[activeWorkflowId] + const blockValues = workflowValues?.[blockId] + const modelValue = blockValues?.model as string + + if (!modelValue) { + return [ + { label: 'low', id: 'low' }, + { label: 'medium', id: 'medium' }, + { label: 'high', id: 'high' }, + ] + } + + const validOptions = getVerbosityValuesForModel(modelValue) + if (!validOptions) { + return [ + { label: 'low', id: 'low' }, + { label: 'medium', id: 'medium' }, + { label: 'high', id: 'high' }, + ] + } + + return validOptions.map((opt) => ({ label: opt, id: opt })) + }, value: () => 'medium', condition: { field: 'model', @@ -166,6 +240,28 @@ export const AgentBlock: BlockConfig = { value: providers['azure-openai'].models, }, }, + { + id: 'vertexProject', + title: 'Vertex AI Project', + type: 'short-input', + placeholder: 'your-gcp-project-id', + connectionDroppable: false, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, + { + id: 'vertexLocation', + title: 'Vertex AI Location', + type: 'short-input', + placeholder: 'us-central1', + connectionDroppable: false, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, { id: 'tools', title: 'Tools', @@ -465,6 +561,8 @@ Example 3 (Array Input): apiKey: { type: 'string', description: 'Provider API key' }, azureEndpoint: { type: 'string', description: 'Azure OpenAI endpoint URL' }, azureApiVersion: { type: 'string', description: 'Azure API version' }, + vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' }, + vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' }, responseFormat: { type: 'json', description: 'JSON response format schema', diff --git a/apps/sim/blocks/blocks/evaluator.ts b/apps/sim/blocks/blocks/evaluator.ts index e809ed047..63ea9c74c 100644 --- a/apps/sim/blocks/blocks/evaluator.ts +++ b/apps/sim/blocks/blocks/evaluator.ts @@ -239,6 +239,28 @@ export const EvaluatorBlock: BlockConfig = { value: providers['azure-openai'].models, }, }, + { + id: 'vertexProject', + title: 'Vertex AI Project', + type: 'short-input', + placeholder: 'your-gcp-project-id', + connectionDroppable: false, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, + { + id: 'vertexLocation', + title: 'Vertex AI Location', + type: 'short-input', + placeholder: 'us-central1', + connectionDroppable: false, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, { id: 'temperature', title: 'Temperature', @@ -356,6 +378,14 @@ export const EvaluatorBlock: BlockConfig = { apiKey: { type: 'string' as ParamType, description: 'Provider API key' }, azureEndpoint: { type: 'string' as ParamType, description: 'Azure OpenAI endpoint URL' }, azureApiVersion: { type: 'string' as ParamType, description: 'Azure API version' }, + vertexProject: { + type: 'string' as ParamType, + description: 'Google Cloud project ID for Vertex AI', + }, + vertexLocation: { + type: 'string' as ParamType, + description: 'Google Cloud location for Vertex AI', + }, temperature: { type: 'number' as ParamType, description: 'Response randomness level (low for consistent evaluation)', diff --git a/apps/sim/blocks/blocks/router.ts b/apps/sim/blocks/blocks/router.ts index 744aa5395..0c6006a43 100644 --- a/apps/sim/blocks/blocks/router.ts +++ b/apps/sim/blocks/blocks/router.ts @@ -188,6 +188,28 @@ export const RouterBlock: BlockConfig = { value: providers['azure-openai'].models, }, }, + { + id: 'vertexProject', + title: 'Vertex AI Project', + type: 'short-input', + placeholder: 'your-gcp-project-id', + connectionDroppable: false, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, + { + id: 'vertexLocation', + title: 'Vertex AI Location', + type: 'short-input', + placeholder: 'us-central1', + connectionDroppable: false, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, { id: 'temperature', title: 'Temperature', @@ -235,6 +257,8 @@ export const RouterBlock: BlockConfig = { apiKey: { type: 'string', description: 'Provider API key' }, azureEndpoint: { type: 'string', description: 'Azure OpenAI endpoint URL' }, azureApiVersion: { type: 'string', description: 'Azure API version' }, + vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' }, + vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' }, temperature: { type: 'number', description: 'Response randomness level (low for consistent routing)', diff --git a/apps/sim/blocks/blocks/translate.ts b/apps/sim/blocks/blocks/translate.ts index bd984b860..1ecfc7a20 100644 --- a/apps/sim/blocks/blocks/translate.ts +++ b/apps/sim/blocks/blocks/translate.ts @@ -99,6 +99,28 @@ export const TranslateBlock: BlockConfig = { value: providers['azure-openai'].models, }, }, + { + id: 'vertexProject', + title: 'Vertex AI Project', + type: 'short-input', + placeholder: 'your-gcp-project-id', + connectionDroppable: false, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, + { + id: 'vertexLocation', + title: 'Vertex AI Location', + type: 'short-input', + placeholder: 'us-central1', + connectionDroppable: false, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, { id: 'systemPrompt', title: 'System Prompt', @@ -120,6 +142,8 @@ export const TranslateBlock: BlockConfig = { apiKey: params.apiKey, azureEndpoint: params.azureEndpoint, azureApiVersion: params.azureApiVersion, + vertexProject: params.vertexProject, + vertexLocation: params.vertexLocation, }), }, }, @@ -129,6 +153,8 @@ export const TranslateBlock: BlockConfig = { apiKey: { type: 'string', description: 'Provider API key' }, azureEndpoint: { type: 'string', description: 'Azure OpenAI endpoint URL' }, azureApiVersion: { type: 'string', description: 'Azure API version' }, + vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' }, + vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' }, systemPrompt: { type: 'string', description: 'Translation instructions' }, }, outputs: { diff --git a/apps/sim/components/icons.tsx b/apps/sim/components/icons.tsx index 2e668f913..6c7f64138 100644 --- a/apps/sim/components/icons.tsx +++ b/apps/sim/components/icons.tsx @@ -2452,6 +2452,56 @@ export const GeminiIcon = (props: SVGProps) => ( ) +export const VertexIcon = (props: SVGProps) => ( + + + + + + + + + + + + + + + + + + + + + + + +) + export const CerebrasIcon = (props: SVGProps) => ( + provider: 'vertex' + model: string + apiKey?: string + vertexProject?: string + vertexLocation?: string + } + | { + provider: Exclude model?: string apiKey?: string } diff --git a/apps/sim/lib/core/config/env.ts b/apps/sim/lib/core/config/env.ts index 290b163d8..39780d841 100644 --- a/apps/sim/lib/core/config/env.ts +++ b/apps/sim/lib/core/config/env.ts @@ -98,6 +98,10 @@ export const env = createEnv({ OCR_AZURE_MODEL_NAME: z.string().optional(), // Azure Mistral OCR model name for document processing OCR_AZURE_API_KEY: z.string().min(1).optional(), // Azure Mistral OCR API key + // Vertex AI Configuration + VERTEX_PROJECT: z.string().optional(), // Google Cloud project ID for Vertex AI + VERTEX_LOCATION: z.string().optional(), // Google Cloud location/region for Vertex AI (defaults to us-central1) + // Monitoring & Analytics TELEMETRY_ENDPOINT: z.string().url().optional(), // Custom telemetry/analytics endpoint COST_MULTIPLIER: z.number().optional(), // Multiplier for cost calculations diff --git a/apps/sim/lib/mcp/service.ts b/apps/sim/lib/mcp/service.ts index cfadec8f4..1e95dd706 100644 --- a/apps/sim/lib/mcp/service.ts +++ b/apps/sim/lib/mcp/service.ts @@ -404,15 +404,11 @@ class McpService { failedCount++ const errorMessage = result.reason instanceof Error ? result.reason.message : 'Unknown error' - logger.warn( - `[${requestId}] Failed to discover tools from server ${server.name}:`, - result.reason - ) + logger.warn(`[${requestId}] Failed to discover tools from server ${server.name}:`) statusUpdates.push(this.updateServerStatus(server.id!, workspaceId, false, errorMessage)) } }) - // Update server statuses in parallel (don't block on this) Promise.allSettled(statusUpdates).catch((err) => { logger.error(`[${requestId}] Error updating server statuses:`, err) }) diff --git a/apps/sim/package.json b/apps/sim/package.json index b34c4a80f..b7aff8168 100644 --- a/apps/sim/package.json +++ b/apps/sim/package.json @@ -8,7 +8,7 @@ "node": ">=20.0.0" }, "scripts": { - "dev": "next dev --port 3000", + "dev": "next dev --port 7321", "dev:webpack": "next dev --webpack", "dev:sockets": "bun run socket-server/index.ts", "dev:full": "concurrently -n \"App,Realtime\" -c \"cyan,magenta\" \"bun run dev\" \"bun run dev:sockets\"", diff --git a/apps/sim/providers/anthropic/index.ts b/apps/sim/providers/anthropic/index.ts index 8afa26446..5e9f2d26f 100644 --- a/apps/sim/providers/anthropic/index.ts +++ b/apps/sim/providers/anthropic/index.ts @@ -1,35 +1,24 @@ import Anthropic from '@anthropic-ai/sdk' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import { + checkForForcedToolUsage, + createReadableStreamFromAnthropicStream, + generateToolUseId, +} from '@/providers/anthropic/utils' +import { getProviderDefaultModel, getProviderModels } from '@/providers/models' +import type { + ProviderConfig, + ProviderRequest, + ProviderResponse, + TimeSegment, +} from '@/providers/types' +import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils' import { executeTool } from '@/tools' -import { getProviderDefaultModel, getProviderModels } from '../models' -import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types' -import { prepareToolExecution, prepareToolsWithUsageControl, trackForcedToolUsage } from '../utils' const logger = createLogger('AnthropicProvider') -/** - * Helper to wrap Anthropic streaming into a browser-friendly ReadableStream - */ -function createReadableStreamFromAnthropicStream( - anthropicStream: AsyncIterable -): ReadableStream { - return new ReadableStream({ - async start(controller) { - try { - for await (const event of anthropicStream) { - if (event.type === 'content_block_delta' && event.delta?.text) { - controller.enqueue(new TextEncoder().encode(event.delta.text)) - } - } - controller.close() - } catch (err) { - controller.error(err) - } - }, - }) -} - export const anthropicProvider: ProviderConfig = { id: 'anthropic', name: 'Anthropic', @@ -47,11 +36,6 @@ export const anthropicProvider: ProviderConfig = { const anthropic = new Anthropic({ apiKey: request.apiKey }) - // Helper function to generate a simple unique ID for tool uses - const generateToolUseId = (toolName: string) => { - return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}` - } - // Transform messages to Anthropic format const messages: any[] = [] @@ -373,7 +357,6 @@ ${fieldDescriptions} const toolResults = [] const currentMessages = [...messages] let iterationCount = 0 - const MAX_ITERATIONS = 10 // Prevent infinite loops // Track if a forced tool has been used let hasUsedForcedTool = false @@ -393,47 +376,20 @@ ${fieldDescriptions} }, ] - // Helper function to check for forced tool usage in Anthropic responses - const checkForForcedToolUsage = (response: any, toolChoice: any) => { - if ( - typeof toolChoice === 'object' && - toolChoice !== null && - Array.isArray(response.content) - ) { - const toolUses = response.content.filter((item: any) => item.type === 'tool_use') - - if (toolUses.length > 0) { - // Convert Anthropic tool_use format to a format trackForcedToolUsage can understand - const adaptedToolCalls = toolUses.map((tool: any) => ({ - name: tool.name, - })) - - // Convert Anthropic tool_choice format to match OpenAI format for tracking - const adaptedToolChoice = - toolChoice.type === 'tool' ? { function: { name: toolChoice.name } } : toolChoice - - const result = trackForcedToolUsage( - adaptedToolCalls, - adaptedToolChoice, - logger, - 'anthropic', - forcedTools, - usedForcedTools - ) - // Make the behavior consistent with the initial check - hasUsedForcedTool = result.hasUsedForcedTool - usedForcedTools = result.usedForcedTools - return result - } - } - return null + // Check if a forced tool was used in the first response + const firstCheckResult = checkForForcedToolUsage( + currentResponse, + originalToolChoice, + forcedTools, + usedForcedTools + ) + if (firstCheckResult) { + hasUsedForcedTool = firstCheckResult.hasUsedForcedTool + usedForcedTools = firstCheckResult.usedForcedTools } - // Check if a forced tool was used in the first response - checkForForcedToolUsage(currentResponse, originalToolChoice) - try { - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { // Check for tool calls const toolUses = currentResponse.content.filter((item) => item.type === 'tool_use') if (!toolUses || toolUses.length === 0) { @@ -576,7 +532,16 @@ ${fieldDescriptions} currentResponse = await anthropic.messages.create(nextPayload) // Check if any forced tools were used in this response - checkForForcedToolUsage(currentResponse, nextPayload.tool_choice) + const nextCheckResult = checkForForcedToolUsage( + currentResponse, + nextPayload.tool_choice, + forcedTools, + usedForcedTools + ) + if (nextCheckResult) { + hasUsedForcedTool = nextCheckResult.hasUsedForcedTool + usedForcedTools = nextCheckResult.usedForcedTools + } const nextModelEndTime = Date.now() const thisModelTime = nextModelEndTime - nextModelStartTime @@ -727,7 +692,6 @@ ${fieldDescriptions} const toolResults = [] const currentMessages = [...messages] let iterationCount = 0 - const MAX_ITERATIONS = 10 // Prevent infinite loops // Track if a forced tool has been used let hasUsedForcedTool = false @@ -747,47 +711,20 @@ ${fieldDescriptions} }, ] - // Helper function to check for forced tool usage in Anthropic responses - const checkForForcedToolUsage = (response: any, toolChoice: any) => { - if ( - typeof toolChoice === 'object' && - toolChoice !== null && - Array.isArray(response.content) - ) { - const toolUses = response.content.filter((item: any) => item.type === 'tool_use') - - if (toolUses.length > 0) { - // Convert Anthropic tool_use format to a format trackForcedToolUsage can understand - const adaptedToolCalls = toolUses.map((tool: any) => ({ - name: tool.name, - })) - - // Convert Anthropic tool_choice format to match OpenAI format for tracking - const adaptedToolChoice = - toolChoice.type === 'tool' ? { function: { name: toolChoice.name } } : toolChoice - - const result = trackForcedToolUsage( - adaptedToolCalls, - adaptedToolChoice, - logger, - 'anthropic', - forcedTools, - usedForcedTools - ) - // Make the behavior consistent with the initial check - hasUsedForcedTool = result.hasUsedForcedTool - usedForcedTools = result.usedForcedTools - return result - } - } - return null + // Check if a forced tool was used in the first response + const firstCheckResult = checkForForcedToolUsage( + currentResponse, + originalToolChoice, + forcedTools, + usedForcedTools + ) + if (firstCheckResult) { + hasUsedForcedTool = firstCheckResult.hasUsedForcedTool + usedForcedTools = firstCheckResult.usedForcedTools } - // Check if a forced tool was used in the first response - checkForForcedToolUsage(currentResponse, originalToolChoice) - try { - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { // Check for tool calls const toolUses = currentResponse.content.filter((item) => item.type === 'tool_use') if (!toolUses || toolUses.length === 0) { @@ -926,7 +863,16 @@ ${fieldDescriptions} currentResponse = await anthropic.messages.create(nextPayload) // Check if any forced tools were used in this response - checkForForcedToolUsage(currentResponse, nextPayload.tool_choice) + const nextCheckResult = checkForForcedToolUsage( + currentResponse, + nextPayload.tool_choice, + forcedTools, + usedForcedTools + ) + if (nextCheckResult) { + hasUsedForcedTool = nextCheckResult.hasUsedForcedTool + usedForcedTools = nextCheckResult.usedForcedTools + } const nextModelEndTime = Date.now() const thisModelTime = nextModelEndTime - nextModelStartTime diff --git a/apps/sim/providers/anthropic/utils.ts b/apps/sim/providers/anthropic/utils.ts new file mode 100644 index 000000000..d45a0e2a0 --- /dev/null +++ b/apps/sim/providers/anthropic/utils.ts @@ -0,0 +1,70 @@ +import { createLogger } from '@/lib/logs/console/logger' +import { trackForcedToolUsage } from '@/providers/utils' + +const logger = createLogger('AnthropicUtils') + +/** + * Helper to wrap Anthropic streaming into a browser-friendly ReadableStream + */ +export function createReadableStreamFromAnthropicStream( + anthropicStream: AsyncIterable +): ReadableStream { + return new ReadableStream({ + async start(controller) { + try { + for await (const event of anthropicStream) { + if (event.type === 'content_block_delta' && event.delta?.text) { + controller.enqueue(new TextEncoder().encode(event.delta.text)) + } + } + controller.close() + } catch (err) { + controller.error(err) + } + }, + }) +} + +/** + * Helper function to generate a simple unique ID for tool uses + */ +export function generateToolUseId(toolName: string): string { + return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}` +} + +/** + * Helper function to check for forced tool usage in Anthropic responses + */ +export function checkForForcedToolUsage( + response: any, + toolChoice: any, + forcedTools: string[], + usedForcedTools: string[] +): { hasUsedForcedTool: boolean; usedForcedTools: string[] } | null { + if (typeof toolChoice === 'object' && toolChoice !== null && Array.isArray(response.content)) { + const toolUses = response.content.filter((item: any) => item.type === 'tool_use') + + if (toolUses.length > 0) { + // Convert Anthropic tool_use format to a format trackForcedToolUsage can understand + const adaptedToolCalls = toolUses.map((tool: any) => ({ + name: tool.name, + })) + + // Convert Anthropic tool_choice format to match OpenAI format for tracking + const adaptedToolChoice = + toolChoice.type === 'tool' ? { function: { name: toolChoice.name } } : toolChoice + + const result = trackForcedToolUsage( + adaptedToolCalls, + adaptedToolChoice, + logger, + 'anthropic', + forcedTools, + usedForcedTools + ) + + return result + } + } + return null +} diff --git a/apps/sim/providers/azure-openai/index.ts b/apps/sim/providers/azure-openai/index.ts index fd4f71a56..c2cf28339 100644 --- a/apps/sim/providers/azure-openai/index.ts +++ b/apps/sim/providers/azure-openai/index.ts @@ -2,6 +2,11 @@ import { AzureOpenAI } from 'openai' import { env } from '@/lib/core/config/env' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import { + checkForForcedToolUsage, + createReadableStreamFromAzureOpenAIStream, +} from '@/providers/azure-openai/utils' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import type { ProviderConfig, @@ -9,55 +14,11 @@ import type { ProviderResponse, TimeSegment, } from '@/providers/types' -import { - prepareToolExecution, - prepareToolsWithUsageControl, - trackForcedToolUsage, -} from '@/providers/utils' +import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils' import { executeTool } from '@/tools' const logger = createLogger('AzureOpenAIProvider') -/** - * Helper function to convert an Azure OpenAI stream to a standard ReadableStream - * and collect completion metrics - */ -function createReadableStreamFromAzureOpenAIStream( - azureOpenAIStream: any, - onComplete?: (content: string, usage?: any) => void -): ReadableStream { - let fullContent = '' - let usageData: any = null - - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of azureOpenAIStream) { - // Check for usage data in the final chunk - if (chunk.usage) { - usageData = chunk.usage - } - - const content = chunk.choices[0]?.delta?.content || '' - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - - // Once stream is complete, call the completion callback with the final content and usage - if (onComplete) { - onComplete(fullContent, usageData) - } - - controller.close() - } catch (error) { - controller.error(error) - } - }, - }) -} - /** * Azure OpenAI provider configuration */ @@ -303,26 +264,6 @@ export const azureOpenAIProvider: ProviderConfig = { const forcedTools = preparedTools?.forcedTools || [] let usedForcedTools: string[] = [] - // Helper function to check for forced tool usage in responses - const checkForForcedToolUsage = ( - response: any, - toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any } - ) => { - if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) { - const toolCallsResponse = response.choices[0].message.tool_calls - const result = trackForcedToolUsage( - toolCallsResponse, - toolChoice, - logger, - 'azure-openai', - forcedTools, - usedForcedTools - ) - hasUsedForcedTool = result.hasUsedForcedTool - usedForcedTools = result.usedForcedTools - } - } - let currentResponse = await azureOpenAI.chat.completions.create(payload) const firstResponseTime = Date.now() - initialCallTime @@ -337,7 +278,6 @@ export const azureOpenAIProvider: ProviderConfig = { const toolResults = [] const currentMessages = [...allMessages] let iterationCount = 0 - const MAX_ITERATIONS = 10 // Prevent infinite loops // Track time spent in model vs tools let modelTime = firstResponseTime @@ -358,9 +298,17 @@ export const azureOpenAIProvider: ProviderConfig = { ] // Check if a forced tool was used in the first response - checkForForcedToolUsage(currentResponse, originalToolChoice) + const firstCheckResult = checkForForcedToolUsage( + currentResponse, + originalToolChoice, + logger, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = firstCheckResult.hasUsedForcedTool + usedForcedTools = firstCheckResult.usedForcedTools - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { // Check for tool calls const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls if (!toolCallsInResponse || toolCallsInResponse.length === 0) { @@ -368,7 +316,7 @@ export const azureOpenAIProvider: ProviderConfig = { } logger.info( - `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})` + `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` ) // Track time for tool calls in this batch @@ -491,7 +439,15 @@ export const azureOpenAIProvider: ProviderConfig = { currentResponse = await azureOpenAI.chat.completions.create(nextPayload) // Check if any forced tools were used in this response - checkForForcedToolUsage(currentResponse, nextPayload.tool_choice) + const nextCheckResult = checkForForcedToolUsage( + currentResponse, + nextPayload.tool_choice, + logger, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = nextCheckResult.hasUsedForcedTool + usedForcedTools = nextCheckResult.usedForcedTools const nextModelEndTime = Date.now() const thisModelTime = nextModelEndTime - nextModelStartTime diff --git a/apps/sim/providers/azure-openai/utils.ts b/apps/sim/providers/azure-openai/utils.ts new file mode 100644 index 000000000..b8baf9978 --- /dev/null +++ b/apps/sim/providers/azure-openai/utils.ts @@ -0,0 +1,70 @@ +import type { Logger } from '@/lib/logs/console/logger' +import { trackForcedToolUsage } from '@/providers/utils' + +/** + * Helper function to convert an Azure OpenAI stream to a standard ReadableStream + * and collect completion metrics + */ +export function createReadableStreamFromAzureOpenAIStream( + azureOpenAIStream: any, + onComplete?: (content: string, usage?: any) => void +): ReadableStream { + let fullContent = '' + let usageData: any = null + + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of azureOpenAIStream) { + if (chunk.usage) { + usageData = chunk.usage + } + + const content = chunk.choices[0]?.delta?.content || '' + if (content) { + fullContent += content + controller.enqueue(new TextEncoder().encode(content)) + } + } + + if (onComplete) { + onComplete(fullContent, usageData) + } + + controller.close() + } catch (error) { + controller.error(error) + } + }, + }) +} + +/** + * Helper function to check for forced tool usage in responses + */ +export function checkForForcedToolUsage( + response: any, + toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }, + logger: Logger, + forcedTools: string[], + usedForcedTools: string[] +): { hasUsedForcedTool: boolean; usedForcedTools: string[] } { + let hasUsedForcedTool = false + let updatedUsedForcedTools = [...usedForcedTools] + + if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) { + const toolCallsResponse = response.choices[0].message.tool_calls + const result = trackForcedToolUsage( + toolCallsResponse, + toolChoice, + logger, + 'azure-openai', + forcedTools, + updatedUsedForcedTools + ) + hasUsedForcedTool = result.hasUsedForcedTool + updatedUsedForcedTools = result.usedForcedTools + } + + return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools } +} diff --git a/apps/sim/providers/cerebras/index.ts b/apps/sim/providers/cerebras/index.ts index 3ebc8b412..f017565ab 100644 --- a/apps/sim/providers/cerebras/index.ts +++ b/apps/sim/providers/cerebras/index.ts @@ -1,6 +1,9 @@ import { Cerebras } from '@cerebras/cerebras_cloud_sdk' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import type { CerebrasResponse } from '@/providers/cerebras/types' +import { createReadableStreamFromCerebrasStream } from '@/providers/cerebras/utils' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import type { ProviderConfig, @@ -14,35 +17,9 @@ import { trackForcedToolUsage, } from '@/providers/utils' import { executeTool } from '@/tools' -import type { CerebrasResponse } from './types' const logger = createLogger('CerebrasProvider') -/** - * Helper to convert a Cerebras streaming response (async iterable) into a ReadableStream. - * Enqueues only the model's text delta chunks as UTF-8 encoded bytes. - */ -function createReadableStreamFromCerebrasStream( - cerebrasStream: AsyncIterable -): ReadableStream { - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of cerebrasStream) { - // Expecting delta content similar to OpenAI: chunk.choices[0]?.delta?.content - const content = chunk.choices?.[0]?.delta?.content || '' - if (content) { - controller.enqueue(new TextEncoder().encode(content)) - } - } - controller.close() - } catch (error) { - controller.error(error) - } - }, - }) -} - export const cerebrasProvider: ProviderConfig = { id: 'cerebras', name: 'Cerebras', @@ -223,7 +200,6 @@ export const cerebrasProvider: ProviderConfig = { const toolResults = [] const currentMessages = [...allMessages] let iterationCount = 0 - const MAX_ITERATIONS = 10 // Prevent infinite loops // Track time spent in model vs tools let modelTime = firstResponseTime @@ -246,7 +222,7 @@ export const cerebrasProvider: ProviderConfig = { const toolCallSignatures = new Set() try { - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { // Check for tool calls const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls diff --git a/apps/sim/providers/cerebras/utils.ts b/apps/sim/providers/cerebras/utils.ts new file mode 100644 index 000000000..01dcfd5fe --- /dev/null +++ b/apps/sim/providers/cerebras/utils.ts @@ -0,0 +1,23 @@ +/** + * Helper to convert a Cerebras streaming response (async iterable) into a ReadableStream. + * Enqueues only the model's text delta chunks as UTF-8 encoded bytes. + */ +export function createReadableStreamFromCerebrasStream( + cerebrasStream: AsyncIterable +): ReadableStream { + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of cerebrasStream) { + const content = chunk.choices?.[0]?.delta?.content || '' + if (content) { + controller.enqueue(new TextEncoder().encode(content)) + } + } + controller.close() + } catch (error) { + controller.error(error) + } + }, + }) +} diff --git a/apps/sim/providers/deepseek/index.ts b/apps/sim/providers/deepseek/index.ts index a303b70b6..7425d84fa 100644 --- a/apps/sim/providers/deepseek/index.ts +++ b/apps/sim/providers/deepseek/index.ts @@ -1,6 +1,8 @@ import OpenAI from 'openai' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import { createReadableStreamFromDeepseekStream } from '@/providers/deepseek/utils' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import type { ProviderConfig, @@ -17,28 +19,6 @@ import { executeTool } from '@/tools' const logger = createLogger('DeepseekProvider') -/** - * Helper function to convert a DeepSeek (OpenAI-compatible) stream to a ReadableStream - * of text chunks that can be consumed by the browser. - */ -function createReadableStreamFromDeepseekStream(deepseekStream: any): ReadableStream { - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of deepseekStream) { - const content = chunk.choices[0]?.delta?.content || '' - if (content) { - controller.enqueue(new TextEncoder().encode(content)) - } - } - controller.close() - } catch (error) { - controller.error(error) - } - }, - }) -} - export const deepseekProvider: ProviderConfig = { id: 'deepseek', name: 'Deepseek', @@ -231,7 +211,6 @@ export const deepseekProvider: ProviderConfig = { const toolResults = [] const currentMessages = [...allMessages] let iterationCount = 0 - const MAX_ITERATIONS = 10 // Prevent infinite loops // Track if a forced tool has been used let hasUsedForcedTool = false @@ -270,7 +249,7 @@ export const deepseekProvider: ProviderConfig = { } try { - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { // Check for tool calls const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls if (!toolCallsInResponse || toolCallsInResponse.length === 0) { diff --git a/apps/sim/providers/deepseek/utils.ts b/apps/sim/providers/deepseek/utils.ts new file mode 100644 index 000000000..228f5e346 --- /dev/null +++ b/apps/sim/providers/deepseek/utils.ts @@ -0,0 +1,21 @@ +/** + * Helper function to convert a DeepSeek (OpenAI-compatible) stream to a ReadableStream + * of text chunks that can be consumed by the browser. + */ +export function createReadableStreamFromDeepseekStream(deepseekStream: any): ReadableStream { + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of deepseekStream) { + const content = chunk.choices[0]?.delta?.content || '' + if (content) { + controller.enqueue(new TextEncoder().encode(content)) + } + } + controller.close() + } catch (error) { + controller.error(error) + } + }, + }) +} diff --git a/apps/sim/providers/google/index.ts b/apps/sim/providers/google/index.ts index 0ff67344f..fdd225a46 100644 --- a/apps/sim/providers/google/index.ts +++ b/apps/sim/providers/google/index.ts @@ -1,5 +1,12 @@ import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import { + cleanSchemaForGemini, + convertToGeminiFormat, + extractFunctionCall, + extractTextContent, +} from '@/providers/google/utils' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import type { ProviderConfig, @@ -19,7 +26,13 @@ const logger = createLogger('GoogleProvider') /** * Creates a ReadableStream from Google's Gemini stream response */ -function createReadableStreamFromGeminiStream(response: Response): ReadableStream { +function createReadableStreamFromGeminiStream( + response: Response, + onComplete?: ( + content: string, + usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number } + ) => void +): ReadableStream { const reader = response.body?.getReader() if (!reader) { throw new Error('Failed to get reader from response body') @@ -29,18 +42,24 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea async start(controller) { try { let buffer = '' + let fullContent = '' + let usageData: { + promptTokenCount?: number + candidatesTokenCount?: number + totalTokenCount?: number + } | null = null while (true) { const { done, value } = await reader.read() if (done) { - // Try to parse any remaining buffer as complete JSON if (buffer.trim()) { - // Processing final buffer try { const data = JSON.parse(buffer.trim()) + if (data.usageMetadata) { + usageData = data.usageMetadata + } const candidate = data.candidates?.[0] if (candidate?.content?.parts) { - // Check if this is a function call const functionCall = extractFunctionCall(candidate) if (functionCall) { logger.debug( @@ -49,26 +68,27 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea functionName: functionCall.name, } ) - // Function calls should not be streamed - end the stream early + if (onComplete) onComplete(fullContent, usageData || undefined) controller.close() return } const content = extractTextContent(candidate) if (content) { + fullContent += content controller.enqueue(new TextEncoder().encode(content)) } } } catch (e) { - // Final buffer not valid JSON, checking if it contains JSON array - // Try parsing as JSON array if it starts with [ if (buffer.trim().startsWith('[')) { try { const dataArray = JSON.parse(buffer.trim()) if (Array.isArray(dataArray)) { for (const item of dataArray) { + if (item.usageMetadata) { + usageData = item.usageMetadata + } const candidate = item.candidates?.[0] if (candidate?.content?.parts) { - // Check if this is a function call const functionCall = extractFunctionCall(candidate) if (functionCall) { logger.debug( @@ -77,11 +97,13 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea functionName: functionCall.name, } ) + if (onComplete) onComplete(fullContent, usageData || undefined) controller.close() return } const content = extractTextContent(candidate) if (content) { + fullContent += content controller.enqueue(new TextEncoder().encode(content)) } } @@ -93,6 +115,7 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea } } } + if (onComplete) onComplete(fullContent, usageData || undefined) controller.close() break } @@ -100,14 +123,11 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea const text = new TextDecoder().decode(value) buffer += text - // Try to find complete JSON objects in buffer - // Look for patterns like: {...}\n{...} or just a single {...} let searchIndex = 0 while (searchIndex < buffer.length) { const openBrace = buffer.indexOf('{', searchIndex) if (openBrace === -1) break - // Try to find the matching closing brace let braceCount = 0 let inString = false let escaped = false @@ -138,28 +158,34 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea } if (closeBrace !== -1) { - // Found a complete JSON object const jsonStr = buffer.substring(openBrace, closeBrace + 1) try { const data = JSON.parse(jsonStr) - // JSON parsed successfully from stream + + if (data.usageMetadata) { + usageData = data.usageMetadata + } const candidate = data.candidates?.[0] - // Handle specific finish reasons if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') { logger.warn('Gemini returned UNEXPECTED_TOOL_CALL in streaming mode', { finishReason: candidate.finishReason, hasContent: !!candidate?.content, hasParts: !!candidate?.content?.parts, }) - // This indicates a configuration issue - tools might be improperly configured for streaming - continue + const textContent = extractTextContent(candidate) + if (textContent) { + fullContent += textContent + controller.enqueue(new TextEncoder().encode(textContent)) + } + if (onComplete) onComplete(fullContent, usageData || undefined) + controller.close() + return } if (candidate?.content?.parts) { - // Check if this is a function call const functionCall = extractFunctionCall(candidate) if (functionCall) { logger.debug( @@ -168,13 +194,13 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea functionName: functionCall.name, } ) - // Function calls should not be streamed - we need to end the stream - // and let the non-streaming tool execution flow handle this + if (onComplete) onComplete(fullContent, usageData || undefined) controller.close() return } const content = extractTextContent(candidate) if (content) { + fullContent += content controller.enqueue(new TextEncoder().encode(content)) } } @@ -185,7 +211,6 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea }) } - // Remove processed JSON from buffer and continue searching buffer = buffer.substring(closeBrace + 1) searchIndex = 0 } else { @@ -232,45 +257,36 @@ export const googleProvider: ProviderConfig = { streaming: !!request.stream, }) - // Start execution timer for the entire provider execution const providerStartTime = Date.now() const providerStartTimeISO = new Date(providerStartTime).toISOString() try { - // Convert messages to Gemini format const { contents, tools, systemInstruction } = convertToGeminiFormat(request) const requestedModel = request.model || 'gemini-2.5-pro' - // Build request payload const payload: any = { contents, generationConfig: {}, } - // Add temperature if specified if (request.temperature !== undefined && request.temperature !== null) { payload.generationConfig.temperature = request.temperature } - // Add max tokens if specified if (request.maxTokens !== undefined) { payload.generationConfig.maxOutputTokens = request.maxTokens } - // Add system instruction if provided if (systemInstruction) { payload.systemInstruction = systemInstruction } - // Add structured output format if requested (but not when tools are present) if (request.responseFormat && !tools?.length) { const responseFormatSchema = request.responseFormat.schema || request.responseFormat - // Clean the schema using our helper function const cleanSchema = cleanSchemaForGemini(responseFormatSchema) - // Use Gemini's native structured output approach payload.generationConfig.responseMimeType = 'application/json' payload.generationConfig.responseSchema = cleanSchema @@ -284,7 +300,6 @@ export const googleProvider: ProviderConfig = { ) } - // Handle tools and tool usage control let preparedTools: ReturnType | null = null if (tools?.length) { @@ -298,7 +313,6 @@ export const googleProvider: ProviderConfig = { }, ] - // Add Google-specific tool configuration if (toolConfig) { payload.toolConfig = toolConfig } @@ -313,14 +327,10 @@ export const googleProvider: ProviderConfig = { } } - // Make the API request const initialCallTime = Date.now() - // Disable streaming for initial requests when tools are present to avoid function calls in streams - // Only enable streaming for the final response after tool execution const shouldStream = request.stream && !tools?.length - // Use streamGenerateContent for streaming requests const endpoint = shouldStream ? `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:streamGenerateContent?key=${request.apiKey}` : `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:generateContent?key=${request.apiKey}` @@ -352,16 +362,11 @@ export const googleProvider: ProviderConfig = { const firstResponseTime = Date.now() - initialCallTime - // Handle streaming response if (shouldStream) { logger.info('Handling Google Gemini streaming response') - // Create a ReadableStream from the Google Gemini stream - const stream = createReadableStreamFromGeminiStream(response) - - // Create an object that combines the stream with execution metadata - const streamingExecution: StreamingExecution = { - stream, + const streamingResult: StreamingExecution = { + stream: null as any, execution: { success: true, output: { @@ -389,7 +394,6 @@ export const googleProvider: ProviderConfig = { duration: firstResponseTime, }, ], - // Cost will be calculated in logger }, }, logs: [], @@ -402,18 +406,49 @@ export const googleProvider: ProviderConfig = { }, } - return streamingExecution + streamingResult.stream = createReadableStreamFromGeminiStream( + response, + (content, usage) => { + streamingResult.execution.output.content = content + + const streamEndTime = Date.now() + const streamEndTimeISO = new Date(streamEndTime).toISOString() + + if (streamingResult.execution.output.providerTiming) { + streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO + streamingResult.execution.output.providerTiming.duration = + streamEndTime - providerStartTime + + if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) { + streamingResult.execution.output.providerTiming.timeSegments[0].endTime = + streamEndTime + streamingResult.execution.output.providerTiming.timeSegments[0].duration = + streamEndTime - providerStartTime + } + } + + if (usage) { + streamingResult.execution.output.tokens = { + prompt: usage.promptTokenCount || 0, + completion: usage.candidatesTokenCount || 0, + total: + usage.totalTokenCount || + (usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0), + } + } + } + ) + + return streamingResult } let geminiResponse = await response.json() - // Check structured output format if (payload.generationConfig?.responseSchema) { const candidate = geminiResponse.candidates?.[0] if (candidate?.content?.parts?.[0]?.text) { const text = candidate.content.parts[0].text try { - // Validate JSON structure JSON.parse(text) logger.info('Successfully received structured JSON output') } catch (_e) { @@ -422,7 +457,6 @@ export const googleProvider: ProviderConfig = { } } - // Initialize response tracking variables let content = '' let tokens = { prompt: 0, @@ -432,16 +466,13 @@ export const googleProvider: ProviderConfig = { const toolCalls = [] const toolResults = [] let iterationCount = 0 - const MAX_ITERATIONS = 10 // Prevent infinite loops - // Track forced tools and their usage (similar to OpenAI pattern) const originalToolConfig = preparedTools?.toolConfig const forcedTools = preparedTools?.forcedTools || [] let usedForcedTools: string[] = [] let hasUsedForcedTool = false let currentToolConfig = originalToolConfig - // Helper function to check for forced tool usage in responses const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => { if (currentToolConfig && forcedTools.length > 0) { const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }] @@ -466,11 +497,9 @@ export const googleProvider: ProviderConfig = { } } - // Track time spent in model vs tools let modelTime = firstResponseTime let toolsTime = 0 - // Track each model and tool call segment with timestamps const timeSegments: TimeSegment[] = [ { type: 'model', @@ -482,46 +511,50 @@ export const googleProvider: ProviderConfig = { ] try { - // Extract content or function calls from initial response const candidate = geminiResponse.candidates?.[0] - // Check if response contains function calls + if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') { + logger.warn( + 'Gemini returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided', + { + finishReason: candidate.finishReason, + hasContent: !!candidate?.content, + hasParts: !!candidate?.content?.parts, + } + ) + content = extractTextContent(candidate) + } + const functionCall = extractFunctionCall(candidate) if (functionCall) { logger.info(`Received function call from Gemini: ${functionCall.name}`) - // Process function calls in a loop - while (iterationCount < MAX_ITERATIONS) { - // Get the latest function calls + while (iterationCount < MAX_TOOL_ITERATIONS) { const latestResponse = geminiResponse.candidates?.[0] const latestFunctionCall = extractFunctionCall(latestResponse) if (!latestFunctionCall) { - // No more function calls - extract final text content content = extractTextContent(latestResponse) break } logger.info( - `Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_ITERATIONS})` + `Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` ) - // Track time for tool calls const toolsStartTime = Date.now() try { const toolName = latestFunctionCall.name const toolArgs = latestFunctionCall.args || {} - // Get the tool from the tools registry const tool = request.tools?.find((t) => t.id === toolName) if (!tool) { logger.warn(`Tool ${toolName} not found in registry, skipping`) break } - // Execute the tool const toolCallStartTime = Date.now() const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) @@ -529,7 +562,6 @@ export const googleProvider: ProviderConfig = { const toolCallEndTime = Date.now() const toolCallDuration = toolCallEndTime - toolCallStartTime - // Add to time segments for both success and failure timeSegments.push({ type: 'tool', name: toolName, @@ -538,13 +570,11 @@ export const googleProvider: ProviderConfig = { duration: toolCallDuration, }) - // Prepare result content for the LLM let resultContent: any if (result.success) { toolResults.push(result.output) resultContent = result.output } else { - // Include error information so LLM can respond appropriately resultContent = { error: true, message: result.error || 'Tool execution failed', @@ -562,14 +592,10 @@ export const googleProvider: ProviderConfig = { success: result.success, }) - // Prepare for next request with simplified messages - // Use simple format: original query + most recent function call + result const simplifiedMessages = [ - // Original user request - find the first user request ...(contents.filter((m) => m.role === 'user').length > 0 ? [contents.filter((m) => m.role === 'user')[0]] : [contents[0]]), - // Function call from model { role: 'model', parts: [ @@ -581,7 +607,6 @@ export const googleProvider: ProviderConfig = { }, ], }, - // Function response - but use USER role since Gemini only accepts user or model { role: 'user', parts: [ @@ -592,35 +617,27 @@ export const googleProvider: ProviderConfig = { }, ] - // Calculate tool call time const thisToolsTime = Date.now() - toolsStartTime toolsTime += thisToolsTime - // Check for forced tool usage and update configuration checkForForcedToolUsage(latestFunctionCall) - // Make the next request with updated messages const nextModelStartTime = Date.now() try { - // Check if we should stream the final response after tool calls if (request.stream) { - // Create a payload for the streaming response after tool calls const streamingPayload = { ...payload, contents: simplifiedMessages, } - // Check if we should remove tools and enable structured output for final response const allForcedToolsUsed = forcedTools.length > 0 && usedForcedTools.length === forcedTools.length if (allForcedToolsUsed && request.responseFormat) { - // All forced tools have been used, we can now remove tools and enable structured output streamingPayload.tools = undefined streamingPayload.toolConfig = undefined - // Add structured output format for final response const responseFormatSchema = request.responseFormat.schema || request.responseFormat const cleanSchema = cleanSchemaForGemini(responseFormatSchema) @@ -633,7 +650,6 @@ export const googleProvider: ProviderConfig = { logger.info('Using structured output for final response after tool execution') } else { - // Use updated tool configuration if available, otherwise default to AUTO if (currentToolConfig) { streamingPayload.toolConfig = currentToolConfig } else { @@ -641,11 +657,8 @@ export const googleProvider: ProviderConfig = { } } - // Check if we should handle this as a potential forced tool call - // First make a non-streaming request to see if we get a function call const checkPayload = { ...streamingPayload, - // Remove stream property to get non-streaming response } checkPayload.stream = undefined @@ -677,7 +690,6 @@ export const googleProvider: ProviderConfig = { const checkFunctionCall = extractFunctionCall(checkCandidate) if (checkFunctionCall) { - // We have a function call - handle it in non-streaming mode logger.info( 'Function call detected in follow-up, handling in non-streaming mode', { @@ -685,10 +697,8 @@ export const googleProvider: ProviderConfig = { } ) - // Update geminiResponse to continue the tool execution loop geminiResponse = checkResult - // Update token counts if available if (checkResult.usageMetadata) { tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0 tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0 @@ -697,12 +707,10 @@ export const googleProvider: ProviderConfig = { (checkResult.usageMetadata.candidatesTokenCount || 0) } - // Calculate timing for this model call const nextModelEndTime = Date.now() const thisModelTime = nextModelEndTime - nextModelStartTime modelTime += thisModelTime - // Add to time segments timeSegments.push({ type: 'model', name: `Model response (iteration ${iterationCount + 1})`, @@ -711,14 +719,32 @@ export const googleProvider: ProviderConfig = { duration: thisModelTime, }) - // Continue the loop to handle the function call iterationCount++ continue } - // No function call - proceed with streaming logger.info('No function call detected, proceeding with streaming response') - // Make the streaming request with streamGenerateContent endpoint + // Apply structured output for the final response if responseFormat is specified + // This works regardless of whether tools were forced or auto + if (request.responseFormat) { + streamingPayload.tools = undefined + streamingPayload.toolConfig = undefined + + const responseFormatSchema = + request.responseFormat.schema || request.responseFormat + const cleanSchema = cleanSchemaForGemini(responseFormatSchema) + + if (!streamingPayload.generationConfig) { + streamingPayload.generationConfig = {} + } + streamingPayload.generationConfig.responseMimeType = 'application/json' + streamingPayload.generationConfig.responseSchema = cleanSchema + + logger.info( + 'Using structured output for final streaming response after tool execution' + ) + } + const streamingResponse = await fetch( `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:streamGenerateContent?key=${request.apiKey}`, { @@ -742,15 +768,10 @@ export const googleProvider: ProviderConfig = { ) } - // Create a stream from the response - const stream = createReadableStreamFromGeminiStream(streamingResponse) - - // Calculate timing information const nextModelEndTime = Date.now() const thisModelTime = nextModelEndTime - nextModelStartTime modelTime += thisModelTime - // Add to time segments timeSegments.push({ type: 'model', name: 'Final streaming response after tool calls', @@ -759,9 +780,8 @@ export const googleProvider: ProviderConfig = { duration: thisModelTime, }) - // Return a streaming execution with tool call information const streamingExecution: StreamingExecution = { - stream, + stream: null as any, execution: { success: true, output: { @@ -786,7 +806,6 @@ export const googleProvider: ProviderConfig = { iterations: iterationCount + 1, timeSegments, }, - // Cost will be calculated in logger }, logs: [], metadata: { @@ -798,25 +817,55 @@ export const googleProvider: ProviderConfig = { }, } + streamingExecution.stream = createReadableStreamFromGeminiStream( + streamingResponse, + (content, usage) => { + streamingExecution.execution.output.content = content + + const streamEndTime = Date.now() + const streamEndTimeISO = new Date(streamEndTime).toISOString() + + if (streamingExecution.execution.output.providerTiming) { + streamingExecution.execution.output.providerTiming.endTime = + streamEndTimeISO + streamingExecution.execution.output.providerTiming.duration = + streamEndTime - providerStartTime + } + + if (usage) { + const existingTokens = streamingExecution.execution.output.tokens || { + prompt: 0, + completion: 0, + total: 0, + } + streamingExecution.execution.output.tokens = { + prompt: (existingTokens.prompt || 0) + (usage.promptTokenCount || 0), + completion: + (existingTokens.completion || 0) + (usage.candidatesTokenCount || 0), + total: + (existingTokens.total || 0) + + (usage.totalTokenCount || + (usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0)), + } + } + } + ) + return streamingExecution } - // Make the next request for non-streaming response const nextPayload = { ...payload, contents: simplifiedMessages, } - // Check if we should remove tools and enable structured output for final response const allForcedToolsUsed = forcedTools.length > 0 && usedForcedTools.length === forcedTools.length if (allForcedToolsUsed && request.responseFormat) { - // All forced tools have been used, we can now remove tools and enable structured output nextPayload.tools = undefined nextPayload.toolConfig = undefined - // Add structured output format for final response const responseFormatSchema = request.responseFormat.schema || request.responseFormat const cleanSchema = cleanSchemaForGemini(responseFormatSchema) @@ -831,7 +880,6 @@ export const googleProvider: ProviderConfig = { 'Using structured output for final non-streaming response after tool execution' ) } else { - // Add updated tool configuration if available if (currentToolConfig) { nextPayload.toolConfig = currentToolConfig } @@ -864,7 +912,6 @@ export const googleProvider: ProviderConfig = { const nextModelEndTime = Date.now() const thisModelTime = nextModelEndTime - nextModelStartTime - // Add to time segments timeSegments.push({ type: 'model', name: `Model response (iteration ${iterationCount + 1})`, @@ -873,15 +920,65 @@ export const googleProvider: ProviderConfig = { duration: thisModelTime, }) - // Add to model time modelTime += thisModelTime - // Check if we need to continue or break const nextCandidate = geminiResponse.candidates?.[0] const nextFunctionCall = extractFunctionCall(nextCandidate) if (!nextFunctionCall) { - content = extractTextContent(nextCandidate) + // If responseFormat is specified, make one final request with structured output + if (request.responseFormat) { + const finalPayload = { + ...payload, + contents: nextPayload.contents, + tools: undefined, + toolConfig: undefined, + } + + const responseFormatSchema = + request.responseFormat.schema || request.responseFormat + const cleanSchema = cleanSchemaForGemini(responseFormatSchema) + + if (!finalPayload.generationConfig) { + finalPayload.generationConfig = {} + } + finalPayload.generationConfig.responseMimeType = 'application/json' + finalPayload.generationConfig.responseSchema = cleanSchema + + logger.info('Making final request with structured output after tool execution') + + const finalResponse = await fetch( + `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:generateContent?key=${request.apiKey}`, + { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(finalPayload), + } + ) + + if (finalResponse.ok) { + const finalResult = await finalResponse.json() + const finalCandidate = finalResult.candidates?.[0] + content = extractTextContent(finalCandidate) + + if (finalResult.usageMetadata) { + tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0 + tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0 + tokens.total += + (finalResult.usageMetadata.promptTokenCount || 0) + + (finalResult.usageMetadata.candidatesTokenCount || 0) + } + } else { + logger.warn( + 'Failed to get structured output, falling back to regular response' + ) + content = extractTextContent(nextCandidate) + } + } else { + content = extractTextContent(nextCandidate) + } break } @@ -902,7 +999,6 @@ export const googleProvider: ProviderConfig = { } } } else { - // Regular text response content = extractTextContent(candidate) } } catch (error) { @@ -911,18 +1007,15 @@ export const googleProvider: ProviderConfig = { iterationCount, }) - // Don't rethrow, so we can still return partial results if (!content && toolCalls.length > 0) { content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.` } } - // Calculate overall timing const providerEndTime = Date.now() const providerEndTimeISO = new Date(providerEndTime).toISOString() const totalDuration = providerEndTime - providerStartTime - // Extract token usage if available if (geminiResponse.usageMetadata) { tokens = { prompt: geminiResponse.usageMetadata.promptTokenCount || 0, @@ -949,10 +1042,8 @@ export const googleProvider: ProviderConfig = { iterations: iterationCount + 1, timeSegments: timeSegments, }, - // Cost will be calculated in logger } } catch (error) { - // Include timing information even for errors const providerEndTime = Date.now() const providerEndTimeISO = new Date(providerEndTime).toISOString() const totalDuration = providerEndTime - providerStartTime @@ -962,7 +1053,6 @@ export const googleProvider: ProviderConfig = { duration: totalDuration, }) - // Create a new error with timing information const enhancedError = new Error(error instanceof Error ? error.message : String(error)) // @ts-ignore - Adding timing property to the error enhancedError.timing = { @@ -975,200 +1065,3 @@ export const googleProvider: ProviderConfig = { } }, } - -/** - * Helper function to remove additionalProperties from a schema object - * and perform a deep copy of the schema to avoid modifying the original - */ -function cleanSchemaForGemini(schema: any): any { - // Handle base cases - if (schema === null || schema === undefined) return schema - if (typeof schema !== 'object') return schema - if (Array.isArray(schema)) { - return schema.map((item) => cleanSchemaForGemini(item)) - } - - // Create a new object for the deep copy - const cleanedSchema: any = {} - - // Process each property in the schema - for (const key in schema) { - // Skip additionalProperties - if (key === 'additionalProperties') continue - - // Deep copy nested objects - cleanedSchema[key] = cleanSchemaForGemini(schema[key]) - } - - return cleanedSchema -} - -/** - * Helper function to extract content from a Gemini response, handling structured output - */ -function extractTextContent(candidate: any): string { - if (!candidate?.content?.parts) return '' - - // Check for JSON response (typically from structured output) - if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) { - const text = candidate.content.parts[0].text - if (text && (text.trim().startsWith('{') || text.trim().startsWith('['))) { - try { - JSON.parse(text) // Validate JSON - return text // Return valid JSON as-is - } catch (_e) { - /* Not valid JSON, continue with normal extraction */ - } - } - } - - // Standard text extraction - return candidate.content.parts - .filter((part: any) => part.text) - .map((part: any) => part.text) - .join('\n') -} - -/** - * Helper function to extract a function call from a Gemini response - */ -function extractFunctionCall(candidate: any): { name: string; args: any } | null { - if (!candidate?.content?.parts) return null - - // Check for functionCall in parts - for (const part of candidate.content.parts) { - if (part.functionCall) { - const args = part.functionCall.args || {} - // Parse string args if they look like JSON - if ( - typeof part.functionCall.args === 'string' && - part.functionCall.args.trim().startsWith('{') - ) { - try { - return { name: part.functionCall.name, args: JSON.parse(part.functionCall.args) } - } catch (_e) { - return { name: part.functionCall.name, args: part.functionCall.args } - } - } - return { name: part.functionCall.name, args } - } - } - - // Check for alternative function_call format - if (candidate.content.function_call) { - const args = - typeof candidate.content.function_call.arguments === 'string' - ? JSON.parse(candidate.content.function_call.arguments || '{}') - : candidate.content.function_call.arguments || {} - return { name: candidate.content.function_call.name, args } - } - - return null -} - -/** - * Convert OpenAI-style request format to Gemini format - */ -function convertToGeminiFormat(request: ProviderRequest): { - contents: any[] - tools: any[] | undefined - systemInstruction: any | undefined -} { - const contents = [] - let systemInstruction - - // Handle system prompt - if (request.systemPrompt) { - systemInstruction = { parts: [{ text: request.systemPrompt }] } - } - - // Add context as user message if present - if (request.context) { - contents.push({ role: 'user', parts: [{ text: request.context }] }) - } - - // Process messages - if (request.messages && request.messages.length > 0) { - for (const message of request.messages) { - if (message.role === 'system') { - // Add to system instruction - if (!systemInstruction) { - systemInstruction = { parts: [{ text: message.content }] } - } else { - // Append to existing system instruction - systemInstruction.parts[0].text = `${systemInstruction.parts[0].text || ''}\n${message.content}` - } - } else if (message.role === 'user' || message.role === 'assistant') { - // Convert to Gemini role format - const geminiRole = message.role === 'user' ? 'user' : 'model' - - // Add text content - if (message.content) { - contents.push({ role: geminiRole, parts: [{ text: message.content }] }) - } - - // Handle tool calls - if (message.role === 'assistant' && message.tool_calls && message.tool_calls.length > 0) { - const functionCalls = message.tool_calls.map((toolCall) => ({ - functionCall: { - name: toolCall.function?.name, - args: JSON.parse(toolCall.function?.arguments || '{}'), - }, - })) - - contents.push({ role: 'model', parts: functionCalls }) - } - } else if (message.role === 'tool') { - // Convert tool response (Gemini only accepts user/model roles) - contents.push({ - role: 'user', - parts: [{ text: `Function result: ${message.content}` }], - }) - } - } - } - - // Convert tools to Gemini function declarations - const tools = request.tools?.map((tool) => { - const toolParameters = { ...(tool.parameters || {}) } - - // Process schema properties - if (toolParameters.properties) { - const properties = { ...toolParameters.properties } - const required = toolParameters.required ? [...toolParameters.required] : [] - - // Remove defaults and optional parameters - for (const key in properties) { - const prop = properties[key] as any - - if (prop.default !== undefined) { - const { default: _, ...cleanProp } = prop - properties[key] = cleanProp - } - } - - // Build Gemini-compatible parameters schema - const parameters = { - type: toolParameters.type || 'object', - properties, - ...(required.length > 0 ? { required } : {}), - } - - // Clean schema for Gemini - return { - name: tool.id, - description: tool.description || `Execute the ${tool.id} function`, - parameters: cleanSchemaForGemini(parameters), - } - } - - // Simple schema case - return { - name: tool.id, - description: tool.description || `Execute the ${tool.id} function`, - parameters: cleanSchemaForGemini(toolParameters), - } - }) - - return { contents, tools, systemInstruction } -} diff --git a/apps/sim/providers/google/utils.ts b/apps/sim/providers/google/utils.ts new file mode 100644 index 000000000..8c14687e6 --- /dev/null +++ b/apps/sim/providers/google/utils.ts @@ -0,0 +1,171 @@ +import type { ProviderRequest } from '@/providers/types' + +/** + * Removes additionalProperties from a schema object (not supported by Gemini) + */ +export function cleanSchemaForGemini(schema: any): any { + if (schema === null || schema === undefined) return schema + if (typeof schema !== 'object') return schema + if (Array.isArray(schema)) { + return schema.map((item) => cleanSchemaForGemini(item)) + } + + const cleanedSchema: any = {} + + for (const key in schema) { + if (key === 'additionalProperties') continue + cleanedSchema[key] = cleanSchemaForGemini(schema[key]) + } + + return cleanedSchema +} + +/** + * Extracts text content from a Gemini response candidate, handling structured output + */ +export function extractTextContent(candidate: any): string { + if (!candidate?.content?.parts) return '' + + if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) { + const text = candidate.content.parts[0].text + if (text && (text.trim().startsWith('{') || text.trim().startsWith('['))) { + try { + JSON.parse(text) + return text + } catch (_e) { + /* Not valid JSON, continue with normal extraction */ + } + } + } + + return candidate.content.parts + .filter((part: any) => part.text) + .map((part: any) => part.text) + .join('\n') +} + +/** + * Extracts a function call from a Gemini response candidate + */ +export function extractFunctionCall(candidate: any): { name: string; args: any } | null { + if (!candidate?.content?.parts) return null + + for (const part of candidate.content.parts) { + if (part.functionCall) { + const args = part.functionCall.args || {} + if ( + typeof part.functionCall.args === 'string' && + part.functionCall.args.trim().startsWith('{') + ) { + try { + return { name: part.functionCall.name, args: JSON.parse(part.functionCall.args) } + } catch (_e) { + return { name: part.functionCall.name, args: part.functionCall.args } + } + } + return { name: part.functionCall.name, args } + } + } + + if (candidate.content.function_call) { + const args = + typeof candidate.content.function_call.arguments === 'string' + ? JSON.parse(candidate.content.function_call.arguments || '{}') + : candidate.content.function_call.arguments || {} + return { name: candidate.content.function_call.name, args } + } + + return null +} + +/** + * Converts OpenAI-style request format to Gemini format + */ +export function convertToGeminiFormat(request: ProviderRequest): { + contents: any[] + tools: any[] | undefined + systemInstruction: any | undefined +} { + const contents: any[] = [] + let systemInstruction + + if (request.systemPrompt) { + systemInstruction = { parts: [{ text: request.systemPrompt }] } + } + + if (request.context) { + contents.push({ role: 'user', parts: [{ text: request.context }] }) + } + + if (request.messages && request.messages.length > 0) { + for (const message of request.messages) { + if (message.role === 'system') { + if (!systemInstruction) { + systemInstruction = { parts: [{ text: message.content }] } + } else { + systemInstruction.parts[0].text = `${systemInstruction.parts[0].text || ''}\n${message.content}` + } + } else if (message.role === 'user' || message.role === 'assistant') { + const geminiRole = message.role === 'user' ? 'user' : 'model' + + if (message.content) { + contents.push({ role: geminiRole, parts: [{ text: message.content }] }) + } + + if (message.role === 'assistant' && message.tool_calls && message.tool_calls.length > 0) { + const functionCalls = message.tool_calls.map((toolCall) => ({ + functionCall: { + name: toolCall.function?.name, + args: JSON.parse(toolCall.function?.arguments || '{}'), + }, + })) + + contents.push({ role: 'model', parts: functionCalls }) + } + } else if (message.role === 'tool') { + contents.push({ + role: 'user', + parts: [{ text: `Function result: ${message.content}` }], + }) + } + } + } + + const tools = request.tools?.map((tool) => { + const toolParameters = { ...(tool.parameters || {}) } + + if (toolParameters.properties) { + const properties = { ...toolParameters.properties } + const required = toolParameters.required ? [...toolParameters.required] : [] + + for (const key in properties) { + const prop = properties[key] as any + + if (prop.default !== undefined) { + const { default: _, ...cleanProp } = prop + properties[key] = cleanProp + } + } + + const parameters = { + type: toolParameters.type || 'object', + properties, + ...(required.length > 0 ? { required } : {}), + } + + return { + name: tool.id, + description: tool.description || `Execute the ${tool.id} function`, + parameters: cleanSchemaForGemini(parameters), + } + } + + return { + name: tool.id, + description: tool.description || `Execute the ${tool.id} function`, + parameters: cleanSchemaForGemini(toolParameters), + } + }) + + return { contents, tools, systemInstruction } +} diff --git a/apps/sim/providers/groq/index.ts b/apps/sim/providers/groq/index.ts index 027f50192..97e00ac19 100644 --- a/apps/sim/providers/groq/index.ts +++ b/apps/sim/providers/groq/index.ts @@ -1,6 +1,8 @@ import { Groq } from 'groq-sdk' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import { createReadableStreamFromGroqStream } from '@/providers/groq/utils' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import type { ProviderConfig, @@ -17,27 +19,6 @@ import { executeTool } from '@/tools' const logger = createLogger('GroqProvider') -/** - * Helper to wrap Groq streaming into a browser-friendly ReadableStream - * of raw assistant text chunks. - */ -function createReadableStreamFromGroqStream(groqStream: any): ReadableStream { - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of groqStream) { - if (chunk.choices[0]?.delta?.content) { - controller.enqueue(new TextEncoder().encode(chunk.choices[0].delta.content)) - } - } - controller.close() - } catch (err) { - controller.error(err) - } - }, - }) -} - export const groqProvider: ProviderConfig = { id: 'groq', name: 'Groq', @@ -225,7 +206,6 @@ export const groqProvider: ProviderConfig = { const toolResults = [] const currentMessages = [...allMessages] let iterationCount = 0 - const MAX_ITERATIONS = 10 // Prevent infinite loops // Track time spent in model vs tools let modelTime = firstResponseTime @@ -243,7 +223,7 @@ export const groqProvider: ProviderConfig = { ] try { - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { // Check for tool calls const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls if (!toolCallsInResponse || toolCallsInResponse.length === 0) { diff --git a/apps/sim/providers/groq/utils.ts b/apps/sim/providers/groq/utils.ts new file mode 100644 index 000000000..845c73af1 --- /dev/null +++ b/apps/sim/providers/groq/utils.ts @@ -0,0 +1,23 @@ +/** + * Helper to wrap Groq streaming into a browser-friendly ReadableStream + * of raw assistant text chunks. + * + * @param groqStream - The Groq streaming response + * @returns A ReadableStream that emits text chunks + */ +export function createReadableStreamFromGroqStream(groqStream: any): ReadableStream { + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of groqStream) { + if (chunk.choices[0]?.delta?.content) { + controller.enqueue(new TextEncoder().encode(chunk.choices[0].delta.content)) + } + } + controller.close() + } catch (err) { + controller.error(err) + } + }, + }) +} diff --git a/apps/sim/providers/index.ts b/apps/sim/providers/index.ts index 72d1423e1..3dbed8f42 100644 --- a/apps/sim/providers/index.ts +++ b/apps/sim/providers/index.ts @@ -12,6 +12,12 @@ import { const logger = createLogger('Providers') +/** + * Maximum number of iterations for tool call loops to prevent infinite loops. + * Used across all providers that support tool/function calling. + */ +export const MAX_TOOL_ITERATIONS = 20 + function sanitizeRequest(request: ProviderRequest): ProviderRequest { const sanitizedRequest = { ...request } @@ -44,7 +50,6 @@ export async function executeProviderRequest( } const sanitizedRequest = sanitizeRequest(request) - // If responseFormat is provided, modify the system prompt to enforce structured output if (sanitizedRequest.responseFormat) { if ( typeof sanitizedRequest.responseFormat === 'string' && @@ -53,12 +58,10 @@ export async function executeProviderRequest( logger.info('Empty response format provided, ignoring it') sanitizedRequest.responseFormat = undefined } else { - // Generate structured output instructions const structuredOutputInstructions = generateStructuredOutputInstructions( sanitizedRequest.responseFormat ) - // Only add additional instructions if they're not empty if (structuredOutputInstructions.trim()) { const originalPrompt = sanitizedRequest.systemPrompt || '' sanitizedRequest.systemPrompt = @@ -69,10 +72,8 @@ export async function executeProviderRequest( } } - // Execute the request using the provider's implementation const response = await provider.executeRequest(sanitizedRequest) - // If we received a StreamingExecution or ReadableStream, just pass it through if (isStreamingExecution(response)) { logger.info('Provider returned StreamingExecution') return response diff --git a/apps/sim/providers/mistral/index.ts b/apps/sim/providers/mistral/index.ts index e2a194962..c4d05fdfb 100644 --- a/apps/sim/providers/mistral/index.ts +++ b/apps/sim/providers/mistral/index.ts @@ -1,6 +1,8 @@ import OpenAI from 'openai' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import { createReadableStreamFromMistralStream } from '@/providers/mistral/utils' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import type { ProviderConfig, @@ -17,40 +19,6 @@ import { executeTool } from '@/tools' const logger = createLogger('MistralProvider') -function createReadableStreamFromMistralStream( - mistralStream: any, - onComplete?: (content: string, usage?: any) => void -): ReadableStream { - let fullContent = '' - let usageData: any = null - - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of mistralStream) { - if (chunk.usage) { - usageData = chunk.usage - } - - const content = chunk.choices[0]?.delta?.content || '' - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - - if (onComplete) { - onComplete(fullContent, usageData) - } - - controller.close() - } catch (error) { - controller.error(error) - } - }, - }) -} - /** * Mistral AI provider configuration */ @@ -288,7 +256,6 @@ export const mistralProvider: ProviderConfig = { const toolResults = [] const currentMessages = [...allMessages] let iterationCount = 0 - const MAX_ITERATIONS = 10 let modelTime = firstResponseTime let toolsTime = 0 @@ -307,14 +274,14 @@ export const mistralProvider: ProviderConfig = { checkForForcedToolUsage(currentResponse, originalToolChoice) - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls if (!toolCallsInResponse || toolCallsInResponse.length === 0) { break } logger.info( - `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})` + `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` ) const toolsStartTime = Date.now() diff --git a/apps/sim/providers/mistral/utils.ts b/apps/sim/providers/mistral/utils.ts new file mode 100644 index 000000000..f33f517d0 --- /dev/null +++ b/apps/sim/providers/mistral/utils.ts @@ -0,0 +1,39 @@ +/** + * Creates a ReadableStream from a Mistral AI streaming response + * @param mistralStream - The Mistral AI stream object + * @param onComplete - Optional callback when streaming completes + * @returns A ReadableStream that yields text chunks + */ +export function createReadableStreamFromMistralStream( + mistralStream: any, + onComplete?: (content: string, usage?: any) => void +): ReadableStream { + let fullContent = '' + let usageData: any = null + + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of mistralStream) { + if (chunk.usage) { + usageData = chunk.usage + } + + const content = chunk.choices[0]?.delta?.content || '' + if (content) { + fullContent += content + controller.enqueue(new TextEncoder().encode(content)) + } + } + + if (onComplete) { + onComplete(fullContent, usageData) + } + + controller.close() + } catch (error) { + controller.error(error) + } + }, + }) +} diff --git a/apps/sim/providers/models.ts b/apps/sim/providers/models.ts index aac7c30b4..4183fc720 100644 --- a/apps/sim/providers/models.ts +++ b/apps/sim/providers/models.ts @@ -19,6 +19,7 @@ import { OllamaIcon, OpenAIIcon, OpenRouterIcon, + VertexIcon, VllmIcon, xAIIcon, } from '@/components/icons' @@ -130,7 +131,7 @@ export const PROVIDER_DEFINITIONS: Record = { }, capabilities: { reasoningEffort: { - values: ['none', 'low', 'medium', 'high'], + values: ['none', 'minimal', 'low', 'medium', 'high', 'xhigh'], }, verbosity: { values: ['low', 'medium', 'high'], @@ -283,7 +284,11 @@ export const PROVIDER_DEFINITIONS: Record = { output: 60, updatedAt: '2025-06-17', }, - capabilities: {}, + capabilities: { + reasoningEffort: { + values: ['low', 'medium', 'high'], + }, + }, contextWindow: 200000, }, { @@ -294,7 +299,11 @@ export const PROVIDER_DEFINITIONS: Record = { output: 8, updatedAt: '2025-06-17', }, - capabilities: {}, + capabilities: { + reasoningEffort: { + values: ['low', 'medium', 'high'], + }, + }, contextWindow: 128000, }, { @@ -305,7 +314,11 @@ export const PROVIDER_DEFINITIONS: Record = { output: 4.4, updatedAt: '2025-06-17', }, - capabilities: {}, + capabilities: { + reasoningEffort: { + values: ['low', 'medium', 'high'], + }, + }, contextWindow: 128000, }, { @@ -383,7 +396,7 @@ export const PROVIDER_DEFINITIONS: Record = { }, capabilities: { reasoningEffort: { - values: ['none', 'low', 'medium', 'high'], + values: ['none', 'minimal', 'low', 'medium', 'high', 'xhigh'], }, verbosity: { values: ['low', 'medium', 'high'], @@ -536,7 +549,11 @@ export const PROVIDER_DEFINITIONS: Record = { output: 40, updatedAt: '2025-06-15', }, - capabilities: {}, + capabilities: { + reasoningEffort: { + values: ['low', 'medium', 'high'], + }, + }, contextWindow: 128000, }, { @@ -547,7 +564,11 @@ export const PROVIDER_DEFINITIONS: Record = { output: 4.4, updatedAt: '2025-06-15', }, - capabilities: {}, + capabilities: { + reasoningEffort: { + values: ['low', 'medium', 'high'], + }, + }, contextWindow: 128000, }, { @@ -708,9 +729,22 @@ export const PROVIDER_DEFINITIONS: Record = { id: 'gemini-3-pro-preview', pricing: { input: 2.0, - cachedInput: 1.0, + cachedInput: 0.2, output: 12.0, - updatedAt: '2025-11-18', + updatedAt: '2025-12-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + }, + contextWindow: 1000000, + }, + { + id: 'gemini-3-flash-preview', + pricing: { + input: 0.5, + cachedInput: 0.05, + output: 3.0, + updatedAt: '2025-12-17', }, capabilities: { temperature: { min: 0, max: 2 }, @@ -756,6 +790,132 @@ export const PROVIDER_DEFINITIONS: Record = { }, contextWindow: 1048576, }, + { + id: 'gemini-2.0-flash', + pricing: { + input: 0.1, + output: 0.4, + updatedAt: '2025-12-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + }, + contextWindow: 1000000, + }, + { + id: 'gemini-2.0-flash-lite', + pricing: { + input: 0.075, + output: 0.3, + updatedAt: '2025-12-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + }, + contextWindow: 1000000, + }, + ], + }, + vertex: { + id: 'vertex', + name: 'Vertex AI', + description: "Google's Vertex AI platform for Gemini models", + defaultModel: 'vertex/gemini-2.5-pro', + modelPatterns: [/^vertex\//], + icon: VertexIcon, + capabilities: { + toolUsageControl: true, + }, + models: [ + { + id: 'vertex/gemini-3-pro-preview', + pricing: { + input: 2.0, + cachedInput: 0.2, + output: 12.0, + updatedAt: '2025-12-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + }, + contextWindow: 1000000, + }, + { + id: 'vertex/gemini-3-flash-preview', + pricing: { + input: 0.5, + cachedInput: 0.05, + output: 3.0, + updatedAt: '2025-12-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + }, + contextWindow: 1000000, + }, + { + id: 'vertex/gemini-2.5-pro', + pricing: { + input: 1.25, + cachedInput: 0.125, + output: 10.0, + updatedAt: '2025-12-02', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + }, + contextWindow: 1048576, + }, + { + id: 'vertex/gemini-2.5-flash', + pricing: { + input: 0.3, + cachedInput: 0.03, + output: 2.5, + updatedAt: '2025-12-02', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + }, + contextWindow: 1048576, + }, + { + id: 'vertex/gemini-2.5-flash-lite', + pricing: { + input: 0.1, + cachedInput: 0.01, + output: 0.4, + updatedAt: '2025-12-02', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + }, + contextWindow: 1048576, + }, + { + id: 'vertex/gemini-2.0-flash', + pricing: { + input: 0.1, + output: 0.4, + updatedAt: '2025-12-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + }, + contextWindow: 1000000, + }, + { + id: 'vertex/gemini-2.0-flash-lite', + pricing: { + input: 0.075, + output: 0.3, + updatedAt: '2025-12-17', + }, + capabilities: { + temperature: { min: 0, max: 2 }, + }, + contextWindow: 1000000, + }, ], }, deepseek: { @@ -1708,6 +1868,20 @@ export function getModelsWithReasoningEffort(): string[] { return models } +/** + * Get the reasoning effort values for a specific model + * Returns the valid options for that model, or null if the model doesn't support reasoning effort + */ +export function getReasoningEffortValuesForModel(modelId: string): string[] | null { + for (const provider of Object.values(PROVIDER_DEFINITIONS)) { + const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase()) + if (model?.capabilities.reasoningEffort) { + return model.capabilities.reasoningEffort.values + } + } + return null +} + /** * Get all models that support verbosity */ @@ -1722,3 +1896,17 @@ export function getModelsWithVerbosity(): string[] { } return models } + +/** + * Get the verbosity values for a specific model + * Returns the valid options for that model, or null if the model doesn't support verbosity + */ +export function getVerbosityValuesForModel(modelId: string): string[] | null { + for (const provider of Object.values(PROVIDER_DEFINITIONS)) { + const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase()) + if (model?.capabilities.verbosity) { + return model.capabilities.verbosity.values + } + } + return null +} diff --git a/apps/sim/providers/ollama/index.ts b/apps/sim/providers/ollama/index.ts index 0118e53ff..acdafa91a 100644 --- a/apps/sim/providers/ollama/index.ts +++ b/apps/sim/providers/ollama/index.ts @@ -2,7 +2,9 @@ import OpenAI from 'openai' import { env } from '@/lib/core/config/env' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' import type { ModelsObject } from '@/providers/ollama/types' +import { createReadableStreamFromOllamaStream } from '@/providers/ollama/utils' import type { ProviderConfig, ProviderRequest, @@ -16,46 +18,6 @@ import { executeTool } from '@/tools' const logger = createLogger('OllamaProvider') const OLLAMA_HOST = env.OLLAMA_URL || 'http://localhost:11434' -/** - * Helper function to convert an Ollama stream to a standard ReadableStream - * and collect completion metrics - */ -function createReadableStreamFromOllamaStream( - ollamaStream: any, - onComplete?: (content: string, usage?: any) => void -): ReadableStream { - let fullContent = '' - let usageData: any = null - - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of ollamaStream) { - // Check for usage data in the final chunk - if (chunk.usage) { - usageData = chunk.usage - } - - const content = chunk.choices[0]?.delta?.content || '' - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - - // Once stream is complete, call the completion callback with the final content and usage - if (onComplete) { - onComplete(fullContent, usageData) - } - - controller.close() - } catch (error) { - controller.error(error) - } - }, - }) -} - export const ollamaProvider: ProviderConfig = { id: 'ollama', name: 'Ollama', @@ -334,7 +296,6 @@ export const ollamaProvider: ProviderConfig = { const toolResults = [] const currentMessages = [...allMessages] let iterationCount = 0 - const MAX_ITERATIONS = 10 // Prevent infinite loops // Track time spent in model vs tools let modelTime = firstResponseTime @@ -351,7 +312,7 @@ export const ollamaProvider: ProviderConfig = { }, ] - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { // Check for tool calls const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls if (!toolCallsInResponse || toolCallsInResponse.length === 0) { @@ -359,7 +320,7 @@ export const ollamaProvider: ProviderConfig = { } logger.info( - `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})` + `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` ) // Track time for tool calls in this batch diff --git a/apps/sim/providers/ollama/utils.ts b/apps/sim/providers/ollama/utils.ts new file mode 100644 index 000000000..fc012f366 --- /dev/null +++ b/apps/sim/providers/ollama/utils.ts @@ -0,0 +1,37 @@ +/** + * Helper function to convert an Ollama stream to a standard ReadableStream + * and collect completion metrics + */ +export function createReadableStreamFromOllamaStream( + ollamaStream: any, + onComplete?: (content: string, usage?: any) => void +): ReadableStream { + let fullContent = '' + let usageData: any = null + + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of ollamaStream) { + if (chunk.usage) { + usageData = chunk.usage + } + + const content = chunk.choices[0]?.delta?.content || '' + if (content) { + fullContent += content + controller.enqueue(new TextEncoder().encode(content)) + } + } + + if (onComplete) { + onComplete(fullContent, usageData) + } + + controller.close() + } catch (error) { + controller.error(error) + } + }, + }) +} diff --git a/apps/sim/providers/openai/index.ts b/apps/sim/providers/openai/index.ts index b925dc7d1..3758fea1f 100644 --- a/apps/sim/providers/openai/index.ts +++ b/apps/sim/providers/openai/index.ts @@ -1,7 +1,9 @@ import OpenAI from 'openai' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' +import { createReadableStreamFromOpenAIStream } from '@/providers/openai/utils' import type { ProviderConfig, ProviderRequest, @@ -17,46 +19,6 @@ import { executeTool } from '@/tools' const logger = createLogger('OpenAIProvider') -/** - * Helper function to convert an OpenAI stream to a standard ReadableStream - * and collect completion metrics - */ -function createReadableStreamFromOpenAIStream( - openaiStream: any, - onComplete?: (content: string, usage?: any) => void -): ReadableStream { - let fullContent = '' - let usageData: any = null - - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of openaiStream) { - // Check for usage data in the final chunk - if (chunk.usage) { - usageData = chunk.usage - } - - const content = chunk.choices[0]?.delta?.content || '' - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - - // Once stream is complete, call the completion callback with the final content and usage - if (onComplete) { - onComplete(fullContent, usageData) - } - - controller.close() - } catch (error) { - controller.error(error) - } - }, - }) -} - /** * OpenAI provider configuration */ @@ -319,7 +281,6 @@ export const openaiProvider: ProviderConfig = { const toolResults = [] const currentMessages = [...allMessages] let iterationCount = 0 - const MAX_ITERATIONS = 10 // Prevent infinite loops // Track time spent in model vs tools let modelTime = firstResponseTime @@ -342,7 +303,7 @@ export const openaiProvider: ProviderConfig = { // Check if a forced tool was used in the first response checkForForcedToolUsage(currentResponse, originalToolChoice) - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { // Check for tool calls const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls if (!toolCallsInResponse || toolCallsInResponse.length === 0) { @@ -350,7 +311,7 @@ export const openaiProvider: ProviderConfig = { } logger.info( - `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})` + `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` ) // Track time for tool calls in this batch diff --git a/apps/sim/providers/openai/utils.ts b/apps/sim/providers/openai/utils.ts new file mode 100644 index 000000000..1f35bf6c3 --- /dev/null +++ b/apps/sim/providers/openai/utils.ts @@ -0,0 +1,37 @@ +/** + * Helper function to convert an OpenAI stream to a standard ReadableStream + * and collect completion metrics + */ +export function createReadableStreamFromOpenAIStream( + openaiStream: any, + onComplete?: (content: string, usage?: any) => void +): ReadableStream { + let fullContent = '' + let usageData: any = null + + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of openaiStream) { + if (chunk.usage) { + usageData = chunk.usage + } + + const content = chunk.choices[0]?.delta?.content || '' + if (content) { + fullContent += content + controller.enqueue(new TextEncoder().encode(content)) + } + } + + if (onComplete) { + onComplete(fullContent, usageData) + } + + controller.close() + } catch (error) { + controller.error(error) + } + }, + }) +} diff --git a/apps/sim/providers/openrouter/index.ts b/apps/sim/providers/openrouter/index.ts index 979b5783a..00fb33db0 100644 --- a/apps/sim/providers/openrouter/index.ts +++ b/apps/sim/providers/openrouter/index.ts @@ -1,56 +1,23 @@ import OpenAI from 'openai' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' +import { + checkForForcedToolUsage, + createReadableStreamFromOpenAIStream, +} from '@/providers/openrouter/utils' import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment, } from '@/providers/types' -import { - prepareToolExecution, - prepareToolsWithUsageControl, - trackForcedToolUsage, -} from '@/providers/utils' +import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils' import { executeTool } from '@/tools' const logger = createLogger('OpenRouterProvider') -function createReadableStreamFromOpenAIStream( - openaiStream: any, - onComplete?: (content: string, usage?: any) => void -): ReadableStream { - let fullContent = '' - let usageData: any = null - - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of openaiStream) { - if (chunk.usage) { - usageData = chunk.usage - } - - const content = chunk.choices[0]?.delta?.content || '' - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - - if (onComplete) { - onComplete(fullContent, usageData) - } - - controller.close() - } catch (error) { - controller.error(error) - } - }, - }) -} - export const openRouterProvider: ProviderConfig = { id: 'openrouter', name: 'OpenRouter', @@ -227,7 +194,6 @@ export const openRouterProvider: ProviderConfig = { const toolResults = [] as any[] const currentMessages = [...allMessages] let iterationCount = 0 - const MAX_ITERATIONS = 10 let modelTime = firstResponseTime let toolsTime = 0 let hasUsedForcedTool = false @@ -241,28 +207,16 @@ export const openRouterProvider: ProviderConfig = { }, ] - const checkForForcedToolUsage = ( - response: any, - toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any } - ) => { - if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) { - const toolCallsResponse = response.choices[0].message.tool_calls - const result = trackForcedToolUsage( - toolCallsResponse, - toolChoice, - logger, - 'openrouter', - forcedTools, - usedForcedTools - ) - hasUsedForcedTool = result.hasUsedForcedTool - usedForcedTools = result.usedForcedTools - } - } + const forcedToolResult = checkForForcedToolUsage( + currentResponse, + originalToolChoice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = forcedToolResult.hasUsedForcedTool + usedForcedTools = forcedToolResult.usedForcedTools - checkForForcedToolUsage(currentResponse, originalToolChoice) - - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls if (!toolCallsInResponse || toolCallsInResponse.length === 0) { break @@ -359,7 +313,14 @@ export const openRouterProvider: ProviderConfig = { const nextModelStartTime = Date.now() currentResponse = await client.chat.completions.create(nextPayload) - checkForForcedToolUsage(currentResponse, nextPayload.tool_choice) + const nextForcedToolResult = checkForForcedToolUsage( + currentResponse, + nextPayload.tool_choice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = nextForcedToolResult.hasUsedForcedTool + usedForcedTools = nextForcedToolResult.usedForcedTools const nextModelEndTime = Date.now() const thisModelTime = nextModelEndTime - nextModelStartTime timeSegments.push({ diff --git a/apps/sim/providers/openrouter/utils.ts b/apps/sim/providers/openrouter/utils.ts new file mode 100644 index 000000000..fc9a4254d --- /dev/null +++ b/apps/sim/providers/openrouter/utils.ts @@ -0,0 +1,78 @@ +import { createLogger } from '@/lib/logs/console/logger' +import { trackForcedToolUsage } from '@/providers/utils' + +const logger = createLogger('OpenRouterProvider') + +/** + * Creates a ReadableStream from an OpenAI-compatible stream response + * @param openaiStream - The OpenAI stream to convert + * @param onComplete - Optional callback when streaming is complete with content and usage data + * @returns ReadableStream that emits text chunks + */ +export function createReadableStreamFromOpenAIStream( + openaiStream: any, + onComplete?: (content: string, usage?: any) => void +): ReadableStream { + let fullContent = '' + let usageData: any = null + + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of openaiStream) { + if (chunk.usage) { + usageData = chunk.usage + } + + const content = chunk.choices[0]?.delta?.content || '' + if (content) { + fullContent += content + controller.enqueue(new TextEncoder().encode(content)) + } + } + + if (onComplete) { + onComplete(fullContent, usageData) + } + + controller.close() + } catch (error) { + controller.error(error) + } + }, + }) +} + +/** + * Checks if a forced tool was used in the response and updates tracking + * @param response - The API response containing tool calls + * @param toolChoice - The tool choice configuration (string or object) + * @param forcedTools - Array of forced tool names + * @param usedForcedTools - Array of already used forced tools + * @returns Object with hasUsedForcedTool flag and updated usedForcedTools array + */ +export function checkForForcedToolUsage( + response: any, + toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }, + forcedTools: string[], + usedForcedTools: string[] +): { hasUsedForcedTool: boolean; usedForcedTools: string[] } { + let hasUsedForcedTool = false + let updatedUsedForcedTools = usedForcedTools + + if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) { + const toolCallsResponse = response.choices[0].message.tool_calls + const result = trackForcedToolUsage( + toolCallsResponse, + toolChoice, + logger, + 'openrouter', + forcedTools, + updatedUsedForcedTools + ) + hasUsedForcedTool = result.hasUsedForcedTool + updatedUsedForcedTools = result.usedForcedTools + } + + return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools } +} diff --git a/apps/sim/providers/types.ts b/apps/sim/providers/types.ts index 6c2fd1f00..4ada41589 100644 --- a/apps/sim/providers/types.ts +++ b/apps/sim/providers/types.ts @@ -5,6 +5,7 @@ export type ProviderId = | 'azure-openai' | 'anthropic' | 'google' + | 'vertex' | 'deepseek' | 'xai' | 'cerebras' @@ -163,6 +164,9 @@ export interface ProviderRequest { // Azure OpenAI specific parameters azureEndpoint?: string azureApiVersion?: string + // Vertex AI specific parameters + vertexProject?: string + vertexLocation?: string // GPT-5 specific parameters reasoningEffort?: string verbosity?: string diff --git a/apps/sim/providers/utils.test.ts b/apps/sim/providers/utils.test.ts index 4fa913214..9085908c2 100644 --- a/apps/sim/providers/utils.test.ts +++ b/apps/sim/providers/utils.test.ts @@ -383,6 +383,17 @@ describe('Model Capabilities', () => { expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/gpt-5-mini') expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/gpt-5-nano') + // Should contain gpt-5.2 models + expect(MODELS_WITH_REASONING_EFFORT).toContain('gpt-5.2') + expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/gpt-5.2') + + // Should contain o-series reasoning models (reasoning_effort added Dec 17, 2024) + expect(MODELS_WITH_REASONING_EFFORT).toContain('o1') + expect(MODELS_WITH_REASONING_EFFORT).toContain('o3') + expect(MODELS_WITH_REASONING_EFFORT).toContain('o4-mini') + expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/o3') + expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/o4-mini') + // Should NOT contain non-reasoning GPT-5 models expect(MODELS_WITH_REASONING_EFFORT).not.toContain('gpt-5-chat-latest') expect(MODELS_WITH_REASONING_EFFORT).not.toContain('azure/gpt-5-chat-latest') @@ -390,7 +401,6 @@ describe('Model Capabilities', () => { // Should NOT contain other models expect(MODELS_WITH_REASONING_EFFORT).not.toContain('gpt-4o') expect(MODELS_WITH_REASONING_EFFORT).not.toContain('claude-sonnet-4-0') - expect(MODELS_WITH_REASONING_EFFORT).not.toContain('o1') }) it.concurrent('should have correct models in MODELS_WITH_VERBOSITY', () => { @@ -409,19 +419,37 @@ describe('Model Capabilities', () => { expect(MODELS_WITH_VERBOSITY).toContain('azure/gpt-5-mini') expect(MODELS_WITH_VERBOSITY).toContain('azure/gpt-5-nano') + // Should contain gpt-5.2 models + expect(MODELS_WITH_VERBOSITY).toContain('gpt-5.2') + expect(MODELS_WITH_VERBOSITY).toContain('azure/gpt-5.2') + // Should NOT contain non-reasoning GPT-5 models expect(MODELS_WITH_VERBOSITY).not.toContain('gpt-5-chat-latest') expect(MODELS_WITH_VERBOSITY).not.toContain('azure/gpt-5-chat-latest') + // Should NOT contain o-series models (they support reasoning_effort but not verbosity) + expect(MODELS_WITH_VERBOSITY).not.toContain('o1') + expect(MODELS_WITH_VERBOSITY).not.toContain('o3') + expect(MODELS_WITH_VERBOSITY).not.toContain('o4-mini') + // Should NOT contain other models expect(MODELS_WITH_VERBOSITY).not.toContain('gpt-4o') expect(MODELS_WITH_VERBOSITY).not.toContain('claude-sonnet-4-0') - expect(MODELS_WITH_VERBOSITY).not.toContain('o1') }) - it.concurrent('should have same models in both reasoning effort and verbosity arrays', () => { - // GPT-5 models that support reasoning effort should also support verbosity and vice versa - expect(MODELS_WITH_REASONING_EFFORT.sort()).toEqual(MODELS_WITH_VERBOSITY.sort()) + it.concurrent('should have GPT-5 models in both reasoning effort and verbosity arrays', () => { + // GPT-5 series models support both reasoning effort and verbosity + const gpt5ModelsWithReasoningEffort = MODELS_WITH_REASONING_EFFORT.filter( + (m) => m.includes('gpt-5') && !m.includes('chat-latest') + ) + const gpt5ModelsWithVerbosity = MODELS_WITH_VERBOSITY.filter( + (m) => m.includes('gpt-5') && !m.includes('chat-latest') + ) + expect(gpt5ModelsWithReasoningEffort.sort()).toEqual(gpt5ModelsWithVerbosity.sort()) + + // o-series models have reasoning effort but NOT verbosity + expect(MODELS_WITH_REASONING_EFFORT).toContain('o1') + expect(MODELS_WITH_VERBOSITY).not.toContain('o1') }) }) }) diff --git a/apps/sim/providers/utils.ts b/apps/sim/providers/utils.ts index d1cbe1b81..179df6e0b 100644 --- a/apps/sim/providers/utils.ts +++ b/apps/sim/providers/utils.ts @@ -21,6 +21,8 @@ import { getModelsWithVerbosity, getProviderModels as getProviderModelsFromDefinitions, getProvidersWithToolUsageControl, + getReasoningEffortValuesForModel as getReasoningEffortValuesForModelFromDefinitions, + getVerbosityValuesForModel as getVerbosityValuesForModelFromDefinitions, PROVIDER_DEFINITIONS, supportsTemperature as supportsTemperatureFromDefinitions, supportsToolUsageControl as supportsToolUsageControlFromDefinitions, @@ -30,6 +32,7 @@ import { ollamaProvider } from '@/providers/ollama' import { openaiProvider } from '@/providers/openai' import { openRouterProvider } from '@/providers/openrouter' import type { ProviderConfig, ProviderId, ProviderToolConfig } from '@/providers/types' +import { vertexProvider } from '@/providers/vertex' import { vllmProvider } from '@/providers/vllm' import { xAIProvider } from '@/providers/xai' import { useCustomToolsStore } from '@/stores/custom-tools/store' @@ -67,6 +70,11 @@ export const providers: Record< models: getProviderModelsFromDefinitions('google'), modelPatterns: PROVIDER_DEFINITIONS.google.modelPatterns, }, + vertex: { + ...vertexProvider, + models: getProviderModelsFromDefinitions('vertex'), + modelPatterns: PROVIDER_DEFINITIONS.vertex.modelPatterns, + }, deepseek: { ...deepseekProvider, models: getProviderModelsFromDefinitions('deepseek'), @@ -274,16 +282,12 @@ export function getProviderIcon(model: string): React.ComponentType<{ className? } export function generateStructuredOutputInstructions(responseFormat: any): string { - // Handle null/undefined input if (!responseFormat) return '' - // If using the new JSON Schema format, don't add additional instructions - // This is necessary because providers now handle the schema directly if (responseFormat.schema || (responseFormat.type === 'object' && responseFormat.properties)) { return '' } - // Handle legacy format with fields array if (!responseFormat.fields) return '' function generateFieldStructure(field: any): string { @@ -335,10 +339,8 @@ Each metric should be an object containing 'score' (number) and 'reasoning' (str } export function extractAndParseJSON(content: string): any { - // First clean up the string const trimmed = content.trim() - // Find the first '{' and last '}' const firstBrace = trimmed.indexOf('{') const lastBrace = trimmed.lastIndexOf('}') @@ -346,17 +348,15 @@ export function extractAndParseJSON(content: string): any { throw new Error('No JSON object found in content') } - // Extract just the JSON part const jsonStr = trimmed.slice(firstBrace, lastBrace + 1) try { return JSON.parse(jsonStr) } catch (_error) { - // If parsing fails, try to clean up common issues const cleaned = jsonStr - .replace(/\n/g, ' ') // Remove newlines - .replace(/\s+/g, ' ') // Normalize whitespace - .replace(/,\s*([}\]])/g, '$1') // Remove trailing commas + .replace(/\n/g, ' ') + .replace(/\s+/g, ' ') + .replace(/,\s*([}\]])/g, '$1') try { return JSON.parse(cleaned) @@ -386,10 +386,10 @@ export function transformCustomTool(customTool: any): ProviderToolConfig { } return { - id: `custom_${customTool.id}`, // Prefix with 'custom_' to identify custom tools + id: `custom_${customTool.id}`, name: schema.function.name, description: schema.function.description || '', - params: {}, // This will be derived from parameters + params: {}, parameters: { type: schema.function.parameters.type, properties: schema.function.parameters.properties, @@ -402,10 +402,8 @@ export function transformCustomTool(customTool: any): ProviderToolConfig { * Gets all available custom tools as provider tool configs */ export function getCustomTools(): ProviderToolConfig[] { - // Get custom tools from the store const customTools = useCustomToolsStore.getState().getAllTools() - // Transform each custom tool into a provider tool config return customTools.map(transformCustomTool) } @@ -427,20 +425,16 @@ export async function transformBlockTool( ): Promise { const { selectedOperation, getAllBlocks, getTool, getToolAsync } = options - // Get the block definition const blockDef = getAllBlocks().find((b: any) => b.type === block.type) if (!blockDef) { logger.warn(`Block definition not found for type: ${block.type}`) return null } - // If the block has multiple operations, use the selected one or the first one let toolId: string | null = null if ((blockDef.tools?.access?.length || 0) > 1) { - // If we have an operation dropdown in the block and a selected operation if (selectedOperation && blockDef.tools?.config?.tool) { - // Use the block's tool selection function to get the right tool try { toolId = blockDef.tools.config.tool({ ...block.params, @@ -455,11 +449,9 @@ export async function transformBlockTool( return null } } else { - // Default to first tool if no operation specified toolId = blockDef.tools.access[0] } } else { - // Single tool case toolId = blockDef.tools?.access?.[0] || null } @@ -468,14 +460,11 @@ export async function transformBlockTool( return null } - // Get the tool config - check if it's a custom tool that needs async fetching let toolConfig: any if (toolId.startsWith('custom_') && getToolAsync) { - // Use the async version for custom tools toolConfig = await getToolAsync(toolId) } else { - // Use the synchronous version for built-in tools toolConfig = getTool(toolId) } @@ -484,16 +473,12 @@ export async function transformBlockTool( return null } - // Import the new tool parameter utilities const { createLLMToolSchema } = await import('@/tools/params') - // Get user-provided parameters from the block const userProvidedParams = block.params || {} - // Create LLM schema that excludes user-provided parameters const llmSchema = await createLLMToolSchema(toolConfig, userProvidedParams) - // Return formatted tool config return { id: toolConfig.id, name: toolConfig.name, @@ -521,15 +506,12 @@ export function calculateCost( inputMultiplier?: number, outputMultiplier?: number ) { - // First check if it's an embedding model let pricing = getEmbeddingModelPricing(model) - // If not found, check chat models if (!pricing) { pricing = getModelPricingFromDefinitions(model) } - // If no pricing found, return default pricing if (!pricing) { const defaultPricing = { input: 1.0, @@ -545,8 +527,6 @@ export function calculateCost( } } - // Calculate costs in USD - // Convert from "per million tokens" to "per token" by dividing by 1,000,000 const inputCost = promptTokens * (useCachedInput && pricing.cachedInput @@ -559,7 +539,7 @@ export function calculateCost( const finalTotalCost = finalInputCost + finalOutputCost return { - input: Number.parseFloat(finalInputCost.toFixed(8)), // Use 8 decimal places for small costs + input: Number.parseFloat(finalInputCost.toFixed(8)), output: Number.parseFloat(finalOutputCost.toFixed(8)), total: Number.parseFloat(finalTotalCost.toFixed(8)), pricing, @@ -997,6 +977,22 @@ export function supportsToolUsageControl(provider: string): boolean { return supportsToolUsageControlFromDefinitions(provider) } +/** + * Get reasoning effort values for a specific model + * Returns the valid options for that model, or null if the model doesn't support reasoning effort + */ +export function getReasoningEffortValuesForModel(model: string): string[] | null { + return getReasoningEffortValuesForModelFromDefinitions(model) +} + +/** + * Get verbosity values for a specific model + * Returns the valid options for that model, or null if the model doesn't support verbosity + */ +export function getVerbosityValuesForModel(model: string): string[] | null { + return getVerbosityValuesForModelFromDefinitions(model) +} + /** * Prepare tool execution parameters, separating tool parameters from system parameters */ diff --git a/apps/sim/providers/vertex/index.ts b/apps/sim/providers/vertex/index.ts new file mode 100644 index 000000000..0a25d304d --- /dev/null +++ b/apps/sim/providers/vertex/index.ts @@ -0,0 +1,899 @@ +import { env } from '@/lib/core/config/env' +import { createLogger } from '@/lib/logs/console/logger' +import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import { + cleanSchemaForGemini, + convertToGeminiFormat, + extractFunctionCall, + extractTextContent, +} from '@/providers/google/utils' +import { getProviderDefaultModel, getProviderModels } from '@/providers/models' +import type { + ProviderConfig, + ProviderRequest, + ProviderResponse, + TimeSegment, +} from '@/providers/types' +import { + prepareToolExecution, + prepareToolsWithUsageControl, + trackForcedToolUsage, +} from '@/providers/utils' +import { buildVertexEndpoint, createReadableStreamFromVertexStream } from '@/providers/vertex/utils' +import { executeTool } from '@/tools' + +const logger = createLogger('VertexProvider') + +/** + * Vertex AI provider configuration + */ +export const vertexProvider: ProviderConfig = { + id: 'vertex', + name: 'Vertex AI', + description: "Google's Vertex AI platform for Gemini models", + version: '1.0.0', + models: getProviderModels('vertex'), + defaultModel: getProviderDefaultModel('vertex'), + + executeRequest: async ( + request: ProviderRequest + ): Promise => { + const vertexProject = env.VERTEX_PROJECT || request.vertexProject + const vertexLocation = env.VERTEX_LOCATION || request.vertexLocation || 'us-central1' + + if (!vertexProject) { + throw new Error( + 'Vertex AI project is required. Please provide it via VERTEX_PROJECT environment variable or vertexProject parameter.' + ) + } + + if (!request.apiKey) { + throw new Error( + 'Access token is required for Vertex AI. Run `gcloud auth print-access-token` to get one, or use a service account.' + ) + } + + logger.info('Preparing Vertex AI request', { + model: request.model || 'vertex/gemini-2.5-pro', + hasSystemPrompt: !!request.systemPrompt, + hasMessages: !!request.messages?.length, + hasTools: !!request.tools?.length, + toolCount: request.tools?.length || 0, + hasResponseFormat: !!request.responseFormat, + streaming: !!request.stream, + project: vertexProject, + location: vertexLocation, + }) + + const providerStartTime = Date.now() + const providerStartTimeISO = new Date(providerStartTime).toISOString() + + try { + const { contents, tools, systemInstruction } = convertToGeminiFormat(request) + + const requestedModel = (request.model || 'vertex/gemini-2.5-pro').replace('vertex/', '') + + const payload: any = { + contents, + generationConfig: {}, + } + + if (request.temperature !== undefined && request.temperature !== null) { + payload.generationConfig.temperature = request.temperature + } + + if (request.maxTokens !== undefined) { + payload.generationConfig.maxOutputTokens = request.maxTokens + } + + if (systemInstruction) { + payload.systemInstruction = systemInstruction + } + + if (request.responseFormat && !tools?.length) { + const responseFormatSchema = request.responseFormat.schema || request.responseFormat + const cleanSchema = cleanSchemaForGemini(responseFormatSchema) + + payload.generationConfig.responseMimeType = 'application/json' + payload.generationConfig.responseSchema = cleanSchema + + logger.info('Using Vertex AI native structured output format', { + hasSchema: !!cleanSchema, + mimeType: 'application/json', + }) + } else if (request.responseFormat && tools?.length) { + logger.warn( + 'Vertex AI does not support structured output (responseFormat) with function calling (tools). Structured output will be ignored.' + ) + } + + let preparedTools: ReturnType | null = null + + if (tools?.length) { + preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'google') + const { tools: filteredTools, toolConfig } = preparedTools + + if (filteredTools?.length) { + payload.tools = [ + { + functionDeclarations: filteredTools, + }, + ] + + if (toolConfig) { + payload.toolConfig = toolConfig + } + + logger.info('Vertex AI request with tools:', { + toolCount: filteredTools.length, + model: requestedModel, + tools: filteredTools.map((t) => t.name), + hasToolConfig: !!toolConfig, + toolConfig: toolConfig, + }) + } + } + + const initialCallTime = Date.now() + const shouldStream = !!(request.stream && !tools?.length) + + const endpoint = buildVertexEndpoint( + vertexProject, + vertexLocation, + requestedModel, + shouldStream + ) + + if (request.stream && tools?.length) { + logger.info('Streaming disabled for initial request due to tools presence', { + toolCount: tools.length, + willStreamAfterTools: true, + }) + } + + const response = await fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${request.apiKey}`, + }, + body: JSON.stringify(payload), + }) + + if (!response.ok) { + const responseText = await response.text() + logger.error('Vertex AI API error details:', { + status: response.status, + statusText: response.statusText, + responseBody: responseText, + }) + throw new Error(`Vertex AI API error: ${response.status} ${response.statusText}`) + } + + const firstResponseTime = Date.now() - initialCallTime + + if (shouldStream) { + logger.info('Handling Vertex AI streaming response') + + const streamingResult: StreamingExecution = { + stream: null as any, + execution: { + success: true, + output: { + content: '', + model: request.model, + tokens: { + prompt: 0, + completion: 0, + total: 0, + }, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: firstResponseTime, + modelTime: firstResponseTime, + toolsTime: 0, + firstResponseTime, + iterations: 1, + timeSegments: [ + { + type: 'model', + name: 'Initial streaming response', + startTime: initialCallTime, + endTime: initialCallTime + firstResponseTime, + duration: firstResponseTime, + }, + ], + }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: firstResponseTime, + }, + isStreaming: true, + }, + } + + streamingResult.stream = createReadableStreamFromVertexStream( + response, + (content, usage) => { + streamingResult.execution.output.content = content + + const streamEndTime = Date.now() + const streamEndTimeISO = new Date(streamEndTime).toISOString() + + if (streamingResult.execution.output.providerTiming) { + streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO + streamingResult.execution.output.providerTiming.duration = + streamEndTime - providerStartTime + + if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) { + streamingResult.execution.output.providerTiming.timeSegments[0].endTime = + streamEndTime + streamingResult.execution.output.providerTiming.timeSegments[0].duration = + streamEndTime - providerStartTime + } + } + + if (usage) { + streamingResult.execution.output.tokens = { + prompt: usage.promptTokenCount || 0, + completion: usage.candidatesTokenCount || 0, + total: + usage.totalTokenCount || + (usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0), + } + } + } + ) + + return streamingResult + } + + let geminiResponse = await response.json() + + if (payload.generationConfig?.responseSchema) { + const candidate = geminiResponse.candidates?.[0] + if (candidate?.content?.parts?.[0]?.text) { + const text = candidate.content.parts[0].text + try { + JSON.parse(text) + logger.info('Successfully received structured JSON output') + } catch (_e) { + logger.warn('Failed to parse structured output as JSON') + } + } + } + + let content = '' + let tokens = { + prompt: 0, + completion: 0, + total: 0, + } + const toolCalls = [] + const toolResults = [] + let iterationCount = 0 + + const originalToolConfig = preparedTools?.toolConfig + const forcedTools = preparedTools?.forcedTools || [] + let usedForcedTools: string[] = [] + let hasUsedForcedTool = false + let currentToolConfig = originalToolConfig + + const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => { + if (currentToolConfig && forcedTools.length > 0) { + const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }] + const result = trackForcedToolUsage( + toolCallsForTracking, + currentToolConfig, + logger, + 'google', + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = result.hasUsedForcedTool + usedForcedTools = result.usedForcedTools + + if (result.nextToolConfig) { + currentToolConfig = result.nextToolConfig + logger.info('Updated tool config for next iteration', { + hasNextToolConfig: !!currentToolConfig, + usedForcedTools: usedForcedTools, + }) + } + } + } + + let modelTime = firstResponseTime + let toolsTime = 0 + + const timeSegments: TimeSegment[] = [ + { + type: 'model', + name: 'Initial response', + startTime: initialCallTime, + endTime: initialCallTime + firstResponseTime, + duration: firstResponseTime, + }, + ] + + try { + const candidate = geminiResponse.candidates?.[0] + + if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') { + logger.warn( + 'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided', + { + finishReason: candidate.finishReason, + hasContent: !!candidate?.content, + hasParts: !!candidate?.content?.parts, + } + ) + content = extractTextContent(candidate) + } + + const functionCall = extractFunctionCall(candidate) + + if (functionCall) { + logger.info(`Received function call from Vertex AI: ${functionCall.name}`) + + while (iterationCount < MAX_TOOL_ITERATIONS) { + const latestResponse = geminiResponse.candidates?.[0] + const latestFunctionCall = extractFunctionCall(latestResponse) + + if (!latestFunctionCall) { + content = extractTextContent(latestResponse) + break + } + + logger.info( + `Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` + ) + + const toolsStartTime = Date.now() + + try { + const toolName = latestFunctionCall.name + const toolArgs = latestFunctionCall.args || {} + + const tool = request.tools?.find((t) => t.id === toolName) + if (!tool) { + logger.warn(`Tool ${toolName} not found in registry, skipping`) + break + } + + const toolCallStartTime = Date.now() + + const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) + const result = await executeTool(toolName, executionParams, true) + const toolCallEndTime = Date.now() + const toolCallDuration = toolCallEndTime - toolCallStartTime + + timeSegments.push({ + type: 'tool', + name: toolName, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallDuration, + }) + + let resultContent: any + if (result.success) { + toolResults.push(result.output) + resultContent = result.output + } else { + resultContent = { + error: true, + message: result.error || 'Tool execution failed', + tool: toolName, + } + } + + toolCalls.push({ + name: toolName, + arguments: toolParams, + startTime: new Date(toolCallStartTime).toISOString(), + endTime: new Date(toolCallEndTime).toISOString(), + duration: toolCallDuration, + result: resultContent, + success: result.success, + }) + + const simplifiedMessages = [ + ...(contents.filter((m) => m.role === 'user').length > 0 + ? [contents.filter((m) => m.role === 'user')[0]] + : [contents[0]]), + { + role: 'model', + parts: [ + { + functionCall: { + name: latestFunctionCall.name, + args: latestFunctionCall.args, + }, + }, + ], + }, + { + role: 'user', + parts: [ + { + text: `Function ${latestFunctionCall.name} result: ${JSON.stringify(resultContent)}`, + }, + ], + }, + ] + + const thisToolsTime = Date.now() - toolsStartTime + toolsTime += thisToolsTime + + checkForForcedToolUsage(latestFunctionCall) + + const nextModelStartTime = Date.now() + + try { + if (request.stream) { + const streamingPayload = { + ...payload, + contents: simplifiedMessages, + } + + const allForcedToolsUsed = + forcedTools.length > 0 && usedForcedTools.length === forcedTools.length + + if (allForcedToolsUsed && request.responseFormat) { + streamingPayload.tools = undefined + streamingPayload.toolConfig = undefined + + const responseFormatSchema = + request.responseFormat.schema || request.responseFormat + const cleanSchema = cleanSchemaForGemini(responseFormatSchema) + + if (!streamingPayload.generationConfig) { + streamingPayload.generationConfig = {} + } + streamingPayload.generationConfig.responseMimeType = 'application/json' + streamingPayload.generationConfig.responseSchema = cleanSchema + + logger.info('Using structured output for final response after tool execution') + } else { + if (currentToolConfig) { + streamingPayload.toolConfig = currentToolConfig + } else { + streamingPayload.toolConfig = { functionCallingConfig: { mode: 'AUTO' } } + } + } + + const checkPayload = { + ...streamingPayload, + } + checkPayload.stream = undefined + + const checkEndpoint = buildVertexEndpoint( + vertexProject, + vertexLocation, + requestedModel, + false + ) + + const checkResponse = await fetch(checkEndpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${request.apiKey}`, + }, + body: JSON.stringify(checkPayload), + }) + + if (!checkResponse.ok) { + const errorBody = await checkResponse.text() + logger.error('Error in Vertex AI check request:', { + status: checkResponse.status, + statusText: checkResponse.statusText, + responseBody: errorBody, + }) + throw new Error( + `Vertex AI API check error: ${checkResponse.status} ${checkResponse.statusText}` + ) + } + + const checkResult = await checkResponse.json() + const checkCandidate = checkResult.candidates?.[0] + const checkFunctionCall = extractFunctionCall(checkCandidate) + + if (checkFunctionCall) { + logger.info( + 'Function call detected in follow-up, handling in non-streaming mode', + { + functionName: checkFunctionCall.name, + } + ) + + geminiResponse = checkResult + + if (checkResult.usageMetadata) { + tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0 + tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0 + tokens.total += + (checkResult.usageMetadata.promptTokenCount || 0) + + (checkResult.usageMetadata.candidatesTokenCount || 0) + } + + const nextModelEndTime = Date.now() + const thisModelTime = nextModelEndTime - nextModelStartTime + modelTime += thisModelTime + + timeSegments.push({ + type: 'model', + name: `Model response (iteration ${iterationCount + 1})`, + startTime: nextModelStartTime, + endTime: nextModelEndTime, + duration: thisModelTime, + }) + + iterationCount++ + continue + } + + logger.info('No function call detected, proceeding with streaming response') + + // Apply structured output for the final response if responseFormat is specified + // This works regardless of whether tools were forced or auto + if (request.responseFormat) { + streamingPayload.tools = undefined + streamingPayload.toolConfig = undefined + + const responseFormatSchema = + request.responseFormat.schema || request.responseFormat + const cleanSchema = cleanSchemaForGemini(responseFormatSchema) + + if (!streamingPayload.generationConfig) { + streamingPayload.generationConfig = {} + } + streamingPayload.generationConfig.responseMimeType = 'application/json' + streamingPayload.generationConfig.responseSchema = cleanSchema + + logger.info( + 'Using structured output for final streaming response after tool execution' + ) + } + + const streamEndpoint = buildVertexEndpoint( + vertexProject, + vertexLocation, + requestedModel, + true + ) + + const streamingResponse = await fetch(streamEndpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${request.apiKey}`, + }, + body: JSON.stringify(streamingPayload), + }) + + if (!streamingResponse.ok) { + const errorBody = await streamingResponse.text() + logger.error('Error in Vertex AI streaming follow-up request:', { + status: streamingResponse.status, + statusText: streamingResponse.statusText, + responseBody: errorBody, + }) + throw new Error( + `Vertex AI API streaming error: ${streamingResponse.status} ${streamingResponse.statusText}` + ) + } + + const nextModelEndTime = Date.now() + const thisModelTime = nextModelEndTime - nextModelStartTime + modelTime += thisModelTime + + timeSegments.push({ + type: 'model', + name: 'Final streaming response after tool calls', + startTime: nextModelStartTime, + endTime: nextModelEndTime, + duration: thisModelTime, + }) + + const streamingExecution: StreamingExecution = { + stream: null as any, + execution: { + success: true, + output: { + content: '', + model: request.model, + tokens, + toolCalls: + toolCalls.length > 0 + ? { + list: toolCalls, + count: toolCalls.length, + } + : undefined, + toolResults, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + modelTime, + toolsTime, + firstResponseTime, + iterations: iterationCount + 1, + timeSegments, + }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + }, + isStreaming: true, + }, + } + + streamingExecution.stream = createReadableStreamFromVertexStream( + streamingResponse, + (content, usage) => { + streamingExecution.execution.output.content = content + + const streamEndTime = Date.now() + const streamEndTimeISO = new Date(streamEndTime).toISOString() + + if (streamingExecution.execution.output.providerTiming) { + streamingExecution.execution.output.providerTiming.endTime = + streamEndTimeISO + streamingExecution.execution.output.providerTiming.duration = + streamEndTime - providerStartTime + } + + if (usage) { + const existingTokens = streamingExecution.execution.output.tokens || { + prompt: 0, + completion: 0, + total: 0, + } + streamingExecution.execution.output.tokens = { + prompt: (existingTokens.prompt || 0) + (usage.promptTokenCount || 0), + completion: + (existingTokens.completion || 0) + (usage.candidatesTokenCount || 0), + total: + (existingTokens.total || 0) + + (usage.totalTokenCount || + (usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0)), + } + } + } + ) + + return streamingExecution + } + + const nextPayload = { + ...payload, + contents: simplifiedMessages, + } + + const allForcedToolsUsed = + forcedTools.length > 0 && usedForcedTools.length === forcedTools.length + + if (allForcedToolsUsed && request.responseFormat) { + nextPayload.tools = undefined + nextPayload.toolConfig = undefined + + const responseFormatSchema = + request.responseFormat.schema || request.responseFormat + const cleanSchema = cleanSchemaForGemini(responseFormatSchema) + + if (!nextPayload.generationConfig) { + nextPayload.generationConfig = {} + } + nextPayload.generationConfig.responseMimeType = 'application/json' + nextPayload.generationConfig.responseSchema = cleanSchema + + logger.info( + 'Using structured output for final non-streaming response after tool execution' + ) + } else { + if (currentToolConfig) { + nextPayload.toolConfig = currentToolConfig + } + } + + const nextEndpoint = buildVertexEndpoint( + vertexProject, + vertexLocation, + requestedModel, + false + ) + + const nextResponse = await fetch(nextEndpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${request.apiKey}`, + }, + body: JSON.stringify(nextPayload), + }) + + if (!nextResponse.ok) { + const errorBody = await nextResponse.text() + logger.error('Error in Vertex AI follow-up request:', { + status: nextResponse.status, + statusText: nextResponse.statusText, + responseBody: errorBody, + iterationCount, + }) + break + } + + geminiResponse = await nextResponse.json() + + const nextModelEndTime = Date.now() + const thisModelTime = nextModelEndTime - nextModelStartTime + + timeSegments.push({ + type: 'model', + name: `Model response (iteration ${iterationCount + 1})`, + startTime: nextModelStartTime, + endTime: nextModelEndTime, + duration: thisModelTime, + }) + + modelTime += thisModelTime + + const nextCandidate = geminiResponse.candidates?.[0] + const nextFunctionCall = extractFunctionCall(nextCandidate) + + if (!nextFunctionCall) { + // If responseFormat is specified, make one final request with structured output + if (request.responseFormat) { + const finalPayload = { + ...payload, + contents: nextPayload.contents, + tools: undefined, + toolConfig: undefined, + } + + const responseFormatSchema = + request.responseFormat.schema || request.responseFormat + const cleanSchema = cleanSchemaForGemini(responseFormatSchema) + + if (!finalPayload.generationConfig) { + finalPayload.generationConfig = {} + } + finalPayload.generationConfig.responseMimeType = 'application/json' + finalPayload.generationConfig.responseSchema = cleanSchema + + logger.info('Making final request with structured output after tool execution') + + const finalEndpoint = buildVertexEndpoint( + vertexProject, + vertexLocation, + requestedModel, + false + ) + + const finalResponse = await fetch(finalEndpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${request.apiKey}`, + }, + body: JSON.stringify(finalPayload), + }) + + if (finalResponse.ok) { + const finalResult = await finalResponse.json() + const finalCandidate = finalResult.candidates?.[0] + content = extractTextContent(finalCandidate) + + if (finalResult.usageMetadata) { + tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0 + tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0 + tokens.total += + (finalResult.usageMetadata.promptTokenCount || 0) + + (finalResult.usageMetadata.candidatesTokenCount || 0) + } + } else { + logger.warn( + 'Failed to get structured output, falling back to regular response' + ) + content = extractTextContent(nextCandidate) + } + } else { + content = extractTextContent(nextCandidate) + } + break + } + + iterationCount++ + } catch (error) { + logger.error('Error in Vertex AI follow-up request:', { + error: error instanceof Error ? error.message : String(error), + iterationCount, + }) + break + } + } catch (error) { + logger.error('Error processing function call:', { + error: error instanceof Error ? error.message : String(error), + functionName: latestFunctionCall?.name || 'unknown', + }) + break + } + } + } else { + content = extractTextContent(candidate) + } + } catch (error) { + logger.error('Error processing Vertex AI response:', { + error: error instanceof Error ? error.message : String(error), + iterationCount, + }) + + if (!content && toolCalls.length > 0) { + content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.` + } + } + + const providerEndTime = Date.now() + const providerEndTimeISO = new Date(providerEndTime).toISOString() + const totalDuration = providerEndTime - providerStartTime + + if (geminiResponse.usageMetadata) { + tokens = { + prompt: geminiResponse.usageMetadata.promptTokenCount || 0, + completion: geminiResponse.usageMetadata.candidatesTokenCount || 0, + total: + (geminiResponse.usageMetadata.promptTokenCount || 0) + + (geminiResponse.usageMetadata.candidatesTokenCount || 0), + } + } + + return { + content, + model: request.model, + tokens, + toolCalls: toolCalls.length > 0 ? toolCalls : undefined, + toolResults: toolResults.length > 0 ? toolResults : undefined, + timing: { + startTime: providerStartTimeISO, + endTime: providerEndTimeISO, + duration: totalDuration, + modelTime: modelTime, + toolsTime: toolsTime, + firstResponseTime: firstResponseTime, + iterations: iterationCount + 1, + timeSegments: timeSegments, + }, + } + } catch (error) { + const providerEndTime = Date.now() + const providerEndTimeISO = new Date(providerEndTime).toISOString() + const totalDuration = providerEndTime - providerStartTime + + logger.error('Error in Vertex AI request:', { + error: error instanceof Error ? error.message : String(error), + duration: totalDuration, + }) + + const enhancedError = new Error(error instanceof Error ? error.message : String(error)) + // @ts-ignore - Adding timing property to the error + enhancedError.timing = { + startTime: providerStartTimeISO, + endTime: providerEndTimeISO, + duration: totalDuration, + } + + throw enhancedError + } + }, +} diff --git a/apps/sim/providers/vertex/utils.ts b/apps/sim/providers/vertex/utils.ts new file mode 100644 index 000000000..70ac83e32 --- /dev/null +++ b/apps/sim/providers/vertex/utils.ts @@ -0,0 +1,233 @@ +import { createLogger } from '@/lib/logs/console/logger' +import { extractFunctionCall, extractTextContent } from '@/providers/google/utils' + +const logger = createLogger('VertexUtils') + +/** + * Creates a ReadableStream from Vertex AI's Gemini stream response + */ +export function createReadableStreamFromVertexStream( + response: Response, + onComplete?: ( + content: string, + usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number } + ) => void +): ReadableStream { + const reader = response.body?.getReader() + if (!reader) { + throw new Error('Failed to get reader from response body') + } + + return new ReadableStream({ + async start(controller) { + try { + let buffer = '' + let fullContent = '' + let usageData: { + promptTokenCount?: number + candidatesTokenCount?: number + totalTokenCount?: number + } | null = null + + while (true) { + const { done, value } = await reader.read() + if (done) { + if (buffer.trim()) { + try { + const data = JSON.parse(buffer.trim()) + if (data.usageMetadata) { + usageData = data.usageMetadata + } + const candidate = data.candidates?.[0] + if (candidate?.content?.parts) { + const functionCall = extractFunctionCall(candidate) + if (functionCall) { + logger.debug( + 'Function call detected in final buffer, ending stream to execute tool', + { + functionName: functionCall.name, + } + ) + if (onComplete) onComplete(fullContent, usageData || undefined) + controller.close() + return + } + const content = extractTextContent(candidate) + if (content) { + fullContent += content + controller.enqueue(new TextEncoder().encode(content)) + } + } + } catch (e) { + if (buffer.trim().startsWith('[')) { + try { + const dataArray = JSON.parse(buffer.trim()) + if (Array.isArray(dataArray)) { + for (const item of dataArray) { + if (item.usageMetadata) { + usageData = item.usageMetadata + } + const candidate = item.candidates?.[0] + if (candidate?.content?.parts) { + const functionCall = extractFunctionCall(candidate) + if (functionCall) { + logger.debug( + 'Function call detected in array item, ending stream to execute tool', + { + functionName: functionCall.name, + } + ) + if (onComplete) onComplete(fullContent, usageData || undefined) + controller.close() + return + } + const content = extractTextContent(candidate) + if (content) { + fullContent += content + controller.enqueue(new TextEncoder().encode(content)) + } + } + } + } + } catch (arrayError) { + // Buffer is not valid JSON array + } + } + } + } + if (onComplete) onComplete(fullContent, usageData || undefined) + controller.close() + break + } + + const text = new TextDecoder().decode(value) + buffer += text + + let searchIndex = 0 + while (searchIndex < buffer.length) { + const openBrace = buffer.indexOf('{', searchIndex) + if (openBrace === -1) break + + let braceCount = 0 + let inString = false + let escaped = false + let closeBrace = -1 + + for (let i = openBrace; i < buffer.length; i++) { + const char = buffer[i] + + if (!inString) { + if (char === '"' && !escaped) { + inString = true + } else if (char === '{') { + braceCount++ + } else if (char === '}') { + braceCount-- + if (braceCount === 0) { + closeBrace = i + break + } + } + } else { + if (char === '"' && !escaped) { + inString = false + } + } + + escaped = char === '\\' && !escaped + } + + if (closeBrace !== -1) { + const jsonStr = buffer.substring(openBrace, closeBrace + 1) + + try { + const data = JSON.parse(jsonStr) + + if (data.usageMetadata) { + usageData = data.usageMetadata + } + + const candidate = data.candidates?.[0] + + if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') { + logger.warn( + 'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided', + { + finishReason: candidate.finishReason, + hasContent: !!candidate?.content, + hasParts: !!candidate?.content?.parts, + } + ) + const textContent = extractTextContent(candidate) + if (textContent) { + fullContent += textContent + controller.enqueue(new TextEncoder().encode(textContent)) + } + if (onComplete) onComplete(fullContent, usageData || undefined) + controller.close() + return + } + + if (candidate?.content?.parts) { + const functionCall = extractFunctionCall(candidate) + if (functionCall) { + logger.debug( + 'Function call detected in stream, ending stream to execute tool', + { + functionName: functionCall.name, + } + ) + if (onComplete) onComplete(fullContent, usageData || undefined) + controller.close() + return + } + const content = extractTextContent(candidate) + if (content) { + fullContent += content + controller.enqueue(new TextEncoder().encode(content)) + } + } + } catch (e) { + logger.error('Error parsing JSON from stream', { + error: e instanceof Error ? e.message : String(e), + jsonPreview: jsonStr.substring(0, 200), + }) + } + + buffer = buffer.substring(closeBrace + 1) + searchIndex = 0 + } else { + break + } + } + } + } catch (e) { + logger.error('Error reading Vertex AI stream', { + error: e instanceof Error ? e.message : String(e), + }) + controller.error(e) + } + }, + async cancel() { + await reader.cancel() + }, + }) +} + +/** + * Build Vertex AI endpoint URL + */ +export function buildVertexEndpoint( + project: string, + location: string, + model: string, + isStreaming: boolean +): string { + const action = isStreaming ? 'streamGenerateContent' : 'generateContent' + + if (location === 'global') { + return `https://aiplatform.googleapis.com/v1/projects/${project}/locations/global/publishers/google/models/${model}:${action}` + } + + return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/google/models/${model}:${action}` +} diff --git a/apps/sim/providers/vllm/index.ts b/apps/sim/providers/vllm/index.ts index bd6805be7..14acdc0e4 100644 --- a/apps/sim/providers/vllm/index.ts +++ b/apps/sim/providers/vllm/index.ts @@ -2,6 +2,7 @@ import OpenAI from 'openai' import { env } from '@/lib/core/config/env' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import type { ProviderConfig, @@ -14,50 +15,13 @@ import { prepareToolsWithUsageControl, trackForcedToolUsage, } from '@/providers/utils' +import { createReadableStreamFromVLLMStream } from '@/providers/vllm/utils' import { useProvidersStore } from '@/stores/providers/store' import { executeTool } from '@/tools' const logger = createLogger('VLLMProvider') const VLLM_VERSION = '1.0.0' -/** - * Helper function to convert a vLLM stream to a standard ReadableStream - * and collect completion metrics - */ -function createReadableStreamFromVLLMStream( - vllmStream: any, - onComplete?: (content: string, usage?: any) => void -): ReadableStream { - let fullContent = '' - let usageData: any = null - - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of vllmStream) { - if (chunk.usage) { - usageData = chunk.usage - } - - const content = chunk.choices[0]?.delta?.content || '' - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - - if (onComplete) { - onComplete(fullContent, usageData) - } - - controller.close() - } catch (error) { - controller.error(error) - } - }, - }) -} - export const vllmProvider: ProviderConfig = { id: 'vllm', name: 'vLLM', @@ -341,7 +305,6 @@ export const vllmProvider: ProviderConfig = { const toolResults = [] const currentMessages = [...allMessages] let iterationCount = 0 - const MAX_ITERATIONS = 10 let modelTime = firstResponseTime let toolsTime = 0 @@ -360,14 +323,14 @@ export const vllmProvider: ProviderConfig = { checkForForcedToolUsage(currentResponse, originalToolChoice) - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls if (!toolCallsInResponse || toolCallsInResponse.length === 0) { break } logger.info( - `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})` + `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` ) const toolsStartTime = Date.now() diff --git a/apps/sim/providers/vllm/utils.ts b/apps/sim/providers/vllm/utils.ts new file mode 100644 index 000000000..56afadf0d --- /dev/null +++ b/apps/sim/providers/vllm/utils.ts @@ -0,0 +1,37 @@ +/** + * Helper function to convert a vLLM stream to a standard ReadableStream + * and collect completion metrics + */ +export function createReadableStreamFromVLLMStream( + vllmStream: any, + onComplete?: (content: string, usage?: any) => void +): ReadableStream { + let fullContent = '' + let usageData: any = null + + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of vllmStream) { + if (chunk.usage) { + usageData = chunk.usage + } + + const content = chunk.choices[0]?.delta?.content || '' + if (content) { + fullContent += content + controller.enqueue(new TextEncoder().encode(content)) + } + } + + if (onComplete) { + onComplete(fullContent, usageData) + } + + controller.close() + } catch (error) { + controller.error(error) + } + }, + }) +} diff --git a/apps/sim/providers/xai/index.ts b/apps/sim/providers/xai/index.ts index cfa73baf2..f1faa6480 100644 --- a/apps/sim/providers/xai/index.ts +++ b/apps/sim/providers/xai/index.ts @@ -1,6 +1,7 @@ import OpenAI from 'openai' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' import type { ProviderConfig, @@ -8,37 +9,16 @@ import type { ProviderResponse, TimeSegment, } from '@/providers/types' +import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils' import { - prepareToolExecution, - prepareToolsWithUsageControl, - trackForcedToolUsage, -} from '@/providers/utils' + checkForForcedToolUsage, + createReadableStreamFromXAIStream, + createResponseFormatPayload, +} from '@/providers/xai/utils' import { executeTool } from '@/tools' const logger = createLogger('XAIProvider') -/** - * Helper to wrap XAI (OpenAI-compatible) streaming into a browser-friendly - * ReadableStream of raw assistant text chunks. - */ -function createReadableStreamFromXAIStream(xaiStream: any): ReadableStream { - return new ReadableStream({ - async start(controller) { - try { - for await (const chunk of xaiStream) { - const content = chunk.choices[0]?.delta?.content || '' - if (content) { - controller.enqueue(new TextEncoder().encode(content)) - } - } - controller.close() - } catch (err) { - controller.error(err) - } - }, - }) -} - export const xAIProvider: ProviderConfig = { id: 'xai', name: 'xAI', @@ -115,27 +95,6 @@ export const xAIProvider: ProviderConfig = { if (request.temperature !== undefined) basePayload.temperature = request.temperature if (request.maxTokens !== undefined) basePayload.max_tokens = request.maxTokens - // Function to create response format configuration - const createResponseFormatPayload = (messages: any[] = allMessages) => { - const payload = { - ...basePayload, - messages, - } - - if (request.responseFormat) { - payload.response_format = { - type: 'json_schema', - json_schema: { - name: request.responseFormat.name || 'structured_response', - schema: request.responseFormat.schema || request.responseFormat, - strict: request.responseFormat.strict !== false, - }, - } - } - - return payload - } - // Handle tools and tool usage control let preparedTools: ReturnType | null = null @@ -154,7 +113,7 @@ export const xAIProvider: ProviderConfig = { // Use response format payload if needed, otherwise use base payload const streamingPayload = request.responseFormat - ? createResponseFormatPayload() + ? createResponseFormatPayload(basePayload, allMessages, request.responseFormat) : { ...basePayload, stream: true } if (!request.responseFormat) { @@ -243,7 +202,11 @@ export const xAIProvider: ProviderConfig = { originalToolChoice = toolChoice } else if (request.responseFormat) { // Only add response format if there are no tools - const responseFormatPayload = createResponseFormatPayload() + const responseFormatPayload = createResponseFormatPayload( + basePayload, + allMessages, + request.responseFormat + ) Object.assign(initialPayload, responseFormatPayload) } @@ -260,7 +223,6 @@ export const xAIProvider: ProviderConfig = { const toolResults = [] const currentMessages = [...allMessages] let iterationCount = 0 - const MAX_ITERATIONS = 10 // Track if a forced tool has been used let hasUsedForcedTool = false @@ -280,33 +242,20 @@ export const xAIProvider: ProviderConfig = { }, ] - // Helper function to check for forced tool usage in responses - const checkForForcedToolUsage = ( - response: any, - toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any } - ) => { - if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) { - const toolCallsResponse = response.choices[0].message.tool_calls - const result = trackForcedToolUsage( - toolCallsResponse, - toolChoice, - logger, - 'xai', - forcedTools, - usedForcedTools - ) - hasUsedForcedTool = result.hasUsedForcedTool - usedForcedTools = result.usedForcedTools - } - } - // Check if a forced tool was used in the first response if (originalToolChoice) { - checkForForcedToolUsage(currentResponse, originalToolChoice) + const result = checkForForcedToolUsage( + currentResponse, + originalToolChoice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = result.hasUsedForcedTool + usedForcedTools = result.usedForcedTools } try { - while (iterationCount < MAX_ITERATIONS) { + while (iterationCount < MAX_TOOL_ITERATIONS) { // Check for tool calls const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls if (!toolCallsInResponse || toolCallsInResponse.length === 0) { @@ -432,7 +381,12 @@ export const xAIProvider: ProviderConfig = { } else { // All forced tools have been used, check if we need response format for final response if (request.responseFormat) { - nextPayload = createResponseFormatPayload(currentMessages) + nextPayload = createResponseFormatPayload( + basePayload, + allMessages, + request.responseFormat, + currentMessages + ) } else { nextPayload = { ...basePayload, @@ -446,7 +400,12 @@ export const xAIProvider: ProviderConfig = { // Normal tool processing - check if this might be the final response if (request.responseFormat) { // Use response format for what might be the final response - nextPayload = createResponseFormatPayload(currentMessages) + nextPayload = createResponseFormatPayload( + basePayload, + allMessages, + request.responseFormat, + currentMessages + ) } else { nextPayload = { ...basePayload, @@ -464,7 +423,14 @@ export const xAIProvider: ProviderConfig = { // Check if any forced tools were used in this response if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') { - checkForForcedToolUsage(currentResponse, nextPayload.tool_choice) + const result = checkForForcedToolUsage( + currentResponse, + nextPayload.tool_choice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = result.hasUsedForcedTool + usedForcedTools = result.usedForcedTools } const nextModelEndTime = Date.now() @@ -509,7 +475,12 @@ export const xAIProvider: ProviderConfig = { if (request.responseFormat) { // Use response format, no tools finalStreamingPayload = { - ...createResponseFormatPayload(currentMessages), + ...createResponseFormatPayload( + basePayload, + allMessages, + request.responseFormat, + currentMessages + ), stream: true, } } else { diff --git a/apps/sim/providers/xai/utils.ts b/apps/sim/providers/xai/utils.ts new file mode 100644 index 000000000..c5ee067e5 --- /dev/null +++ b/apps/sim/providers/xai/utils.ts @@ -0,0 +1,83 @@ +import { createLogger } from '@/lib/logs/console/logger' +import { trackForcedToolUsage } from '@/providers/utils' + +const logger = createLogger('XAIProvider') + +/** + * Helper to wrap XAI (OpenAI-compatible) streaming into a browser-friendly + * ReadableStream of raw assistant text chunks. + */ +export function createReadableStreamFromXAIStream(xaiStream: any): ReadableStream { + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of xaiStream) { + const content = chunk.choices[0]?.delta?.content || '' + if (content) { + controller.enqueue(new TextEncoder().encode(content)) + } + } + controller.close() + } catch (err) { + controller.error(err) + } + }, + }) +} + +/** + * Creates a response format payload for XAI API requests. + */ +export function createResponseFormatPayload( + basePayload: any, + allMessages: any[], + responseFormat: any, + currentMessages?: any[] +) { + const payload = { + ...basePayload, + messages: currentMessages || allMessages, + } + + if (responseFormat) { + payload.response_format = { + type: 'json_schema', + json_schema: { + name: responseFormat.name || 'structured_response', + schema: responseFormat.schema || responseFormat, + strict: responseFormat.strict !== false, + }, + } + } + + return payload +} + +/** + * Helper function to check for forced tool usage in responses. + */ +export function checkForForcedToolUsage( + response: any, + toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }, + forcedTools: string[], + usedForcedTools: string[] +): { hasUsedForcedTool: boolean; usedForcedTools: string[] } { + let hasUsedForcedTool = false + let updatedUsedForcedTools = usedForcedTools + + if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) { + const toolCallsResponse = response.choices[0].message.tool_calls + const result = trackForcedToolUsage( + toolCallsResponse, + toolChoice, + logger, + 'xai', + forcedTools, + updatedUsedForcedTools + ) + hasUsedForcedTool = result.hasUsedForcedTool + updatedUsedForcedTools = result.usedForcedTools + } + + return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools } +} diff --git a/apps/sim/tools/llm/chat.ts b/apps/sim/tools/llm/chat.ts index 536400734..7af74232d 100644 --- a/apps/sim/tools/llm/chat.ts +++ b/apps/sim/tools/llm/chat.ts @@ -13,6 +13,8 @@ interface LLMChatParams { maxTokens?: number azureEndpoint?: string azureApiVersion?: string + vertexProject?: string + vertexLocation?: string } interface LLMChatResponse extends ToolResponse { @@ -77,6 +79,18 @@ export const llmChatTool: ToolConfig = { visibility: 'hidden', description: 'Azure OpenAI API version', }, + vertexProject: { + type: 'string', + required: false, + visibility: 'hidden', + description: 'Google Cloud project ID for Vertex AI', + }, + vertexLocation: { + type: 'string', + required: false, + visibility: 'hidden', + description: 'Google Cloud location for Vertex AI (defaults to us-central1)', + }, }, request: { @@ -98,6 +112,8 @@ export const llmChatTool: ToolConfig = { maxTokens: params.maxTokens, azureEndpoint: params.azureEndpoint, azureApiVersion: params.azureApiVersion, + vertexProject: params.vertexProject, + vertexLocation: params.vertexLocation, } }, }, diff --git a/bun.lock b/bun.lock index c5863930c..e13beed62 100644 --- a/bun.lock +++ b/bun.lock @@ -1,5 +1,6 @@ { "lockfileVersion": 1, + "configVersion": 0, "workspaces": { "": { "name": "simstudio", @@ -266,12 +267,12 @@ "sharp", ], "overrides": { - "react": "19.2.1", - "react-dom": "19.2.1", - "next": "16.1.0-canary.21", "@next/env": "16.1.0-canary.21", "drizzle-orm": "^0.44.5", + "next": "16.1.0-canary.21", "postgres": "^3.4.5", + "react": "19.2.1", + "react-dom": "19.2.1", }, "packages": { "@adobe/css-tools": ["@adobe/css-tools@4.4.4", "", {}, "sha512-Elp+iwUx5rN5+Y8xLt5/GRoG20WGoDCQ/1Fb+1LiGtvwbDavuSk0jhD/eZdckHAuzcDzccnkv+rEjyWfRx18gg=="],