feat(vertex): added vertex to list of supported providers (#2430)

* feat(vertex): added vertex to list of supported providers

* added utils files for each provider, consolidated gemini utils, added dynamic verbosity and reasoning fetcher
This commit is contained in:
Waleed
2025-12-17 14:57:58 -08:00
committed by GitHub
parent 1ae3b47f5c
commit 7b5405e968
48 changed files with 2767 additions and 935 deletions

View File

@@ -4,7 +4,7 @@
"private": true,
"license": "Apache-2.0",
"scripts": {
"dev": "next dev --port 3001",
"dev": "next dev --port 7322",
"build": "fumadocs-mdx && NODE_OPTIONS='--max-old-space-size=8192' next build",
"start": "next start",
"postinstall": "fumadocs-mdx",

View File

@@ -303,6 +303,14 @@ export async function POST(req: NextRequest) {
apiVersion: 'preview',
endpoint: env.AZURE_OPENAI_ENDPOINT,
}
} else if (providerEnv === 'vertex') {
providerConfig = {
provider: 'vertex',
model: modelToUse,
apiKey: env.COPILOT_API_KEY,
vertexProject: env.VERTEX_PROJECT,
vertexLocation: env.VERTEX_LOCATION,
}
} else {
providerConfig = {
provider: providerEnv,

View File

@@ -66,6 +66,14 @@ export async function POST(req: NextRequest) {
apiVersion: env.AZURE_OPENAI_API_VERSION,
endpoint: env.AZURE_OPENAI_ENDPOINT,
}
} else if (providerEnv === 'vertex') {
providerConfig = {
provider: 'vertex',
model: modelToUse,
apiKey: env.COPILOT_API_KEY,
vertexProject: env.VERTEX_PROJECT,
vertexLocation: env.VERTEX_LOCATION,
}
} else {
providerConfig = {
provider: providerEnv,

View File

@@ -35,6 +35,8 @@ export async function POST(request: NextRequest) {
apiKey,
azureEndpoint,
azureApiVersion,
vertexProject,
vertexLocation,
responseFormat,
workflowId,
workspaceId,
@@ -58,6 +60,8 @@ export async function POST(request: NextRequest) {
hasApiKey: !!apiKey,
hasAzureEndpoint: !!azureEndpoint,
hasAzureApiVersion: !!azureApiVersion,
hasVertexProject: !!vertexProject,
hasVertexLocation: !!vertexLocation,
hasResponseFormat: !!responseFormat,
workflowId,
stream: !!stream,
@@ -104,6 +108,8 @@ export async function POST(request: NextRequest) {
apiKey: finalApiKey,
azureEndpoint,
azureApiVersion,
vertexProject,
vertexLocation,
responseFormat,
workflowId,
workspaceId,

View File

@@ -8,6 +8,8 @@ import {
getHostedModels,
getMaxTemperature,
getProviderIcon,
getReasoningEffortValuesForModel,
getVerbosityValuesForModel,
MODELS_WITH_REASONING_EFFORT,
MODELS_WITH_VERBOSITY,
providers,
@@ -114,12 +116,47 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
type: 'dropdown',
placeholder: 'Select reasoning effort...',
options: [
{ label: 'none', id: 'none' },
{ label: 'minimal', id: 'minimal' },
{ label: 'low', id: 'low' },
{ label: 'medium', id: 'medium' },
{ label: 'high', id: 'high' },
],
dependsOn: ['model'],
fetchOptions: async (blockId: string) => {
const { useSubBlockStore } = await import('@/stores/workflows/subblock/store')
const { useWorkflowRegistry } = await import('@/stores/workflows/registry/store')
const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId
if (!activeWorkflowId) {
return [
{ label: 'low', id: 'low' },
{ label: 'medium', id: 'medium' },
{ label: 'high', id: 'high' },
]
}
const workflowValues = useSubBlockStore.getState().workflowValues[activeWorkflowId]
const blockValues = workflowValues?.[blockId]
const modelValue = blockValues?.model as string
if (!modelValue) {
return [
{ label: 'low', id: 'low' },
{ label: 'medium', id: 'medium' },
{ label: 'high', id: 'high' },
]
}
const validOptions = getReasoningEffortValuesForModel(modelValue)
if (!validOptions) {
return [
{ label: 'low', id: 'low' },
{ label: 'medium', id: 'medium' },
{ label: 'high', id: 'high' },
]
}
return validOptions.map((opt) => ({ label: opt, id: opt }))
},
value: () => 'medium',
condition: {
field: 'model',
@@ -136,6 +173,43 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
{ label: 'medium', id: 'medium' },
{ label: 'high', id: 'high' },
],
dependsOn: ['model'],
fetchOptions: async (blockId: string) => {
const { useSubBlockStore } = await import('@/stores/workflows/subblock/store')
const { useWorkflowRegistry } = await import('@/stores/workflows/registry/store')
const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId
if (!activeWorkflowId) {
return [
{ label: 'low', id: 'low' },
{ label: 'medium', id: 'medium' },
{ label: 'high', id: 'high' },
]
}
const workflowValues = useSubBlockStore.getState().workflowValues[activeWorkflowId]
const blockValues = workflowValues?.[blockId]
const modelValue = blockValues?.model as string
if (!modelValue) {
return [
{ label: 'low', id: 'low' },
{ label: 'medium', id: 'medium' },
{ label: 'high', id: 'high' },
]
}
const validOptions = getVerbosityValuesForModel(modelValue)
if (!validOptions) {
return [
{ label: 'low', id: 'low' },
{ label: 'medium', id: 'medium' },
{ label: 'high', id: 'high' },
]
}
return validOptions.map((opt) => ({ label: opt, id: opt }))
},
value: () => 'medium',
condition: {
field: 'model',
@@ -166,6 +240,28 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
value: providers['azure-openai'].models,
},
},
{
id: 'vertexProject',
title: 'Vertex AI Project',
type: 'short-input',
placeholder: 'your-gcp-project-id',
connectionDroppable: false,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'vertexLocation',
title: 'Vertex AI Location',
type: 'short-input',
placeholder: 'us-central1',
connectionDroppable: false,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'tools',
title: 'Tools',
@@ -465,6 +561,8 @@ Example 3 (Array Input):
apiKey: { type: 'string', description: 'Provider API key' },
azureEndpoint: { type: 'string', description: 'Azure OpenAI endpoint URL' },
azureApiVersion: { type: 'string', description: 'Azure API version' },
vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' },
vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' },
responseFormat: {
type: 'json',
description: 'JSON response format schema',

View File

@@ -239,6 +239,28 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
value: providers['azure-openai'].models,
},
},
{
id: 'vertexProject',
title: 'Vertex AI Project',
type: 'short-input',
placeholder: 'your-gcp-project-id',
connectionDroppable: false,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'vertexLocation',
title: 'Vertex AI Location',
type: 'short-input',
placeholder: 'us-central1',
connectionDroppable: false,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'temperature',
title: 'Temperature',
@@ -356,6 +378,14 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
apiKey: { type: 'string' as ParamType, description: 'Provider API key' },
azureEndpoint: { type: 'string' as ParamType, description: 'Azure OpenAI endpoint URL' },
azureApiVersion: { type: 'string' as ParamType, description: 'Azure API version' },
vertexProject: {
type: 'string' as ParamType,
description: 'Google Cloud project ID for Vertex AI',
},
vertexLocation: {
type: 'string' as ParamType,
description: 'Google Cloud location for Vertex AI',
},
temperature: {
type: 'number' as ParamType,
description: 'Response randomness level (low for consistent evaluation)',

View File

@@ -188,6 +188,28 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
value: providers['azure-openai'].models,
},
},
{
id: 'vertexProject',
title: 'Vertex AI Project',
type: 'short-input',
placeholder: 'your-gcp-project-id',
connectionDroppable: false,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'vertexLocation',
title: 'Vertex AI Location',
type: 'short-input',
placeholder: 'us-central1',
connectionDroppable: false,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'temperature',
title: 'Temperature',
@@ -235,6 +257,8 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
apiKey: { type: 'string', description: 'Provider API key' },
azureEndpoint: { type: 'string', description: 'Azure OpenAI endpoint URL' },
azureApiVersion: { type: 'string', description: 'Azure API version' },
vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' },
vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' },
temperature: {
type: 'number',
description: 'Response randomness level (low for consistent routing)',

View File

@@ -99,6 +99,28 @@ export const TranslateBlock: BlockConfig = {
value: providers['azure-openai'].models,
},
},
{
id: 'vertexProject',
title: 'Vertex AI Project',
type: 'short-input',
placeholder: 'your-gcp-project-id',
connectionDroppable: false,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'vertexLocation',
title: 'Vertex AI Location',
type: 'short-input',
placeholder: 'us-central1',
connectionDroppable: false,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'systemPrompt',
title: 'System Prompt',
@@ -120,6 +142,8 @@ export const TranslateBlock: BlockConfig = {
apiKey: params.apiKey,
azureEndpoint: params.azureEndpoint,
azureApiVersion: params.azureApiVersion,
vertexProject: params.vertexProject,
vertexLocation: params.vertexLocation,
}),
},
},
@@ -129,6 +153,8 @@ export const TranslateBlock: BlockConfig = {
apiKey: { type: 'string', description: 'Provider API key' },
azureEndpoint: { type: 'string', description: 'Azure OpenAI endpoint URL' },
azureApiVersion: { type: 'string', description: 'Azure API version' },
vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' },
vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' },
systemPrompt: { type: 'string', description: 'Translation instructions' },
},
outputs: {

View File

@@ -2452,6 +2452,56 @@ export const GeminiIcon = (props: SVGProps<SVGSVGElement>) => (
</svg>
)
export const VertexIcon = (props: SVGProps<SVGSVGElement>) => (
<svg
{...props}
id='standard_product_icon'
xmlns='http://www.w3.org/2000/svg'
version='1.1'
viewBox='0 0 512 512'
>
<g id='bounding_box'>
<rect width='512' height='512' fill='none' />
</g>
<g id='art'>
<path
d='M128,244.99c-8.84,0-16-7.16-16-16v-95.97c0-8.84,7.16-16,16-16s16,7.16,16,16v95.97c0,8.84-7.16,16-16,16Z'
fill='#ea4335'
/>
<path
d='M256,458c-2.98,0-5.97-.83-8.59-2.5l-186-122c-7.46-4.74-9.65-14.63-4.91-22.09,4.75-7.46,14.64-9.65,22.09-4.91l177.41,116.53,177.41-116.53c7.45-4.74,17.34-2.55,22.09,4.91,4.74,7.46,2.55,17.34-4.91,22.09l-186,122c-2.62,1.67-5.61,2.5-8.59,2.5Z'
fill='#fbbc04'
/>
<path
d='M256,388.03c-8.84,0-16-7.16-16-16v-73.06c0-8.84,7.16-16,16-16s16,7.16,16,16v73.06c0,8.84-7.16,16-16,16Z'
fill='#34a853'
/>
<circle cx='128' cy='70' r='16' fill='#ea4335' />
<circle cx='128' cy='292' r='16' fill='#ea4335' />
<path
d='M384.23,308.01c-8.82,0-15.98-7.14-16-15.97l-.23-94.01c-.02-8.84,7.13-16.02,15.97-16.03h.04c8.82,0,15.98,7.14,16,15.97l.23,94.01c.02,8.84-7.13,16.02-15.97,16.03h-.04Z'
fill='#4285f4'
/>
<circle cx='384' cy='70' r='16' fill='#4285f4' />
<circle cx='384' cy='134' r='16' fill='#4285f4' />
<path
d='M320,220.36c-8.84,0-16-7.16-16-16v-103.02c0-8.84,7.16-16,16-16s16,7.16,16,16v103.02c0,8.84-7.16,16-16,16Z'
fill='#fbbc04'
/>
<circle cx='256' cy='171' r='16' fill='#34a853' />
<circle cx='256' cy='235' r='16' fill='#34a853' />
<circle cx='320' cy='265' r='16' fill='#fbbc04' />
<circle cx='320' cy='329' r='16' fill='#fbbc04' />
<path
d='M192,217.36c-8.84,0-16-7.16-16-16v-100.02c0-8.84,7.16-16,16-16s16,7.16,16,16v100.02c0,8.84-7.16,16-16,16Z'
fill='#fbbc04'
/>
<circle cx='192' cy='265' r='16' fill='#fbbc04' />
<circle cx='192' cy='329' r='16' fill='#fbbc04' />
</g>
</svg>
)
export const CerebrasIcon = (props: SVGProps<SVGSVGElement>) => (
<svg
{...props}

View File

@@ -493,7 +493,7 @@ export class AgentBlockHandler implements BlockHandler {
const discoveredTools = await this.discoverMcpToolsForServer(ctx, serverId)
return { serverId, tools, discoveredTools, error: null as Error | null }
} catch (error) {
logger.error(`Failed to discover tools from server ${serverId}:`, error)
logger.error(`Failed to discover tools from server ${serverId}:`)
return { serverId, tools, discoveredTools: [] as any[], error: error as Error }
}
})
@@ -883,6 +883,8 @@ export class AgentBlockHandler implements BlockHandler {
apiKey: inputs.apiKey,
azureEndpoint: inputs.azureEndpoint,
azureApiVersion: inputs.azureApiVersion,
vertexProject: inputs.vertexProject,
vertexLocation: inputs.vertexLocation,
responseFormat,
workflowId: ctx.workflowId,
workspaceId: ctx.workspaceId,
@@ -975,6 +977,8 @@ export class AgentBlockHandler implements BlockHandler {
apiKey: finalApiKey,
azureEndpoint: providerRequest.azureEndpoint,
azureApiVersion: providerRequest.azureApiVersion,
vertexProject: providerRequest.vertexProject,
vertexLocation: providerRequest.vertexLocation,
responseFormat: providerRequest.responseFormat,
workflowId: providerRequest.workflowId,
workspaceId: providerRequest.workspaceId,

View File

@@ -19,6 +19,8 @@ export interface AgentInputs {
apiKey?: string
azureEndpoint?: string
azureApiVersion?: string
vertexProject?: string
vertexLocation?: string
reasoningEffort?: string
verbosity?: string
}

View File

@@ -148,7 +148,14 @@ export type CopilotProviderConfig =
endpoint?: string
}
| {
provider: Exclude<ProviderId, 'azure-openai'>
provider: 'vertex'
model: string
apiKey?: string
vertexProject?: string
vertexLocation?: string
}
| {
provider: Exclude<ProviderId, 'azure-openai' | 'vertex'>
model?: string
apiKey?: string
}

View File

@@ -98,6 +98,10 @@ export const env = createEnv({
OCR_AZURE_MODEL_NAME: z.string().optional(), // Azure Mistral OCR model name for document processing
OCR_AZURE_API_KEY: z.string().min(1).optional(), // Azure Mistral OCR API key
// Vertex AI Configuration
VERTEX_PROJECT: z.string().optional(), // Google Cloud project ID for Vertex AI
VERTEX_LOCATION: z.string().optional(), // Google Cloud location/region for Vertex AI (defaults to us-central1)
// Monitoring & Analytics
TELEMETRY_ENDPOINT: z.string().url().optional(), // Custom telemetry/analytics endpoint
COST_MULTIPLIER: z.number().optional(), // Multiplier for cost calculations

View File

@@ -404,15 +404,11 @@ class McpService {
failedCount++
const errorMessage =
result.reason instanceof Error ? result.reason.message : 'Unknown error'
logger.warn(
`[${requestId}] Failed to discover tools from server ${server.name}:`,
result.reason
)
logger.warn(`[${requestId}] Failed to discover tools from server ${server.name}:`)
statusUpdates.push(this.updateServerStatus(server.id!, workspaceId, false, errorMessage))
}
})
// Update server statuses in parallel (don't block on this)
Promise.allSettled(statusUpdates).catch((err) => {
logger.error(`[${requestId}] Error updating server statuses:`, err)
})

View File

@@ -8,7 +8,7 @@
"node": ">=20.0.0"
},
"scripts": {
"dev": "next dev --port 3000",
"dev": "next dev --port 7321",
"dev:webpack": "next dev --webpack",
"dev:sockets": "bun run socket-server/index.ts",
"dev:full": "concurrently -n \"App,Realtime\" -c \"cyan,magenta\" \"bun run dev\" \"bun run dev:sockets\"",

View File

@@ -1,35 +1,24 @@
import Anthropic from '@anthropic-ai/sdk'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import {
checkForForcedToolUsage,
createReadableStreamFromAnthropicStream,
generateToolUseId,
} from '@/providers/anthropic/utils'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
ProviderRequest,
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
import { executeTool } from '@/tools'
import { getProviderDefaultModel, getProviderModels } from '../models'
import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types'
import { prepareToolExecution, prepareToolsWithUsageControl, trackForcedToolUsage } from '../utils'
const logger = createLogger('AnthropicProvider')
/**
* Helper to wrap Anthropic streaming into a browser-friendly ReadableStream
*/
function createReadableStreamFromAnthropicStream(
anthropicStream: AsyncIterable<any>
): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const event of anthropicStream) {
if (event.type === 'content_block_delta' && event.delta?.text) {
controller.enqueue(new TextEncoder().encode(event.delta.text))
}
}
controller.close()
} catch (err) {
controller.error(err)
}
},
})
}
export const anthropicProvider: ProviderConfig = {
id: 'anthropic',
name: 'Anthropic',
@@ -47,11 +36,6 @@ export const anthropicProvider: ProviderConfig = {
const anthropic = new Anthropic({ apiKey: request.apiKey })
// Helper function to generate a simple unique ID for tool uses
const generateToolUseId = (toolName: string) => {
return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}`
}
// Transform messages to Anthropic format
const messages: any[] = []
@@ -373,7 +357,6 @@ ${fieldDescriptions}
const toolResults = []
const currentMessages = [...messages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
// Track if a forced tool has been used
let hasUsedForcedTool = false
@@ -393,47 +376,20 @@ ${fieldDescriptions}
},
]
// Helper function to check for forced tool usage in Anthropic responses
const checkForForcedToolUsage = (response: any, toolChoice: any) => {
if (
typeof toolChoice === 'object' &&
toolChoice !== null &&
Array.isArray(response.content)
) {
const toolUses = response.content.filter((item: any) => item.type === 'tool_use')
if (toolUses.length > 0) {
// Convert Anthropic tool_use format to a format trackForcedToolUsage can understand
const adaptedToolCalls = toolUses.map((tool: any) => ({
name: tool.name,
}))
// Convert Anthropic tool_choice format to match OpenAI format for tracking
const adaptedToolChoice =
toolChoice.type === 'tool' ? { function: { name: toolChoice.name } } : toolChoice
const result = trackForcedToolUsage(
adaptedToolCalls,
adaptedToolChoice,
logger,
'anthropic',
forcedTools,
usedForcedTools
)
// Make the behavior consistent with the initial check
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
return result
}
}
return null
// Check if a forced tool was used in the first response
const firstCheckResult = checkForForcedToolUsage(
currentResponse,
originalToolChoice,
forcedTools,
usedForcedTools
)
if (firstCheckResult) {
hasUsedForcedTool = firstCheckResult.hasUsedForcedTool
usedForcedTools = firstCheckResult.usedForcedTools
}
// Check if a forced tool was used in the first response
checkForForcedToolUsage(currentResponse, originalToolChoice)
try {
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
const toolUses = currentResponse.content.filter((item) => item.type === 'tool_use')
if (!toolUses || toolUses.length === 0) {
@@ -576,7 +532,16 @@ ${fieldDescriptions}
currentResponse = await anthropic.messages.create(nextPayload)
// Check if any forced tools were used in this response
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
const nextCheckResult = checkForForcedToolUsage(
currentResponse,
nextPayload.tool_choice,
forcedTools,
usedForcedTools
)
if (nextCheckResult) {
hasUsedForcedTool = nextCheckResult.hasUsedForcedTool
usedForcedTools = nextCheckResult.usedForcedTools
}
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
@@ -727,7 +692,6 @@ ${fieldDescriptions}
const toolResults = []
const currentMessages = [...messages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
// Track if a forced tool has been used
let hasUsedForcedTool = false
@@ -747,47 +711,20 @@ ${fieldDescriptions}
},
]
// Helper function to check for forced tool usage in Anthropic responses
const checkForForcedToolUsage = (response: any, toolChoice: any) => {
if (
typeof toolChoice === 'object' &&
toolChoice !== null &&
Array.isArray(response.content)
) {
const toolUses = response.content.filter((item: any) => item.type === 'tool_use')
if (toolUses.length > 0) {
// Convert Anthropic tool_use format to a format trackForcedToolUsage can understand
const adaptedToolCalls = toolUses.map((tool: any) => ({
name: tool.name,
}))
// Convert Anthropic tool_choice format to match OpenAI format for tracking
const adaptedToolChoice =
toolChoice.type === 'tool' ? { function: { name: toolChoice.name } } : toolChoice
const result = trackForcedToolUsage(
adaptedToolCalls,
adaptedToolChoice,
logger,
'anthropic',
forcedTools,
usedForcedTools
)
// Make the behavior consistent with the initial check
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
return result
}
}
return null
// Check if a forced tool was used in the first response
const firstCheckResult = checkForForcedToolUsage(
currentResponse,
originalToolChoice,
forcedTools,
usedForcedTools
)
if (firstCheckResult) {
hasUsedForcedTool = firstCheckResult.hasUsedForcedTool
usedForcedTools = firstCheckResult.usedForcedTools
}
// Check if a forced tool was used in the first response
checkForForcedToolUsage(currentResponse, originalToolChoice)
try {
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
const toolUses = currentResponse.content.filter((item) => item.type === 'tool_use')
if (!toolUses || toolUses.length === 0) {
@@ -926,7 +863,16 @@ ${fieldDescriptions}
currentResponse = await anthropic.messages.create(nextPayload)
// Check if any forced tools were used in this response
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
const nextCheckResult = checkForForcedToolUsage(
currentResponse,
nextPayload.tool_choice,
forcedTools,
usedForcedTools
)
if (nextCheckResult) {
hasUsedForcedTool = nextCheckResult.hasUsedForcedTool
usedForcedTools = nextCheckResult.usedForcedTools
}
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime

View File

@@ -0,0 +1,70 @@
import { createLogger } from '@/lib/logs/console/logger'
import { trackForcedToolUsage } from '@/providers/utils'
const logger = createLogger('AnthropicUtils')
/**
* Helper to wrap Anthropic streaming into a browser-friendly ReadableStream
*/
export function createReadableStreamFromAnthropicStream(
anthropicStream: AsyncIterable<any>
): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const event of anthropicStream) {
if (event.type === 'content_block_delta' && event.delta?.text) {
controller.enqueue(new TextEncoder().encode(event.delta.text))
}
}
controller.close()
} catch (err) {
controller.error(err)
}
},
})
}
/**
* Helper function to generate a simple unique ID for tool uses
*/
export function generateToolUseId(toolName: string): string {
return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}`
}
/**
* Helper function to check for forced tool usage in Anthropic responses
*/
export function checkForForcedToolUsage(
response: any,
toolChoice: any,
forcedTools: string[],
usedForcedTools: string[]
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } | null {
if (typeof toolChoice === 'object' && toolChoice !== null && Array.isArray(response.content)) {
const toolUses = response.content.filter((item: any) => item.type === 'tool_use')
if (toolUses.length > 0) {
// Convert Anthropic tool_use format to a format trackForcedToolUsage can understand
const adaptedToolCalls = toolUses.map((tool: any) => ({
name: tool.name,
}))
// Convert Anthropic tool_choice format to match OpenAI format for tracking
const adaptedToolChoice =
toolChoice.type === 'tool' ? { function: { name: toolChoice.name } } : toolChoice
const result = trackForcedToolUsage(
adaptedToolCalls,
adaptedToolChoice,
logger,
'anthropic',
forcedTools,
usedForcedTools
)
return result
}
}
return null
}

View File

@@ -2,6 +2,11 @@ import { AzureOpenAI } from 'openai'
import { env } from '@/lib/core/config/env'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import {
checkForForcedToolUsage,
createReadableStreamFromAzureOpenAIStream,
} from '@/providers/azure-openai/utils'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
@@ -9,55 +14,11 @@ import type {
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import {
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
} from '@/providers/utils'
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
import { executeTool } from '@/tools'
const logger = createLogger('AzureOpenAIProvider')
/**
* Helper function to convert an Azure OpenAI stream to a standard ReadableStream
* and collect completion metrics
*/
function createReadableStreamFromAzureOpenAIStream(
azureOpenAIStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of azureOpenAIStream) {
// Check for usage data in the final chunk
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
// Once stream is complete, call the completion callback with the final content and usage
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}
/**
* Azure OpenAI provider configuration
*/
@@ -303,26 +264,6 @@ export const azureOpenAIProvider: ProviderConfig = {
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
// Helper function to check for forced tool usage in responses
const checkForForcedToolUsage = (
response: any,
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }
) => {
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
const toolCallsResponse = response.choices[0].message.tool_calls
const result = trackForcedToolUsage(
toolCallsResponse,
toolChoice,
logger,
'azure-openai',
forcedTools,
usedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
}
}
let currentResponse = await azureOpenAI.chat.completions.create(payload)
const firstResponseTime = Date.now() - initialCallTime
@@ -337,7 +278,6 @@ export const azureOpenAIProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
// Track time spent in model vs tools
let modelTime = firstResponseTime
@@ -358,9 +298,17 @@ export const azureOpenAIProvider: ProviderConfig = {
]
// Check if a forced tool was used in the first response
checkForForcedToolUsage(currentResponse, originalToolChoice)
const firstCheckResult = checkForForcedToolUsage(
currentResponse,
originalToolChoice,
logger,
forcedTools,
usedForcedTools
)
hasUsedForcedTool = firstCheckResult.hasUsedForcedTool
usedForcedTools = firstCheckResult.usedForcedTools
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
@@ -368,7 +316,7 @@ export const azureOpenAIProvider: ProviderConfig = {
}
logger.info(
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
// Track time for tool calls in this batch
@@ -491,7 +439,15 @@ export const azureOpenAIProvider: ProviderConfig = {
currentResponse = await azureOpenAI.chat.completions.create(nextPayload)
// Check if any forced tools were used in this response
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
const nextCheckResult = checkForForcedToolUsage(
currentResponse,
nextPayload.tool_choice,
logger,
forcedTools,
usedForcedTools
)
hasUsedForcedTool = nextCheckResult.hasUsedForcedTool
usedForcedTools = nextCheckResult.usedForcedTools
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime

View File

@@ -0,0 +1,70 @@
import type { Logger } from '@/lib/logs/console/logger'
import { trackForcedToolUsage } from '@/providers/utils'
/**
* Helper function to convert an Azure OpenAI stream to a standard ReadableStream
* and collect completion metrics
*/
export function createReadableStreamFromAzureOpenAIStream(
azureOpenAIStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of azureOpenAIStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}
/**
* Helper function to check for forced tool usage in responses
*/
export function checkForForcedToolUsage(
response: any,
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
logger: Logger,
forcedTools: string[],
usedForcedTools: string[]
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
let hasUsedForcedTool = false
let updatedUsedForcedTools = [...usedForcedTools]
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
const toolCallsResponse = response.choices[0].message.tool_calls
const result = trackForcedToolUsage(
toolCallsResponse,
toolChoice,
logger,
'azure-openai',
forcedTools,
updatedUsedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
updatedUsedForcedTools = result.usedForcedTools
}
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
}

View File

@@ -1,6 +1,9 @@
import { Cerebras } from '@cerebras/cerebras_cloud_sdk'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import type { CerebrasResponse } from '@/providers/cerebras/types'
import { createReadableStreamFromCerebrasStream } from '@/providers/cerebras/utils'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
@@ -14,35 +17,9 @@ import {
trackForcedToolUsage,
} from '@/providers/utils'
import { executeTool } from '@/tools'
import type { CerebrasResponse } from './types'
const logger = createLogger('CerebrasProvider')
/**
* Helper to convert a Cerebras streaming response (async iterable) into a ReadableStream.
* Enqueues only the model's text delta chunks as UTF-8 encoded bytes.
*/
function createReadableStreamFromCerebrasStream(
cerebrasStream: AsyncIterable<any>
): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of cerebrasStream) {
// Expecting delta content similar to OpenAI: chunk.choices[0]?.delta?.content
const content = chunk.choices?.[0]?.delta?.content || ''
if (content) {
controller.enqueue(new TextEncoder().encode(content))
}
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}
export const cerebrasProvider: ProviderConfig = {
id: 'cerebras',
name: 'Cerebras',
@@ -223,7 +200,6 @@ export const cerebrasProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
// Track time spent in model vs tools
let modelTime = firstResponseTime
@@ -246,7 +222,7 @@ export const cerebrasProvider: ProviderConfig = {
const toolCallSignatures = new Set()
try {
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls

View File

@@ -0,0 +1,23 @@
/**
* Helper to convert a Cerebras streaming response (async iterable) into a ReadableStream.
* Enqueues only the model's text delta chunks as UTF-8 encoded bytes.
*/
export function createReadableStreamFromCerebrasStream(
cerebrasStream: AsyncIterable<any>
): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of cerebrasStream) {
const content = chunk.choices?.[0]?.delta?.content || ''
if (content) {
controller.enqueue(new TextEncoder().encode(content))
}
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}

View File

@@ -1,6 +1,8 @@
import OpenAI from 'openai'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import { createReadableStreamFromDeepseekStream } from '@/providers/deepseek/utils'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
@@ -17,28 +19,6 @@ import { executeTool } from '@/tools'
const logger = createLogger('DeepseekProvider')
/**
* Helper function to convert a DeepSeek (OpenAI-compatible) stream to a ReadableStream
* of text chunks that can be consumed by the browser.
*/
function createReadableStreamFromDeepseekStream(deepseekStream: any): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of deepseekStream) {
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
controller.enqueue(new TextEncoder().encode(content))
}
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}
export const deepseekProvider: ProviderConfig = {
id: 'deepseek',
name: 'Deepseek',
@@ -231,7 +211,6 @@ export const deepseekProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
// Track if a forced tool has been used
let hasUsedForcedTool = false
@@ -270,7 +249,7 @@ export const deepseekProvider: ProviderConfig = {
}
try {
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {

View File

@@ -0,0 +1,21 @@
/**
* Helper function to convert a DeepSeek (OpenAI-compatible) stream to a ReadableStream
* of text chunks that can be consumed by the browser.
*/
export function createReadableStreamFromDeepseekStream(deepseekStream: any): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of deepseekStream) {
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
controller.enqueue(new TextEncoder().encode(content))
}
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}

View File

@@ -1,5 +1,12 @@
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import {
cleanSchemaForGemini,
convertToGeminiFormat,
extractFunctionCall,
extractTextContent,
} from '@/providers/google/utils'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
@@ -19,7 +26,13 @@ const logger = createLogger('GoogleProvider')
/**
* Creates a ReadableStream from Google's Gemini stream response
*/
function createReadableStreamFromGeminiStream(response: Response): ReadableStream<Uint8Array> {
function createReadableStreamFromGeminiStream(
response: Response,
onComplete?: (
content: string,
usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number }
) => void
): ReadableStream<Uint8Array> {
const reader = response.body?.getReader()
if (!reader) {
throw new Error('Failed to get reader from response body')
@@ -29,18 +42,24 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
async start(controller) {
try {
let buffer = ''
let fullContent = ''
let usageData: {
promptTokenCount?: number
candidatesTokenCount?: number
totalTokenCount?: number
} | null = null
while (true) {
const { done, value } = await reader.read()
if (done) {
// Try to parse any remaining buffer as complete JSON
if (buffer.trim()) {
// Processing final buffer
try {
const data = JSON.parse(buffer.trim())
if (data.usageMetadata) {
usageData = data.usageMetadata
}
const candidate = data.candidates?.[0]
if (candidate?.content?.parts) {
// Check if this is a function call
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
@@ -49,26 +68,27 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
functionName: functionCall.name,
}
)
// Function calls should not be streamed - end the stream early
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
} catch (e) {
// Final buffer not valid JSON, checking if it contains JSON array
// Try parsing as JSON array if it starts with [
if (buffer.trim().startsWith('[')) {
try {
const dataArray = JSON.parse(buffer.trim())
if (Array.isArray(dataArray)) {
for (const item of dataArray) {
if (item.usageMetadata) {
usageData = item.usageMetadata
}
const candidate = item.candidates?.[0]
if (candidate?.content?.parts) {
// Check if this is a function call
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
@@ -77,11 +97,13 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
@@ -93,6 +115,7 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
}
}
}
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
break
}
@@ -100,14 +123,11 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
const text = new TextDecoder().decode(value)
buffer += text
// Try to find complete JSON objects in buffer
// Look for patterns like: {...}\n{...} or just a single {...}
let searchIndex = 0
while (searchIndex < buffer.length) {
const openBrace = buffer.indexOf('{', searchIndex)
if (openBrace === -1) break
// Try to find the matching closing brace
let braceCount = 0
let inString = false
let escaped = false
@@ -138,28 +158,34 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
}
if (closeBrace !== -1) {
// Found a complete JSON object
const jsonStr = buffer.substring(openBrace, closeBrace + 1)
try {
const data = JSON.parse(jsonStr)
// JSON parsed successfully from stream
if (data.usageMetadata) {
usageData = data.usageMetadata
}
const candidate = data.candidates?.[0]
// Handle specific finish reasons
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
logger.warn('Gemini returned UNEXPECTED_TOOL_CALL in streaming mode', {
finishReason: candidate.finishReason,
hasContent: !!candidate?.content,
hasParts: !!candidate?.content?.parts,
})
// This indicates a configuration issue - tools might be improperly configured for streaming
continue
const textContent = extractTextContent(candidate)
if (textContent) {
fullContent += textContent
controller.enqueue(new TextEncoder().encode(textContent))
}
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
if (candidate?.content?.parts) {
// Check if this is a function call
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
@@ -168,13 +194,13 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
functionName: functionCall.name,
}
)
// Function calls should not be streamed - we need to end the stream
// and let the non-streaming tool execution flow handle this
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
@@ -185,7 +211,6 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
})
}
// Remove processed JSON from buffer and continue searching
buffer = buffer.substring(closeBrace + 1)
searchIndex = 0
} else {
@@ -232,45 +257,36 @@ export const googleProvider: ProviderConfig = {
streaming: !!request.stream,
})
// Start execution timer for the entire provider execution
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
// Convert messages to Gemini format
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
const requestedModel = request.model || 'gemini-2.5-pro'
// Build request payload
const payload: any = {
contents,
generationConfig: {},
}
// Add temperature if specified
if (request.temperature !== undefined && request.temperature !== null) {
payload.generationConfig.temperature = request.temperature
}
// Add max tokens if specified
if (request.maxTokens !== undefined) {
payload.generationConfig.maxOutputTokens = request.maxTokens
}
// Add system instruction if provided
if (systemInstruction) {
payload.systemInstruction = systemInstruction
}
// Add structured output format if requested (but not when tools are present)
if (request.responseFormat && !tools?.length) {
const responseFormatSchema = request.responseFormat.schema || request.responseFormat
// Clean the schema using our helper function
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
// Use Gemini's native structured output approach
payload.generationConfig.responseMimeType = 'application/json'
payload.generationConfig.responseSchema = cleanSchema
@@ -284,7 +300,6 @@ export const googleProvider: ProviderConfig = {
)
}
// Handle tools and tool usage control
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
if (tools?.length) {
@@ -298,7 +313,6 @@ export const googleProvider: ProviderConfig = {
},
]
// Add Google-specific tool configuration
if (toolConfig) {
payload.toolConfig = toolConfig
}
@@ -313,14 +327,10 @@ export const googleProvider: ProviderConfig = {
}
}
// Make the API request
const initialCallTime = Date.now()
// Disable streaming for initial requests when tools are present to avoid function calls in streams
// Only enable streaming for the final response after tool execution
const shouldStream = request.stream && !tools?.length
// Use streamGenerateContent for streaming requests
const endpoint = shouldStream
? `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:streamGenerateContent?key=${request.apiKey}`
: `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:generateContent?key=${request.apiKey}`
@@ -352,16 +362,11 @@ export const googleProvider: ProviderConfig = {
const firstResponseTime = Date.now() - initialCallTime
// Handle streaming response
if (shouldStream) {
logger.info('Handling Google Gemini streaming response')
// Create a ReadableStream from the Google Gemini stream
const stream = createReadableStreamFromGeminiStream(response)
// Create an object that combines the stream with execution metadata
const streamingExecution: StreamingExecution = {
stream,
const streamingResult: StreamingExecution = {
stream: null as any,
execution: {
success: true,
output: {
@@ -389,7 +394,6 @@ export const googleProvider: ProviderConfig = {
duration: firstResponseTime,
},
],
// Cost will be calculated in logger
},
},
logs: [],
@@ -402,18 +406,49 @@ export const googleProvider: ProviderConfig = {
},
}
return streamingExecution
streamingResult.stream = createReadableStreamFromGeminiStream(
response,
(content, usage) => {
streamingResult.execution.output.content = content
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
if (streamingResult.execution.output.providerTiming) {
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
streamingResult.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
streamEndTime
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
streamEndTime - providerStartTime
}
}
if (usage) {
streamingResult.execution.output.tokens = {
prompt: usage.promptTokenCount || 0,
completion: usage.candidatesTokenCount || 0,
total:
usage.totalTokenCount ||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0),
}
}
}
)
return streamingResult
}
let geminiResponse = await response.json()
// Check structured output format
if (payload.generationConfig?.responseSchema) {
const candidate = geminiResponse.candidates?.[0]
if (candidate?.content?.parts?.[0]?.text) {
const text = candidate.content.parts[0].text
try {
// Validate JSON structure
JSON.parse(text)
logger.info('Successfully received structured JSON output')
} catch (_e) {
@@ -422,7 +457,6 @@ export const googleProvider: ProviderConfig = {
}
}
// Initialize response tracking variables
let content = ''
let tokens = {
prompt: 0,
@@ -432,16 +466,13 @@ export const googleProvider: ProviderConfig = {
const toolCalls = []
const toolResults = []
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
// Track forced tools and their usage (similar to OpenAI pattern)
const originalToolConfig = preparedTools?.toolConfig
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
let hasUsedForcedTool = false
let currentToolConfig = originalToolConfig
// Helper function to check for forced tool usage in responses
const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => {
if (currentToolConfig && forcedTools.length > 0) {
const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }]
@@ -466,11 +497,9 @@ export const googleProvider: ProviderConfig = {
}
}
// Track time spent in model vs tools
let modelTime = firstResponseTime
let toolsTime = 0
// Track each model and tool call segment with timestamps
const timeSegments: TimeSegment[] = [
{
type: 'model',
@@ -482,46 +511,50 @@ export const googleProvider: ProviderConfig = {
]
try {
// Extract content or function calls from initial response
const candidate = geminiResponse.candidates?.[0]
// Check if response contains function calls
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
logger.warn(
'Gemini returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
{
finishReason: candidate.finishReason,
hasContent: !!candidate?.content,
hasParts: !!candidate?.content?.parts,
}
)
content = extractTextContent(candidate)
}
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.info(`Received function call from Gemini: ${functionCall.name}`)
// Process function calls in a loop
while (iterationCount < MAX_ITERATIONS) {
// Get the latest function calls
while (iterationCount < MAX_TOOL_ITERATIONS) {
const latestResponse = geminiResponse.candidates?.[0]
const latestFunctionCall = extractFunctionCall(latestResponse)
if (!latestFunctionCall) {
// No more function calls - extract final text content
content = extractTextContent(latestResponse)
break
}
logger.info(
`Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
`Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
// Track time for tool calls
const toolsStartTime = Date.now()
try {
const toolName = latestFunctionCall.name
const toolArgs = latestFunctionCall.args || {}
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) {
logger.warn(`Tool ${toolName} not found in registry, skipping`)
break
}
// Execute the tool
const toolCallStartTime = Date.now()
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
@@ -529,7 +562,6 @@ export const googleProvider: ProviderConfig = {
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
// Add to time segments for both success and failure
timeSegments.push({
type: 'tool',
name: toolName,
@@ -538,13 +570,11 @@ export const googleProvider: ProviderConfig = {
duration: toolCallDuration,
})
// Prepare result content for the LLM
let resultContent: any
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
} else {
// Include error information so LLM can respond appropriately
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
@@ -562,14 +592,10 @@ export const googleProvider: ProviderConfig = {
success: result.success,
})
// Prepare for next request with simplified messages
// Use simple format: original query + most recent function call + result
const simplifiedMessages = [
// Original user request - find the first user request
...(contents.filter((m) => m.role === 'user').length > 0
? [contents.filter((m) => m.role === 'user')[0]]
: [contents[0]]),
// Function call from model
{
role: 'model',
parts: [
@@ -581,7 +607,6 @@ export const googleProvider: ProviderConfig = {
},
],
},
// Function response - but use USER role since Gemini only accepts user or model
{
role: 'user',
parts: [
@@ -592,35 +617,27 @@ export const googleProvider: ProviderConfig = {
},
]
// Calculate tool call time
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
// Check for forced tool usage and update configuration
checkForForcedToolUsage(latestFunctionCall)
// Make the next request with updated messages
const nextModelStartTime = Date.now()
try {
// Check if we should stream the final response after tool calls
if (request.stream) {
// Create a payload for the streaming response after tool calls
const streamingPayload = {
...payload,
contents: simplifiedMessages,
}
// Check if we should remove tools and enable structured output for final response
const allForcedToolsUsed =
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
if (allForcedToolsUsed && request.responseFormat) {
// All forced tools have been used, we can now remove tools and enable structured output
streamingPayload.tools = undefined
streamingPayload.toolConfig = undefined
// Add structured output format for final response
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
@@ -633,7 +650,6 @@ export const googleProvider: ProviderConfig = {
logger.info('Using structured output for final response after tool execution')
} else {
// Use updated tool configuration if available, otherwise default to AUTO
if (currentToolConfig) {
streamingPayload.toolConfig = currentToolConfig
} else {
@@ -641,11 +657,8 @@ export const googleProvider: ProviderConfig = {
}
}
// Check if we should handle this as a potential forced tool call
// First make a non-streaming request to see if we get a function call
const checkPayload = {
...streamingPayload,
// Remove stream property to get non-streaming response
}
checkPayload.stream = undefined
@@ -677,7 +690,6 @@ export const googleProvider: ProviderConfig = {
const checkFunctionCall = extractFunctionCall(checkCandidate)
if (checkFunctionCall) {
// We have a function call - handle it in non-streaming mode
logger.info(
'Function call detected in follow-up, handling in non-streaming mode',
{
@@ -685,10 +697,8 @@ export const googleProvider: ProviderConfig = {
}
)
// Update geminiResponse to continue the tool execution loop
geminiResponse = checkResult
// Update token counts if available
if (checkResult.usageMetadata) {
tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0
tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0
@@ -697,12 +707,10 @@ export const googleProvider: ProviderConfig = {
(checkResult.usageMetadata.candidatesTokenCount || 0)
}
// Calculate timing for this model call
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
modelTime += thisModelTime
// Add to time segments
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
@@ -711,14 +719,32 @@ export const googleProvider: ProviderConfig = {
duration: thisModelTime,
})
// Continue the loop to handle the function call
iterationCount++
continue
}
// No function call - proceed with streaming
logger.info('No function call detected, proceeding with streaming response')
// Make the streaming request with streamGenerateContent endpoint
// Apply structured output for the final response if responseFormat is specified
// This works regardless of whether tools were forced or auto
if (request.responseFormat) {
streamingPayload.tools = undefined
streamingPayload.toolConfig = undefined
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!streamingPayload.generationConfig) {
streamingPayload.generationConfig = {}
}
streamingPayload.generationConfig.responseMimeType = 'application/json'
streamingPayload.generationConfig.responseSchema = cleanSchema
logger.info(
'Using structured output for final streaming response after tool execution'
)
}
const streamingResponse = await fetch(
`https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:streamGenerateContent?key=${request.apiKey}`,
{
@@ -742,15 +768,10 @@ export const googleProvider: ProviderConfig = {
)
}
// Create a stream from the response
const stream = createReadableStreamFromGeminiStream(streamingResponse)
// Calculate timing information
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
modelTime += thisModelTime
// Add to time segments
timeSegments.push({
type: 'model',
name: 'Final streaming response after tool calls',
@@ -759,9 +780,8 @@ export const googleProvider: ProviderConfig = {
duration: thisModelTime,
})
// Return a streaming execution with tool call information
const streamingExecution: StreamingExecution = {
stream,
stream: null as any,
execution: {
success: true,
output: {
@@ -786,7 +806,6 @@ export const googleProvider: ProviderConfig = {
iterations: iterationCount + 1,
timeSegments,
},
// Cost will be calculated in logger
},
logs: [],
metadata: {
@@ -798,25 +817,55 @@ export const googleProvider: ProviderConfig = {
},
}
streamingExecution.stream = createReadableStreamFromGeminiStream(
streamingResponse,
(content, usage) => {
streamingExecution.execution.output.content = content
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
if (streamingExecution.execution.output.providerTiming) {
streamingExecution.execution.output.providerTiming.endTime =
streamEndTimeISO
streamingExecution.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
}
if (usage) {
const existingTokens = streamingExecution.execution.output.tokens || {
prompt: 0,
completion: 0,
total: 0,
}
streamingExecution.execution.output.tokens = {
prompt: (existingTokens.prompt || 0) + (usage.promptTokenCount || 0),
completion:
(existingTokens.completion || 0) + (usage.candidatesTokenCount || 0),
total:
(existingTokens.total || 0) +
(usage.totalTokenCount ||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0)),
}
}
}
)
return streamingExecution
}
// Make the next request for non-streaming response
const nextPayload = {
...payload,
contents: simplifiedMessages,
}
// Check if we should remove tools and enable structured output for final response
const allForcedToolsUsed =
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
if (allForcedToolsUsed && request.responseFormat) {
// All forced tools have been used, we can now remove tools and enable structured output
nextPayload.tools = undefined
nextPayload.toolConfig = undefined
// Add structured output format for final response
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
@@ -831,7 +880,6 @@ export const googleProvider: ProviderConfig = {
'Using structured output for final non-streaming response after tool execution'
)
} else {
// Add updated tool configuration if available
if (currentToolConfig) {
nextPayload.toolConfig = currentToolConfig
}
@@ -864,7 +912,6 @@ export const googleProvider: ProviderConfig = {
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
// Add to time segments
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
@@ -873,15 +920,65 @@ export const googleProvider: ProviderConfig = {
duration: thisModelTime,
})
// Add to model time
modelTime += thisModelTime
// Check if we need to continue or break
const nextCandidate = geminiResponse.candidates?.[0]
const nextFunctionCall = extractFunctionCall(nextCandidate)
if (!nextFunctionCall) {
content = extractTextContent(nextCandidate)
// If responseFormat is specified, make one final request with structured output
if (request.responseFormat) {
const finalPayload = {
...payload,
contents: nextPayload.contents,
tools: undefined,
toolConfig: undefined,
}
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!finalPayload.generationConfig) {
finalPayload.generationConfig = {}
}
finalPayload.generationConfig.responseMimeType = 'application/json'
finalPayload.generationConfig.responseSchema = cleanSchema
logger.info('Making final request with structured output after tool execution')
const finalResponse = await fetch(
`https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:generateContent?key=${request.apiKey}`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(finalPayload),
}
)
if (finalResponse.ok) {
const finalResult = await finalResponse.json()
const finalCandidate = finalResult.candidates?.[0]
content = extractTextContent(finalCandidate)
if (finalResult.usageMetadata) {
tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0
tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0
tokens.total +=
(finalResult.usageMetadata.promptTokenCount || 0) +
(finalResult.usageMetadata.candidatesTokenCount || 0)
}
} else {
logger.warn(
'Failed to get structured output, falling back to regular response'
)
content = extractTextContent(nextCandidate)
}
} else {
content = extractTextContent(nextCandidate)
}
break
}
@@ -902,7 +999,6 @@ export const googleProvider: ProviderConfig = {
}
}
} else {
// Regular text response
content = extractTextContent(candidate)
}
} catch (error) {
@@ -911,18 +1007,15 @@ export const googleProvider: ProviderConfig = {
iterationCount,
})
// Don't rethrow, so we can still return partial results
if (!content && toolCalls.length > 0) {
content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.`
}
}
// Calculate overall timing
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
// Extract token usage if available
if (geminiResponse.usageMetadata) {
tokens = {
prompt: geminiResponse.usageMetadata.promptTokenCount || 0,
@@ -949,10 +1042,8 @@ export const googleProvider: ProviderConfig = {
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
// Cost will be calculated in logger
}
} catch (error) {
// Include timing information even for errors
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -962,7 +1053,6 @@ export const googleProvider: ProviderConfig = {
duration: totalDuration,
})
// Create a new error with timing information
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
enhancedError.timing = {
@@ -975,200 +1065,3 @@ export const googleProvider: ProviderConfig = {
}
},
}
/**
* Helper function to remove additionalProperties from a schema object
* and perform a deep copy of the schema to avoid modifying the original
*/
function cleanSchemaForGemini(schema: any): any {
// Handle base cases
if (schema === null || schema === undefined) return schema
if (typeof schema !== 'object') return schema
if (Array.isArray(schema)) {
return schema.map((item) => cleanSchemaForGemini(item))
}
// Create a new object for the deep copy
const cleanedSchema: any = {}
// Process each property in the schema
for (const key in schema) {
// Skip additionalProperties
if (key === 'additionalProperties') continue
// Deep copy nested objects
cleanedSchema[key] = cleanSchemaForGemini(schema[key])
}
return cleanedSchema
}
/**
* Helper function to extract content from a Gemini response, handling structured output
*/
function extractTextContent(candidate: any): string {
if (!candidate?.content?.parts) return ''
// Check for JSON response (typically from structured output)
if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) {
const text = candidate.content.parts[0].text
if (text && (text.trim().startsWith('{') || text.trim().startsWith('['))) {
try {
JSON.parse(text) // Validate JSON
return text // Return valid JSON as-is
} catch (_e) {
/* Not valid JSON, continue with normal extraction */
}
}
}
// Standard text extraction
return candidate.content.parts
.filter((part: any) => part.text)
.map((part: any) => part.text)
.join('\n')
}
/**
* Helper function to extract a function call from a Gemini response
*/
function extractFunctionCall(candidate: any): { name: string; args: any } | null {
if (!candidate?.content?.parts) return null
// Check for functionCall in parts
for (const part of candidate.content.parts) {
if (part.functionCall) {
const args = part.functionCall.args || {}
// Parse string args if they look like JSON
if (
typeof part.functionCall.args === 'string' &&
part.functionCall.args.trim().startsWith('{')
) {
try {
return { name: part.functionCall.name, args: JSON.parse(part.functionCall.args) }
} catch (_e) {
return { name: part.functionCall.name, args: part.functionCall.args }
}
}
return { name: part.functionCall.name, args }
}
}
// Check for alternative function_call format
if (candidate.content.function_call) {
const args =
typeof candidate.content.function_call.arguments === 'string'
? JSON.parse(candidate.content.function_call.arguments || '{}')
: candidate.content.function_call.arguments || {}
return { name: candidate.content.function_call.name, args }
}
return null
}
/**
* Convert OpenAI-style request format to Gemini format
*/
function convertToGeminiFormat(request: ProviderRequest): {
contents: any[]
tools: any[] | undefined
systemInstruction: any | undefined
} {
const contents = []
let systemInstruction
// Handle system prompt
if (request.systemPrompt) {
systemInstruction = { parts: [{ text: request.systemPrompt }] }
}
// Add context as user message if present
if (request.context) {
contents.push({ role: 'user', parts: [{ text: request.context }] })
}
// Process messages
if (request.messages && request.messages.length > 0) {
for (const message of request.messages) {
if (message.role === 'system') {
// Add to system instruction
if (!systemInstruction) {
systemInstruction = { parts: [{ text: message.content }] }
} else {
// Append to existing system instruction
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text || ''}\n${message.content}`
}
} else if (message.role === 'user' || message.role === 'assistant') {
// Convert to Gemini role format
const geminiRole = message.role === 'user' ? 'user' : 'model'
// Add text content
if (message.content) {
contents.push({ role: geminiRole, parts: [{ text: message.content }] })
}
// Handle tool calls
if (message.role === 'assistant' && message.tool_calls && message.tool_calls.length > 0) {
const functionCalls = message.tool_calls.map((toolCall) => ({
functionCall: {
name: toolCall.function?.name,
args: JSON.parse(toolCall.function?.arguments || '{}'),
},
}))
contents.push({ role: 'model', parts: functionCalls })
}
} else if (message.role === 'tool') {
// Convert tool response (Gemini only accepts user/model roles)
contents.push({
role: 'user',
parts: [{ text: `Function result: ${message.content}` }],
})
}
}
}
// Convert tools to Gemini function declarations
const tools = request.tools?.map((tool) => {
const toolParameters = { ...(tool.parameters || {}) }
// Process schema properties
if (toolParameters.properties) {
const properties = { ...toolParameters.properties }
const required = toolParameters.required ? [...toolParameters.required] : []
// Remove defaults and optional parameters
for (const key in properties) {
const prop = properties[key] as any
if (prop.default !== undefined) {
const { default: _, ...cleanProp } = prop
properties[key] = cleanProp
}
}
// Build Gemini-compatible parameters schema
const parameters = {
type: toolParameters.type || 'object',
properties,
...(required.length > 0 ? { required } : {}),
}
// Clean schema for Gemini
return {
name: tool.id,
description: tool.description || `Execute the ${tool.id} function`,
parameters: cleanSchemaForGemini(parameters),
}
}
// Simple schema case
return {
name: tool.id,
description: tool.description || `Execute the ${tool.id} function`,
parameters: cleanSchemaForGemini(toolParameters),
}
})
return { contents, tools, systemInstruction }
}

View File

@@ -0,0 +1,171 @@
import type { ProviderRequest } from '@/providers/types'
/**
* Removes additionalProperties from a schema object (not supported by Gemini)
*/
export function cleanSchemaForGemini(schema: any): any {
if (schema === null || schema === undefined) return schema
if (typeof schema !== 'object') return schema
if (Array.isArray(schema)) {
return schema.map((item) => cleanSchemaForGemini(item))
}
const cleanedSchema: any = {}
for (const key in schema) {
if (key === 'additionalProperties') continue
cleanedSchema[key] = cleanSchemaForGemini(schema[key])
}
return cleanedSchema
}
/**
* Extracts text content from a Gemini response candidate, handling structured output
*/
export function extractTextContent(candidate: any): string {
if (!candidate?.content?.parts) return ''
if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) {
const text = candidate.content.parts[0].text
if (text && (text.trim().startsWith('{') || text.trim().startsWith('['))) {
try {
JSON.parse(text)
return text
} catch (_e) {
/* Not valid JSON, continue with normal extraction */
}
}
}
return candidate.content.parts
.filter((part: any) => part.text)
.map((part: any) => part.text)
.join('\n')
}
/**
* Extracts a function call from a Gemini response candidate
*/
export function extractFunctionCall(candidate: any): { name: string; args: any } | null {
if (!candidate?.content?.parts) return null
for (const part of candidate.content.parts) {
if (part.functionCall) {
const args = part.functionCall.args || {}
if (
typeof part.functionCall.args === 'string' &&
part.functionCall.args.trim().startsWith('{')
) {
try {
return { name: part.functionCall.name, args: JSON.parse(part.functionCall.args) }
} catch (_e) {
return { name: part.functionCall.name, args: part.functionCall.args }
}
}
return { name: part.functionCall.name, args }
}
}
if (candidate.content.function_call) {
const args =
typeof candidate.content.function_call.arguments === 'string'
? JSON.parse(candidate.content.function_call.arguments || '{}')
: candidate.content.function_call.arguments || {}
return { name: candidate.content.function_call.name, args }
}
return null
}
/**
* Converts OpenAI-style request format to Gemini format
*/
export function convertToGeminiFormat(request: ProviderRequest): {
contents: any[]
tools: any[] | undefined
systemInstruction: any | undefined
} {
const contents: any[] = []
let systemInstruction
if (request.systemPrompt) {
systemInstruction = { parts: [{ text: request.systemPrompt }] }
}
if (request.context) {
contents.push({ role: 'user', parts: [{ text: request.context }] })
}
if (request.messages && request.messages.length > 0) {
for (const message of request.messages) {
if (message.role === 'system') {
if (!systemInstruction) {
systemInstruction = { parts: [{ text: message.content }] }
} else {
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text || ''}\n${message.content}`
}
} else if (message.role === 'user' || message.role === 'assistant') {
const geminiRole = message.role === 'user' ? 'user' : 'model'
if (message.content) {
contents.push({ role: geminiRole, parts: [{ text: message.content }] })
}
if (message.role === 'assistant' && message.tool_calls && message.tool_calls.length > 0) {
const functionCalls = message.tool_calls.map((toolCall) => ({
functionCall: {
name: toolCall.function?.name,
args: JSON.parse(toolCall.function?.arguments || '{}'),
},
}))
contents.push({ role: 'model', parts: functionCalls })
}
} else if (message.role === 'tool') {
contents.push({
role: 'user',
parts: [{ text: `Function result: ${message.content}` }],
})
}
}
}
const tools = request.tools?.map((tool) => {
const toolParameters = { ...(tool.parameters || {}) }
if (toolParameters.properties) {
const properties = { ...toolParameters.properties }
const required = toolParameters.required ? [...toolParameters.required] : []
for (const key in properties) {
const prop = properties[key] as any
if (prop.default !== undefined) {
const { default: _, ...cleanProp } = prop
properties[key] = cleanProp
}
}
const parameters = {
type: toolParameters.type || 'object',
properties,
...(required.length > 0 ? { required } : {}),
}
return {
name: tool.id,
description: tool.description || `Execute the ${tool.id} function`,
parameters: cleanSchemaForGemini(parameters),
}
}
return {
name: tool.id,
description: tool.description || `Execute the ${tool.id} function`,
parameters: cleanSchemaForGemini(toolParameters),
}
})
return { contents, tools, systemInstruction }
}

View File

@@ -1,6 +1,8 @@
import { Groq } from 'groq-sdk'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import { createReadableStreamFromGroqStream } from '@/providers/groq/utils'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
@@ -17,27 +19,6 @@ import { executeTool } from '@/tools'
const logger = createLogger('GroqProvider')
/**
* Helper to wrap Groq streaming into a browser-friendly ReadableStream
* of raw assistant text chunks.
*/
function createReadableStreamFromGroqStream(groqStream: any): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of groqStream) {
if (chunk.choices[0]?.delta?.content) {
controller.enqueue(new TextEncoder().encode(chunk.choices[0].delta.content))
}
}
controller.close()
} catch (err) {
controller.error(err)
}
},
})
}
export const groqProvider: ProviderConfig = {
id: 'groq',
name: 'Groq',
@@ -225,7 +206,6 @@ export const groqProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
// Track time spent in model vs tools
let modelTime = firstResponseTime
@@ -243,7 +223,7 @@ export const groqProvider: ProviderConfig = {
]
try {
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {

View File

@@ -0,0 +1,23 @@
/**
* Helper to wrap Groq streaming into a browser-friendly ReadableStream
* of raw assistant text chunks.
*
* @param groqStream - The Groq streaming response
* @returns A ReadableStream that emits text chunks
*/
export function createReadableStreamFromGroqStream(groqStream: any): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of groqStream) {
if (chunk.choices[0]?.delta?.content) {
controller.enqueue(new TextEncoder().encode(chunk.choices[0].delta.content))
}
}
controller.close()
} catch (err) {
controller.error(err)
}
},
})
}

View File

@@ -12,6 +12,12 @@ import {
const logger = createLogger('Providers')
/**
* Maximum number of iterations for tool call loops to prevent infinite loops.
* Used across all providers that support tool/function calling.
*/
export const MAX_TOOL_ITERATIONS = 20
function sanitizeRequest(request: ProviderRequest): ProviderRequest {
const sanitizedRequest = { ...request }
@@ -44,7 +50,6 @@ export async function executeProviderRequest(
}
const sanitizedRequest = sanitizeRequest(request)
// If responseFormat is provided, modify the system prompt to enforce structured output
if (sanitizedRequest.responseFormat) {
if (
typeof sanitizedRequest.responseFormat === 'string' &&
@@ -53,12 +58,10 @@ export async function executeProviderRequest(
logger.info('Empty response format provided, ignoring it')
sanitizedRequest.responseFormat = undefined
} else {
// Generate structured output instructions
const structuredOutputInstructions = generateStructuredOutputInstructions(
sanitizedRequest.responseFormat
)
// Only add additional instructions if they're not empty
if (structuredOutputInstructions.trim()) {
const originalPrompt = sanitizedRequest.systemPrompt || ''
sanitizedRequest.systemPrompt =
@@ -69,10 +72,8 @@ export async function executeProviderRequest(
}
}
// Execute the request using the provider's implementation
const response = await provider.executeRequest(sanitizedRequest)
// If we received a StreamingExecution or ReadableStream, just pass it through
if (isStreamingExecution(response)) {
logger.info('Provider returned StreamingExecution')
return response

View File

@@ -1,6 +1,8 @@
import OpenAI from 'openai'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import { createReadableStreamFromMistralStream } from '@/providers/mistral/utils'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
@@ -17,40 +19,6 @@ import { executeTool } from '@/tools'
const logger = createLogger('MistralProvider')
function createReadableStreamFromMistralStream(
mistralStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of mistralStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}
/**
* Mistral AI provider configuration
*/
@@ -288,7 +256,6 @@ export const mistralProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10
let modelTime = firstResponseTime
let toolsTime = 0
@@ -307,14 +274,14 @@ export const mistralProvider: ProviderConfig = {
checkForForcedToolUsage(currentResponse, originalToolChoice)
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
logger.info(
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
const toolsStartTime = Date.now()

View File

@@ -0,0 +1,39 @@
/**
* Creates a ReadableStream from a Mistral AI streaming response
* @param mistralStream - The Mistral AI stream object
* @param onComplete - Optional callback when streaming completes
* @returns A ReadableStream that yields text chunks
*/
export function createReadableStreamFromMistralStream(
mistralStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of mistralStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}

View File

@@ -19,6 +19,7 @@ import {
OllamaIcon,
OpenAIIcon,
OpenRouterIcon,
VertexIcon,
VllmIcon,
xAIIcon,
} from '@/components/icons'
@@ -130,7 +131,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
reasoningEffort: {
values: ['none', 'low', 'medium', 'high'],
values: ['none', 'minimal', 'low', 'medium', 'high', 'xhigh'],
},
verbosity: {
values: ['low', 'medium', 'high'],
@@ -283,7 +284,11 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
output: 60,
updatedAt: '2025-06-17',
},
capabilities: {},
capabilities: {
reasoningEffort: {
values: ['low', 'medium', 'high'],
},
},
contextWindow: 200000,
},
{
@@ -294,7 +299,11 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
output: 8,
updatedAt: '2025-06-17',
},
capabilities: {},
capabilities: {
reasoningEffort: {
values: ['low', 'medium', 'high'],
},
},
contextWindow: 128000,
},
{
@@ -305,7 +314,11 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
output: 4.4,
updatedAt: '2025-06-17',
},
capabilities: {},
capabilities: {
reasoningEffort: {
values: ['low', 'medium', 'high'],
},
},
contextWindow: 128000,
},
{
@@ -383,7 +396,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
reasoningEffort: {
values: ['none', 'low', 'medium', 'high'],
values: ['none', 'minimal', 'low', 'medium', 'high', 'xhigh'],
},
verbosity: {
values: ['low', 'medium', 'high'],
@@ -536,7 +549,11 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
output: 40,
updatedAt: '2025-06-15',
},
capabilities: {},
capabilities: {
reasoningEffort: {
values: ['low', 'medium', 'high'],
},
},
contextWindow: 128000,
},
{
@@ -547,7 +564,11 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
output: 4.4,
updatedAt: '2025-06-15',
},
capabilities: {},
capabilities: {
reasoningEffort: {
values: ['low', 'medium', 'high'],
},
},
contextWindow: 128000,
},
{
@@ -708,9 +729,22 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
id: 'gemini-3-pro-preview',
pricing: {
input: 2.0,
cachedInput: 1.0,
cachedInput: 0.2,
output: 12.0,
updatedAt: '2025-11-18',
updatedAt: '2025-12-17',
},
capabilities: {
temperature: { min: 0, max: 2 },
},
contextWindow: 1000000,
},
{
id: 'gemini-3-flash-preview',
pricing: {
input: 0.5,
cachedInput: 0.05,
output: 3.0,
updatedAt: '2025-12-17',
},
capabilities: {
temperature: { min: 0, max: 2 },
@@ -756,6 +790,132 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
contextWindow: 1048576,
},
{
id: 'gemini-2.0-flash',
pricing: {
input: 0.1,
output: 0.4,
updatedAt: '2025-12-17',
},
capabilities: {
temperature: { min: 0, max: 2 },
},
contextWindow: 1000000,
},
{
id: 'gemini-2.0-flash-lite',
pricing: {
input: 0.075,
output: 0.3,
updatedAt: '2025-12-17',
},
capabilities: {
temperature: { min: 0, max: 2 },
},
contextWindow: 1000000,
},
],
},
vertex: {
id: 'vertex',
name: 'Vertex AI',
description: "Google's Vertex AI platform for Gemini models",
defaultModel: 'vertex/gemini-2.5-pro',
modelPatterns: [/^vertex\//],
icon: VertexIcon,
capabilities: {
toolUsageControl: true,
},
models: [
{
id: 'vertex/gemini-3-pro-preview',
pricing: {
input: 2.0,
cachedInput: 0.2,
output: 12.0,
updatedAt: '2025-12-17',
},
capabilities: {
temperature: { min: 0, max: 2 },
},
contextWindow: 1000000,
},
{
id: 'vertex/gemini-3-flash-preview',
pricing: {
input: 0.5,
cachedInput: 0.05,
output: 3.0,
updatedAt: '2025-12-17',
},
capabilities: {
temperature: { min: 0, max: 2 },
},
contextWindow: 1000000,
},
{
id: 'vertex/gemini-2.5-pro',
pricing: {
input: 1.25,
cachedInput: 0.125,
output: 10.0,
updatedAt: '2025-12-02',
},
capabilities: {
temperature: { min: 0, max: 2 },
},
contextWindow: 1048576,
},
{
id: 'vertex/gemini-2.5-flash',
pricing: {
input: 0.3,
cachedInput: 0.03,
output: 2.5,
updatedAt: '2025-12-02',
},
capabilities: {
temperature: { min: 0, max: 2 },
},
contextWindow: 1048576,
},
{
id: 'vertex/gemini-2.5-flash-lite',
pricing: {
input: 0.1,
cachedInput: 0.01,
output: 0.4,
updatedAt: '2025-12-02',
},
capabilities: {
temperature: { min: 0, max: 2 },
},
contextWindow: 1048576,
},
{
id: 'vertex/gemini-2.0-flash',
pricing: {
input: 0.1,
output: 0.4,
updatedAt: '2025-12-17',
},
capabilities: {
temperature: { min: 0, max: 2 },
},
contextWindow: 1000000,
},
{
id: 'vertex/gemini-2.0-flash-lite',
pricing: {
input: 0.075,
output: 0.3,
updatedAt: '2025-12-17',
},
capabilities: {
temperature: { min: 0, max: 2 },
},
contextWindow: 1000000,
},
],
},
deepseek: {
@@ -1708,6 +1868,20 @@ export function getModelsWithReasoningEffort(): string[] {
return models
}
/**
* Get the reasoning effort values for a specific model
* Returns the valid options for that model, or null if the model doesn't support reasoning effort
*/
export function getReasoningEffortValuesForModel(modelId: string): string[] | null {
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase())
if (model?.capabilities.reasoningEffort) {
return model.capabilities.reasoningEffort.values
}
}
return null
}
/**
* Get all models that support verbosity
*/
@@ -1722,3 +1896,17 @@ export function getModelsWithVerbosity(): string[] {
}
return models
}
/**
* Get the verbosity values for a specific model
* Returns the valid options for that model, or null if the model doesn't support verbosity
*/
export function getVerbosityValuesForModel(modelId: string): string[] | null {
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase())
if (model?.capabilities.verbosity) {
return model.capabilities.verbosity.values
}
}
return null
}

View File

@@ -2,7 +2,9 @@ import OpenAI from 'openai'
import { env } from '@/lib/core/config/env'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import type { ModelsObject } from '@/providers/ollama/types'
import { createReadableStreamFromOllamaStream } from '@/providers/ollama/utils'
import type {
ProviderConfig,
ProviderRequest,
@@ -16,46 +18,6 @@ import { executeTool } from '@/tools'
const logger = createLogger('OllamaProvider')
const OLLAMA_HOST = env.OLLAMA_URL || 'http://localhost:11434'
/**
* Helper function to convert an Ollama stream to a standard ReadableStream
* and collect completion metrics
*/
function createReadableStreamFromOllamaStream(
ollamaStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of ollamaStream) {
// Check for usage data in the final chunk
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
// Once stream is complete, call the completion callback with the final content and usage
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}
export const ollamaProvider: ProviderConfig = {
id: 'ollama',
name: 'Ollama',
@@ -334,7 +296,6 @@ export const ollamaProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
// Track time spent in model vs tools
let modelTime = firstResponseTime
@@ -351,7 +312,7 @@ export const ollamaProvider: ProviderConfig = {
},
]
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
@@ -359,7 +320,7 @@ export const ollamaProvider: ProviderConfig = {
}
logger.info(
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
// Track time for tool calls in this batch

View File

@@ -0,0 +1,37 @@
/**
* Helper function to convert an Ollama stream to a standard ReadableStream
* and collect completion metrics
*/
export function createReadableStreamFromOllamaStream(
ollamaStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of ollamaStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}

View File

@@ -1,7 +1,9 @@
import OpenAI from 'openai'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import { createReadableStreamFromOpenAIStream } from '@/providers/openai/utils'
import type {
ProviderConfig,
ProviderRequest,
@@ -17,46 +19,6 @@ import { executeTool } from '@/tools'
const logger = createLogger('OpenAIProvider')
/**
* Helper function to convert an OpenAI stream to a standard ReadableStream
* and collect completion metrics
*/
function createReadableStreamFromOpenAIStream(
openaiStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of openaiStream) {
// Check for usage data in the final chunk
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
// Once stream is complete, call the completion callback with the final content and usage
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}
/**
* OpenAI provider configuration
*/
@@ -319,7 +281,6 @@ export const openaiProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10 // Prevent infinite loops
// Track time spent in model vs tools
let modelTime = firstResponseTime
@@ -342,7 +303,7 @@ export const openaiProvider: ProviderConfig = {
// Check if a forced tool was used in the first response
checkForForcedToolUsage(currentResponse, originalToolChoice)
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
@@ -350,7 +311,7 @@ export const openaiProvider: ProviderConfig = {
}
logger.info(
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
// Track time for tool calls in this batch

View File

@@ -0,0 +1,37 @@
/**
* Helper function to convert an OpenAI stream to a standard ReadableStream
* and collect completion metrics
*/
export function createReadableStreamFromOpenAIStream(
openaiStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of openaiStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}

View File

@@ -1,56 +1,23 @@
import OpenAI from 'openai'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import {
checkForForcedToolUsage,
createReadableStreamFromOpenAIStream,
} from '@/providers/openrouter/utils'
import type {
ProviderConfig,
ProviderRequest,
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import {
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
} from '@/providers/utils'
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
import { executeTool } from '@/tools'
const logger = createLogger('OpenRouterProvider')
function createReadableStreamFromOpenAIStream(
openaiStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of openaiStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}
export const openRouterProvider: ProviderConfig = {
id: 'openrouter',
name: 'OpenRouter',
@@ -227,7 +194,6 @@ export const openRouterProvider: ProviderConfig = {
const toolResults = [] as any[]
const currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10
let modelTime = firstResponseTime
let toolsTime = 0
let hasUsedForcedTool = false
@@ -241,28 +207,16 @@ export const openRouterProvider: ProviderConfig = {
},
]
const checkForForcedToolUsage = (
response: any,
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }
) => {
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
const toolCallsResponse = response.choices[0].message.tool_calls
const result = trackForcedToolUsage(
toolCallsResponse,
toolChoice,
logger,
'openrouter',
forcedTools,
usedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
}
}
const forcedToolResult = checkForForcedToolUsage(
currentResponse,
originalToolChoice,
forcedTools,
usedForcedTools
)
hasUsedForcedTool = forcedToolResult.hasUsedForcedTool
usedForcedTools = forcedToolResult.usedForcedTools
checkForForcedToolUsage(currentResponse, originalToolChoice)
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
@@ -359,7 +313,14 @@ export const openRouterProvider: ProviderConfig = {
const nextModelStartTime = Date.now()
currentResponse = await client.chat.completions.create(nextPayload)
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
const nextForcedToolResult = checkForForcedToolUsage(
currentResponse,
nextPayload.tool_choice,
forcedTools,
usedForcedTools
)
hasUsedForcedTool = nextForcedToolResult.hasUsedForcedTool
usedForcedTools = nextForcedToolResult.usedForcedTools
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
timeSegments.push({

View File

@@ -0,0 +1,78 @@
import { createLogger } from '@/lib/logs/console/logger'
import { trackForcedToolUsage } from '@/providers/utils'
const logger = createLogger('OpenRouterProvider')
/**
* Creates a ReadableStream from an OpenAI-compatible stream response
* @param openaiStream - The OpenAI stream to convert
* @param onComplete - Optional callback when streaming is complete with content and usage data
* @returns ReadableStream that emits text chunks
*/
export function createReadableStreamFromOpenAIStream(
openaiStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of openaiStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}
/**
* Checks if a forced tool was used in the response and updates tracking
* @param response - The API response containing tool calls
* @param toolChoice - The tool choice configuration (string or object)
* @param forcedTools - Array of forced tool names
* @param usedForcedTools - Array of already used forced tools
* @returns Object with hasUsedForcedTool flag and updated usedForcedTools array
*/
export function checkForForcedToolUsage(
response: any,
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
forcedTools: string[],
usedForcedTools: string[]
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
let hasUsedForcedTool = false
let updatedUsedForcedTools = usedForcedTools
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
const toolCallsResponse = response.choices[0].message.tool_calls
const result = trackForcedToolUsage(
toolCallsResponse,
toolChoice,
logger,
'openrouter',
forcedTools,
updatedUsedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
updatedUsedForcedTools = result.usedForcedTools
}
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
}

View File

@@ -5,6 +5,7 @@ export type ProviderId =
| 'azure-openai'
| 'anthropic'
| 'google'
| 'vertex'
| 'deepseek'
| 'xai'
| 'cerebras'
@@ -163,6 +164,9 @@ export interface ProviderRequest {
// Azure OpenAI specific parameters
azureEndpoint?: string
azureApiVersion?: string
// Vertex AI specific parameters
vertexProject?: string
vertexLocation?: string
// GPT-5 specific parameters
reasoningEffort?: string
verbosity?: string

View File

@@ -383,6 +383,17 @@ describe('Model Capabilities', () => {
expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/gpt-5-mini')
expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/gpt-5-nano')
// Should contain gpt-5.2 models
expect(MODELS_WITH_REASONING_EFFORT).toContain('gpt-5.2')
expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/gpt-5.2')
// Should contain o-series reasoning models (reasoning_effort added Dec 17, 2024)
expect(MODELS_WITH_REASONING_EFFORT).toContain('o1')
expect(MODELS_WITH_REASONING_EFFORT).toContain('o3')
expect(MODELS_WITH_REASONING_EFFORT).toContain('o4-mini')
expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/o3')
expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/o4-mini')
// Should NOT contain non-reasoning GPT-5 models
expect(MODELS_WITH_REASONING_EFFORT).not.toContain('gpt-5-chat-latest')
expect(MODELS_WITH_REASONING_EFFORT).not.toContain('azure/gpt-5-chat-latest')
@@ -390,7 +401,6 @@ describe('Model Capabilities', () => {
// Should NOT contain other models
expect(MODELS_WITH_REASONING_EFFORT).not.toContain('gpt-4o')
expect(MODELS_WITH_REASONING_EFFORT).not.toContain('claude-sonnet-4-0')
expect(MODELS_WITH_REASONING_EFFORT).not.toContain('o1')
})
it.concurrent('should have correct models in MODELS_WITH_VERBOSITY', () => {
@@ -409,19 +419,37 @@ describe('Model Capabilities', () => {
expect(MODELS_WITH_VERBOSITY).toContain('azure/gpt-5-mini')
expect(MODELS_WITH_VERBOSITY).toContain('azure/gpt-5-nano')
// Should contain gpt-5.2 models
expect(MODELS_WITH_VERBOSITY).toContain('gpt-5.2')
expect(MODELS_WITH_VERBOSITY).toContain('azure/gpt-5.2')
// Should NOT contain non-reasoning GPT-5 models
expect(MODELS_WITH_VERBOSITY).not.toContain('gpt-5-chat-latest')
expect(MODELS_WITH_VERBOSITY).not.toContain('azure/gpt-5-chat-latest')
// Should NOT contain o-series models (they support reasoning_effort but not verbosity)
expect(MODELS_WITH_VERBOSITY).not.toContain('o1')
expect(MODELS_WITH_VERBOSITY).not.toContain('o3')
expect(MODELS_WITH_VERBOSITY).not.toContain('o4-mini')
// Should NOT contain other models
expect(MODELS_WITH_VERBOSITY).not.toContain('gpt-4o')
expect(MODELS_WITH_VERBOSITY).not.toContain('claude-sonnet-4-0')
expect(MODELS_WITH_VERBOSITY).not.toContain('o1')
})
it.concurrent('should have same models in both reasoning effort and verbosity arrays', () => {
// GPT-5 models that support reasoning effort should also support verbosity and vice versa
expect(MODELS_WITH_REASONING_EFFORT.sort()).toEqual(MODELS_WITH_VERBOSITY.sort())
it.concurrent('should have GPT-5 models in both reasoning effort and verbosity arrays', () => {
// GPT-5 series models support both reasoning effort and verbosity
const gpt5ModelsWithReasoningEffort = MODELS_WITH_REASONING_EFFORT.filter(
(m) => m.includes('gpt-5') && !m.includes('chat-latest')
)
const gpt5ModelsWithVerbosity = MODELS_WITH_VERBOSITY.filter(
(m) => m.includes('gpt-5') && !m.includes('chat-latest')
)
expect(gpt5ModelsWithReasoningEffort.sort()).toEqual(gpt5ModelsWithVerbosity.sort())
// o-series models have reasoning effort but NOT verbosity
expect(MODELS_WITH_REASONING_EFFORT).toContain('o1')
expect(MODELS_WITH_VERBOSITY).not.toContain('o1')
})
})
})

View File

@@ -21,6 +21,8 @@ import {
getModelsWithVerbosity,
getProviderModels as getProviderModelsFromDefinitions,
getProvidersWithToolUsageControl,
getReasoningEffortValuesForModel as getReasoningEffortValuesForModelFromDefinitions,
getVerbosityValuesForModel as getVerbosityValuesForModelFromDefinitions,
PROVIDER_DEFINITIONS,
supportsTemperature as supportsTemperatureFromDefinitions,
supportsToolUsageControl as supportsToolUsageControlFromDefinitions,
@@ -30,6 +32,7 @@ import { ollamaProvider } from '@/providers/ollama'
import { openaiProvider } from '@/providers/openai'
import { openRouterProvider } from '@/providers/openrouter'
import type { ProviderConfig, ProviderId, ProviderToolConfig } from '@/providers/types'
import { vertexProvider } from '@/providers/vertex'
import { vllmProvider } from '@/providers/vllm'
import { xAIProvider } from '@/providers/xai'
import { useCustomToolsStore } from '@/stores/custom-tools/store'
@@ -67,6 +70,11 @@ export const providers: Record<
models: getProviderModelsFromDefinitions('google'),
modelPatterns: PROVIDER_DEFINITIONS.google.modelPatterns,
},
vertex: {
...vertexProvider,
models: getProviderModelsFromDefinitions('vertex'),
modelPatterns: PROVIDER_DEFINITIONS.vertex.modelPatterns,
},
deepseek: {
...deepseekProvider,
models: getProviderModelsFromDefinitions('deepseek'),
@@ -274,16 +282,12 @@ export function getProviderIcon(model: string): React.ComponentType<{ className?
}
export function generateStructuredOutputInstructions(responseFormat: any): string {
// Handle null/undefined input
if (!responseFormat) return ''
// If using the new JSON Schema format, don't add additional instructions
// This is necessary because providers now handle the schema directly
if (responseFormat.schema || (responseFormat.type === 'object' && responseFormat.properties)) {
return ''
}
// Handle legacy format with fields array
if (!responseFormat.fields) return ''
function generateFieldStructure(field: any): string {
@@ -335,10 +339,8 @@ Each metric should be an object containing 'score' (number) and 'reasoning' (str
}
export function extractAndParseJSON(content: string): any {
// First clean up the string
const trimmed = content.trim()
// Find the first '{' and last '}'
const firstBrace = trimmed.indexOf('{')
const lastBrace = trimmed.lastIndexOf('}')
@@ -346,17 +348,15 @@ export function extractAndParseJSON(content: string): any {
throw new Error('No JSON object found in content')
}
// Extract just the JSON part
const jsonStr = trimmed.slice(firstBrace, lastBrace + 1)
try {
return JSON.parse(jsonStr)
} catch (_error) {
// If parsing fails, try to clean up common issues
const cleaned = jsonStr
.replace(/\n/g, ' ') // Remove newlines
.replace(/\s+/g, ' ') // Normalize whitespace
.replace(/,\s*([}\]])/g, '$1') // Remove trailing commas
.replace(/\n/g, ' ')
.replace(/\s+/g, ' ')
.replace(/,\s*([}\]])/g, '$1')
try {
return JSON.parse(cleaned)
@@ -386,10 +386,10 @@ export function transformCustomTool(customTool: any): ProviderToolConfig {
}
return {
id: `custom_${customTool.id}`, // Prefix with 'custom_' to identify custom tools
id: `custom_${customTool.id}`,
name: schema.function.name,
description: schema.function.description || '',
params: {}, // This will be derived from parameters
params: {},
parameters: {
type: schema.function.parameters.type,
properties: schema.function.parameters.properties,
@@ -402,10 +402,8 @@ export function transformCustomTool(customTool: any): ProviderToolConfig {
* Gets all available custom tools as provider tool configs
*/
export function getCustomTools(): ProviderToolConfig[] {
// Get custom tools from the store
const customTools = useCustomToolsStore.getState().getAllTools()
// Transform each custom tool into a provider tool config
return customTools.map(transformCustomTool)
}
@@ -427,20 +425,16 @@ export async function transformBlockTool(
): Promise<ProviderToolConfig | null> {
const { selectedOperation, getAllBlocks, getTool, getToolAsync } = options
// Get the block definition
const blockDef = getAllBlocks().find((b: any) => b.type === block.type)
if (!blockDef) {
logger.warn(`Block definition not found for type: ${block.type}`)
return null
}
// If the block has multiple operations, use the selected one or the first one
let toolId: string | null = null
if ((blockDef.tools?.access?.length || 0) > 1) {
// If we have an operation dropdown in the block and a selected operation
if (selectedOperation && blockDef.tools?.config?.tool) {
// Use the block's tool selection function to get the right tool
try {
toolId = blockDef.tools.config.tool({
...block.params,
@@ -455,11 +449,9 @@ export async function transformBlockTool(
return null
}
} else {
// Default to first tool if no operation specified
toolId = blockDef.tools.access[0]
}
} else {
// Single tool case
toolId = blockDef.tools?.access?.[0] || null
}
@@ -468,14 +460,11 @@ export async function transformBlockTool(
return null
}
// Get the tool config - check if it's a custom tool that needs async fetching
let toolConfig: any
if (toolId.startsWith('custom_') && getToolAsync) {
// Use the async version for custom tools
toolConfig = await getToolAsync(toolId)
} else {
// Use the synchronous version for built-in tools
toolConfig = getTool(toolId)
}
@@ -484,16 +473,12 @@ export async function transformBlockTool(
return null
}
// Import the new tool parameter utilities
const { createLLMToolSchema } = await import('@/tools/params')
// Get user-provided parameters from the block
const userProvidedParams = block.params || {}
// Create LLM schema that excludes user-provided parameters
const llmSchema = await createLLMToolSchema(toolConfig, userProvidedParams)
// Return formatted tool config
return {
id: toolConfig.id,
name: toolConfig.name,
@@ -521,15 +506,12 @@ export function calculateCost(
inputMultiplier?: number,
outputMultiplier?: number
) {
// First check if it's an embedding model
let pricing = getEmbeddingModelPricing(model)
// If not found, check chat models
if (!pricing) {
pricing = getModelPricingFromDefinitions(model)
}
// If no pricing found, return default pricing
if (!pricing) {
const defaultPricing = {
input: 1.0,
@@ -545,8 +527,6 @@ export function calculateCost(
}
}
// Calculate costs in USD
// Convert from "per million tokens" to "per token" by dividing by 1,000,000
const inputCost =
promptTokens *
(useCachedInput && pricing.cachedInput
@@ -559,7 +539,7 @@ export function calculateCost(
const finalTotalCost = finalInputCost + finalOutputCost
return {
input: Number.parseFloat(finalInputCost.toFixed(8)), // Use 8 decimal places for small costs
input: Number.parseFloat(finalInputCost.toFixed(8)),
output: Number.parseFloat(finalOutputCost.toFixed(8)),
total: Number.parseFloat(finalTotalCost.toFixed(8)),
pricing,
@@ -997,6 +977,22 @@ export function supportsToolUsageControl(provider: string): boolean {
return supportsToolUsageControlFromDefinitions(provider)
}
/**
* Get reasoning effort values for a specific model
* Returns the valid options for that model, or null if the model doesn't support reasoning effort
*/
export function getReasoningEffortValuesForModel(model: string): string[] | null {
return getReasoningEffortValuesForModelFromDefinitions(model)
}
/**
* Get verbosity values for a specific model
* Returns the valid options for that model, or null if the model doesn't support verbosity
*/
export function getVerbosityValuesForModel(model: string): string[] | null {
return getVerbosityValuesForModelFromDefinitions(model)
}
/**
* Prepare tool execution parameters, separating tool parameters from system parameters
*/

View File

@@ -0,0 +1,899 @@
import { env } from '@/lib/core/config/env'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import {
cleanSchemaForGemini,
convertToGeminiFormat,
extractFunctionCall,
extractTextContent,
} from '@/providers/google/utils'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
ProviderRequest,
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import {
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
} from '@/providers/utils'
import { buildVertexEndpoint, createReadableStreamFromVertexStream } from '@/providers/vertex/utils'
import { executeTool } from '@/tools'
const logger = createLogger('VertexProvider')
/**
* Vertex AI provider configuration
*/
export const vertexProvider: ProviderConfig = {
id: 'vertex',
name: 'Vertex AI',
description: "Google's Vertex AI platform for Gemini models",
version: '1.0.0',
models: getProviderModels('vertex'),
defaultModel: getProviderDefaultModel('vertex'),
executeRequest: async (
request: ProviderRequest
): Promise<ProviderResponse | StreamingExecution> => {
const vertexProject = env.VERTEX_PROJECT || request.vertexProject
const vertexLocation = env.VERTEX_LOCATION || request.vertexLocation || 'us-central1'
if (!vertexProject) {
throw new Error(
'Vertex AI project is required. Please provide it via VERTEX_PROJECT environment variable or vertexProject parameter.'
)
}
if (!request.apiKey) {
throw new Error(
'Access token is required for Vertex AI. Run `gcloud auth print-access-token` to get one, or use a service account.'
)
}
logger.info('Preparing Vertex AI request', {
model: request.model || 'vertex/gemini-2.5-pro',
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
toolCount: request.tools?.length || 0,
hasResponseFormat: !!request.responseFormat,
streaming: !!request.stream,
project: vertexProject,
location: vertexLocation,
})
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
const requestedModel = (request.model || 'vertex/gemini-2.5-pro').replace('vertex/', '')
const payload: any = {
contents,
generationConfig: {},
}
if (request.temperature !== undefined && request.temperature !== null) {
payload.generationConfig.temperature = request.temperature
}
if (request.maxTokens !== undefined) {
payload.generationConfig.maxOutputTokens = request.maxTokens
}
if (systemInstruction) {
payload.systemInstruction = systemInstruction
}
if (request.responseFormat && !tools?.length) {
const responseFormatSchema = request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
payload.generationConfig.responseMimeType = 'application/json'
payload.generationConfig.responseSchema = cleanSchema
logger.info('Using Vertex AI native structured output format', {
hasSchema: !!cleanSchema,
mimeType: 'application/json',
})
} else if (request.responseFormat && tools?.length) {
logger.warn(
'Vertex AI does not support structured output (responseFormat) with function calling (tools). Structured output will be ignored.'
)
}
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
if (tools?.length) {
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'google')
const { tools: filteredTools, toolConfig } = preparedTools
if (filteredTools?.length) {
payload.tools = [
{
functionDeclarations: filteredTools,
},
]
if (toolConfig) {
payload.toolConfig = toolConfig
}
logger.info('Vertex AI request with tools:', {
toolCount: filteredTools.length,
model: requestedModel,
tools: filteredTools.map((t) => t.name),
hasToolConfig: !!toolConfig,
toolConfig: toolConfig,
})
}
}
const initialCallTime = Date.now()
const shouldStream = !!(request.stream && !tools?.length)
const endpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
shouldStream
)
if (request.stream && tools?.length) {
logger.info('Streaming disabled for initial request due to tools presence', {
toolCount: tools.length,
willStreamAfterTools: true,
})
}
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(payload),
})
if (!response.ok) {
const responseText = await response.text()
logger.error('Vertex AI API error details:', {
status: response.status,
statusText: response.statusText,
responseBody: responseText,
})
throw new Error(`Vertex AI API error: ${response.status} ${response.statusText}`)
}
const firstResponseTime = Date.now() - initialCallTime
if (shouldStream) {
logger.info('Handling Vertex AI streaming response')
const streamingResult: StreamingExecution = {
stream: null as any,
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens: {
prompt: 0,
completion: 0,
total: 0,
},
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: firstResponseTime,
modelTime: firstResponseTime,
toolsTime: 0,
firstResponseTime,
iterations: 1,
timeSegments: [
{
type: 'model',
name: 'Initial streaming response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
],
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: firstResponseTime,
},
isStreaming: true,
},
}
streamingResult.stream = createReadableStreamFromVertexStream(
response,
(content, usage) => {
streamingResult.execution.output.content = content
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
if (streamingResult.execution.output.providerTiming) {
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
streamingResult.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
streamEndTime
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
streamEndTime - providerStartTime
}
}
if (usage) {
streamingResult.execution.output.tokens = {
prompt: usage.promptTokenCount || 0,
completion: usage.candidatesTokenCount || 0,
total:
usage.totalTokenCount ||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0),
}
}
}
)
return streamingResult
}
let geminiResponse = await response.json()
if (payload.generationConfig?.responseSchema) {
const candidate = geminiResponse.candidates?.[0]
if (candidate?.content?.parts?.[0]?.text) {
const text = candidate.content.parts[0].text
try {
JSON.parse(text)
logger.info('Successfully received structured JSON output')
} catch (_e) {
logger.warn('Failed to parse structured output as JSON')
}
}
}
let content = ''
let tokens = {
prompt: 0,
completion: 0,
total: 0,
}
const toolCalls = []
const toolResults = []
let iterationCount = 0
const originalToolConfig = preparedTools?.toolConfig
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
let hasUsedForcedTool = false
let currentToolConfig = originalToolConfig
const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => {
if (currentToolConfig && forcedTools.length > 0) {
const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }]
const result = trackForcedToolUsage(
toolCallsForTracking,
currentToolConfig,
logger,
'google',
forcedTools,
usedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
if (result.nextToolConfig) {
currentToolConfig = result.nextToolConfig
logger.info('Updated tool config for next iteration', {
hasNextToolConfig: !!currentToolConfig,
usedForcedTools: usedForcedTools,
})
}
}
}
let modelTime = firstResponseTime
let toolsTime = 0
const timeSegments: TimeSegment[] = [
{
type: 'model',
name: 'Initial response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
]
try {
const candidate = geminiResponse.candidates?.[0]
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
logger.warn(
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
{
finishReason: candidate.finishReason,
hasContent: !!candidate?.content,
hasParts: !!candidate?.content?.parts,
}
)
content = extractTextContent(candidate)
}
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.info(`Received function call from Vertex AI: ${functionCall.name}`)
while (iterationCount < MAX_TOOL_ITERATIONS) {
const latestResponse = geminiResponse.candidates?.[0]
const latestFunctionCall = extractFunctionCall(latestResponse)
if (!latestFunctionCall) {
content = extractTextContent(latestResponse)
break
}
logger.info(
`Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
const toolsStartTime = Date.now()
try {
const toolName = latestFunctionCall.name
const toolArgs = latestFunctionCall.args || {}
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) {
logger.warn(`Tool ${toolName} not found in registry, skipping`)
break
}
const toolCallStartTime = Date.now()
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
})
let resultContent: any
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
} else {
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
tool: toolName,
}
}
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
result: resultContent,
success: result.success,
})
const simplifiedMessages = [
...(contents.filter((m) => m.role === 'user').length > 0
? [contents.filter((m) => m.role === 'user')[0]]
: [contents[0]]),
{
role: 'model',
parts: [
{
functionCall: {
name: latestFunctionCall.name,
args: latestFunctionCall.args,
},
},
],
},
{
role: 'user',
parts: [
{
text: `Function ${latestFunctionCall.name} result: ${JSON.stringify(resultContent)}`,
},
],
},
]
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
checkForForcedToolUsage(latestFunctionCall)
const nextModelStartTime = Date.now()
try {
if (request.stream) {
const streamingPayload = {
...payload,
contents: simplifiedMessages,
}
const allForcedToolsUsed =
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
if (allForcedToolsUsed && request.responseFormat) {
streamingPayload.tools = undefined
streamingPayload.toolConfig = undefined
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!streamingPayload.generationConfig) {
streamingPayload.generationConfig = {}
}
streamingPayload.generationConfig.responseMimeType = 'application/json'
streamingPayload.generationConfig.responseSchema = cleanSchema
logger.info('Using structured output for final response after tool execution')
} else {
if (currentToolConfig) {
streamingPayload.toolConfig = currentToolConfig
} else {
streamingPayload.toolConfig = { functionCallingConfig: { mode: 'AUTO' } }
}
}
const checkPayload = {
...streamingPayload,
}
checkPayload.stream = undefined
const checkEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
false
)
const checkResponse = await fetch(checkEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(checkPayload),
})
if (!checkResponse.ok) {
const errorBody = await checkResponse.text()
logger.error('Error in Vertex AI check request:', {
status: checkResponse.status,
statusText: checkResponse.statusText,
responseBody: errorBody,
})
throw new Error(
`Vertex AI API check error: ${checkResponse.status} ${checkResponse.statusText}`
)
}
const checkResult = await checkResponse.json()
const checkCandidate = checkResult.candidates?.[0]
const checkFunctionCall = extractFunctionCall(checkCandidate)
if (checkFunctionCall) {
logger.info(
'Function call detected in follow-up, handling in non-streaming mode',
{
functionName: checkFunctionCall.name,
}
)
geminiResponse = checkResult
if (checkResult.usageMetadata) {
tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0
tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0
tokens.total +=
(checkResult.usageMetadata.promptTokenCount || 0) +
(checkResult.usageMetadata.candidatesTokenCount || 0)
}
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
modelTime += thisModelTime
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
startTime: nextModelStartTime,
endTime: nextModelEndTime,
duration: thisModelTime,
})
iterationCount++
continue
}
logger.info('No function call detected, proceeding with streaming response')
// Apply structured output for the final response if responseFormat is specified
// This works regardless of whether tools were forced or auto
if (request.responseFormat) {
streamingPayload.tools = undefined
streamingPayload.toolConfig = undefined
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!streamingPayload.generationConfig) {
streamingPayload.generationConfig = {}
}
streamingPayload.generationConfig.responseMimeType = 'application/json'
streamingPayload.generationConfig.responseSchema = cleanSchema
logger.info(
'Using structured output for final streaming response after tool execution'
)
}
const streamEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
true
)
const streamingResponse = await fetch(streamEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(streamingPayload),
})
if (!streamingResponse.ok) {
const errorBody = await streamingResponse.text()
logger.error('Error in Vertex AI streaming follow-up request:', {
status: streamingResponse.status,
statusText: streamingResponse.statusText,
responseBody: errorBody,
})
throw new Error(
`Vertex AI API streaming error: ${streamingResponse.status} ${streamingResponse.statusText}`
)
}
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
modelTime += thisModelTime
timeSegments.push({
type: 'model',
name: 'Final streaming response after tool calls',
startTime: nextModelStartTime,
endTime: nextModelEndTime,
duration: thisModelTime,
})
const streamingExecution: StreamingExecution = {
stream: null as any,
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens,
toolCalls:
toolCalls.length > 0
? {
list: toolCalls,
count: toolCalls.length,
}
: undefined,
toolResults,
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
modelTime,
toolsTime,
firstResponseTime,
iterations: iterationCount + 1,
timeSegments,
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
},
isStreaming: true,
},
}
streamingExecution.stream = createReadableStreamFromVertexStream(
streamingResponse,
(content, usage) => {
streamingExecution.execution.output.content = content
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
if (streamingExecution.execution.output.providerTiming) {
streamingExecution.execution.output.providerTiming.endTime =
streamEndTimeISO
streamingExecution.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
}
if (usage) {
const existingTokens = streamingExecution.execution.output.tokens || {
prompt: 0,
completion: 0,
total: 0,
}
streamingExecution.execution.output.tokens = {
prompt: (existingTokens.prompt || 0) + (usage.promptTokenCount || 0),
completion:
(existingTokens.completion || 0) + (usage.candidatesTokenCount || 0),
total:
(existingTokens.total || 0) +
(usage.totalTokenCount ||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0)),
}
}
}
)
return streamingExecution
}
const nextPayload = {
...payload,
contents: simplifiedMessages,
}
const allForcedToolsUsed =
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
if (allForcedToolsUsed && request.responseFormat) {
nextPayload.tools = undefined
nextPayload.toolConfig = undefined
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!nextPayload.generationConfig) {
nextPayload.generationConfig = {}
}
nextPayload.generationConfig.responseMimeType = 'application/json'
nextPayload.generationConfig.responseSchema = cleanSchema
logger.info(
'Using structured output for final non-streaming response after tool execution'
)
} else {
if (currentToolConfig) {
nextPayload.toolConfig = currentToolConfig
}
}
const nextEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
false
)
const nextResponse = await fetch(nextEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(nextPayload),
})
if (!nextResponse.ok) {
const errorBody = await nextResponse.text()
logger.error('Error in Vertex AI follow-up request:', {
status: nextResponse.status,
statusText: nextResponse.statusText,
responseBody: errorBody,
iterationCount,
})
break
}
geminiResponse = await nextResponse.json()
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
startTime: nextModelStartTime,
endTime: nextModelEndTime,
duration: thisModelTime,
})
modelTime += thisModelTime
const nextCandidate = geminiResponse.candidates?.[0]
const nextFunctionCall = extractFunctionCall(nextCandidate)
if (!nextFunctionCall) {
// If responseFormat is specified, make one final request with structured output
if (request.responseFormat) {
const finalPayload = {
...payload,
contents: nextPayload.contents,
tools: undefined,
toolConfig: undefined,
}
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!finalPayload.generationConfig) {
finalPayload.generationConfig = {}
}
finalPayload.generationConfig.responseMimeType = 'application/json'
finalPayload.generationConfig.responseSchema = cleanSchema
logger.info('Making final request with structured output after tool execution')
const finalEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
false
)
const finalResponse = await fetch(finalEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(finalPayload),
})
if (finalResponse.ok) {
const finalResult = await finalResponse.json()
const finalCandidate = finalResult.candidates?.[0]
content = extractTextContent(finalCandidate)
if (finalResult.usageMetadata) {
tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0
tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0
tokens.total +=
(finalResult.usageMetadata.promptTokenCount || 0) +
(finalResult.usageMetadata.candidatesTokenCount || 0)
}
} else {
logger.warn(
'Failed to get structured output, falling back to regular response'
)
content = extractTextContent(nextCandidate)
}
} else {
content = extractTextContent(nextCandidate)
}
break
}
iterationCount++
} catch (error) {
logger.error('Error in Vertex AI follow-up request:', {
error: error instanceof Error ? error.message : String(error),
iterationCount,
})
break
}
} catch (error) {
logger.error('Error processing function call:', {
error: error instanceof Error ? error.message : String(error),
functionName: latestFunctionCall?.name || 'unknown',
})
break
}
}
} else {
content = extractTextContent(candidate)
}
} catch (error) {
logger.error('Error processing Vertex AI response:', {
error: error instanceof Error ? error.message : String(error),
iterationCount,
})
if (!content && toolCalls.length > 0) {
content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.`
}
}
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
if (geminiResponse.usageMetadata) {
tokens = {
prompt: geminiResponse.usageMetadata.promptTokenCount || 0,
completion: geminiResponse.usageMetadata.candidatesTokenCount || 0,
total:
(geminiResponse.usageMetadata.promptTokenCount || 0) +
(geminiResponse.usageMetadata.candidatesTokenCount || 0),
}
}
return {
content,
model: request.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
timing: {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
modelTime: modelTime,
toolsTime: toolsTime,
firstResponseTime: firstResponseTime,
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
}
} catch (error) {
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
logger.error('Error in Vertex AI request:', {
error: error instanceof Error ? error.message : String(error),
duration: totalDuration,
})
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
}
throw enhancedError
}
},
}

View File

@@ -0,0 +1,233 @@
import { createLogger } from '@/lib/logs/console/logger'
import { extractFunctionCall, extractTextContent } from '@/providers/google/utils'
const logger = createLogger('VertexUtils')
/**
* Creates a ReadableStream from Vertex AI's Gemini stream response
*/
export function createReadableStreamFromVertexStream(
response: Response,
onComplete?: (
content: string,
usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number }
) => void
): ReadableStream<Uint8Array> {
const reader = response.body?.getReader()
if (!reader) {
throw new Error('Failed to get reader from response body')
}
return new ReadableStream({
async start(controller) {
try {
let buffer = ''
let fullContent = ''
let usageData: {
promptTokenCount?: number
candidatesTokenCount?: number
totalTokenCount?: number
} | null = null
while (true) {
const { done, value } = await reader.read()
if (done) {
if (buffer.trim()) {
try {
const data = JSON.parse(buffer.trim())
if (data.usageMetadata) {
usageData = data.usageMetadata
}
const candidate = data.candidates?.[0]
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in final buffer, ending stream to execute tool',
{
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
} catch (e) {
if (buffer.trim().startsWith('[')) {
try {
const dataArray = JSON.parse(buffer.trim())
if (Array.isArray(dataArray)) {
for (const item of dataArray) {
if (item.usageMetadata) {
usageData = item.usageMetadata
}
const candidate = item.candidates?.[0]
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in array item, ending stream to execute tool',
{
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
}
}
} catch (arrayError) {
// Buffer is not valid JSON array
}
}
}
}
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
break
}
const text = new TextDecoder().decode(value)
buffer += text
let searchIndex = 0
while (searchIndex < buffer.length) {
const openBrace = buffer.indexOf('{', searchIndex)
if (openBrace === -1) break
let braceCount = 0
let inString = false
let escaped = false
let closeBrace = -1
for (let i = openBrace; i < buffer.length; i++) {
const char = buffer[i]
if (!inString) {
if (char === '"' && !escaped) {
inString = true
} else if (char === '{') {
braceCount++
} else if (char === '}') {
braceCount--
if (braceCount === 0) {
closeBrace = i
break
}
}
} else {
if (char === '"' && !escaped) {
inString = false
}
}
escaped = char === '\\' && !escaped
}
if (closeBrace !== -1) {
const jsonStr = buffer.substring(openBrace, closeBrace + 1)
try {
const data = JSON.parse(jsonStr)
if (data.usageMetadata) {
usageData = data.usageMetadata
}
const candidate = data.candidates?.[0]
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
logger.warn(
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
{
finishReason: candidate.finishReason,
hasContent: !!candidate?.content,
hasParts: !!candidate?.content?.parts,
}
)
const textContent = extractTextContent(candidate)
if (textContent) {
fullContent += textContent
controller.enqueue(new TextEncoder().encode(textContent))
}
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in stream, ending stream to execute tool',
{
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
} catch (e) {
logger.error('Error parsing JSON from stream', {
error: e instanceof Error ? e.message : String(e),
jsonPreview: jsonStr.substring(0, 200),
})
}
buffer = buffer.substring(closeBrace + 1)
searchIndex = 0
} else {
break
}
}
}
} catch (e) {
logger.error('Error reading Vertex AI stream', {
error: e instanceof Error ? e.message : String(e),
})
controller.error(e)
}
},
async cancel() {
await reader.cancel()
},
})
}
/**
* Build Vertex AI endpoint URL
*/
export function buildVertexEndpoint(
project: string,
location: string,
model: string,
isStreaming: boolean
): string {
const action = isStreaming ? 'streamGenerateContent' : 'generateContent'
if (location === 'global') {
return `https://aiplatform.googleapis.com/v1/projects/${project}/locations/global/publishers/google/models/${model}:${action}`
}
return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/google/models/${model}:${action}`
}

View File

@@ -2,6 +2,7 @@ import OpenAI from 'openai'
import { env } from '@/lib/core/config/env'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
@@ -14,50 +15,13 @@ import {
prepareToolsWithUsageControl,
trackForcedToolUsage,
} from '@/providers/utils'
import { createReadableStreamFromVLLMStream } from '@/providers/vllm/utils'
import { useProvidersStore } from '@/stores/providers/store'
import { executeTool } from '@/tools'
const logger = createLogger('VLLMProvider')
const VLLM_VERSION = '1.0.0'
/**
* Helper function to convert a vLLM stream to a standard ReadableStream
* and collect completion metrics
*/
function createReadableStreamFromVLLMStream(
vllmStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of vllmStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}
export const vllmProvider: ProviderConfig = {
id: 'vllm',
name: 'vLLM',
@@ -341,7 +305,6 @@ export const vllmProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10
let modelTime = firstResponseTime
let toolsTime = 0
@@ -360,14 +323,14 @@ export const vllmProvider: ProviderConfig = {
checkForForcedToolUsage(currentResponse, originalToolChoice)
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
logger.info(
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
const toolsStartTime = Date.now()

View File

@@ -0,0 +1,37 @@
/**
* Helper function to convert a vLLM stream to a standard ReadableStream
* and collect completion metrics
*/
export function createReadableStreamFromVLLMStream(
vllmStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of vllmStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}

View File

@@ -1,6 +1,7 @@
import OpenAI from 'openai'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
@@ -8,37 +9,16 @@ import type {
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
import {
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
} from '@/providers/utils'
checkForForcedToolUsage,
createReadableStreamFromXAIStream,
createResponseFormatPayload,
} from '@/providers/xai/utils'
import { executeTool } from '@/tools'
const logger = createLogger('XAIProvider')
/**
* Helper to wrap XAI (OpenAI-compatible) streaming into a browser-friendly
* ReadableStream of raw assistant text chunks.
*/
function createReadableStreamFromXAIStream(xaiStream: any): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of xaiStream) {
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
controller.enqueue(new TextEncoder().encode(content))
}
}
controller.close()
} catch (err) {
controller.error(err)
}
},
})
}
export const xAIProvider: ProviderConfig = {
id: 'xai',
name: 'xAI',
@@ -115,27 +95,6 @@ export const xAIProvider: ProviderConfig = {
if (request.temperature !== undefined) basePayload.temperature = request.temperature
if (request.maxTokens !== undefined) basePayload.max_tokens = request.maxTokens
// Function to create response format configuration
const createResponseFormatPayload = (messages: any[] = allMessages) => {
const payload = {
...basePayload,
messages,
}
if (request.responseFormat) {
payload.response_format = {
type: 'json_schema',
json_schema: {
name: request.responseFormat.name || 'structured_response',
schema: request.responseFormat.schema || request.responseFormat,
strict: request.responseFormat.strict !== false,
},
}
}
return payload
}
// Handle tools and tool usage control
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
@@ -154,7 +113,7 @@ export const xAIProvider: ProviderConfig = {
// Use response format payload if needed, otherwise use base payload
const streamingPayload = request.responseFormat
? createResponseFormatPayload()
? createResponseFormatPayload(basePayload, allMessages, request.responseFormat)
: { ...basePayload, stream: true }
if (!request.responseFormat) {
@@ -243,7 +202,11 @@ export const xAIProvider: ProviderConfig = {
originalToolChoice = toolChoice
} else if (request.responseFormat) {
// Only add response format if there are no tools
const responseFormatPayload = createResponseFormatPayload()
const responseFormatPayload = createResponseFormatPayload(
basePayload,
allMessages,
request.responseFormat
)
Object.assign(initialPayload, responseFormatPayload)
}
@@ -260,7 +223,6 @@ export const xAIProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
const MAX_ITERATIONS = 10
// Track if a forced tool has been used
let hasUsedForcedTool = false
@@ -280,33 +242,20 @@ export const xAIProvider: ProviderConfig = {
},
]
// Helper function to check for forced tool usage in responses
const checkForForcedToolUsage = (
response: any,
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }
) => {
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
const toolCallsResponse = response.choices[0].message.tool_calls
const result = trackForcedToolUsage(
toolCallsResponse,
toolChoice,
logger,
'xai',
forcedTools,
usedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
}
}
// Check if a forced tool was used in the first response
if (originalToolChoice) {
checkForForcedToolUsage(currentResponse, originalToolChoice)
const result = checkForForcedToolUsage(
currentResponse,
originalToolChoice,
forcedTools,
usedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
}
try {
while (iterationCount < MAX_ITERATIONS) {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
@@ -432,7 +381,12 @@ export const xAIProvider: ProviderConfig = {
} else {
// All forced tools have been used, check if we need response format for final response
if (request.responseFormat) {
nextPayload = createResponseFormatPayload(currentMessages)
nextPayload = createResponseFormatPayload(
basePayload,
allMessages,
request.responseFormat,
currentMessages
)
} else {
nextPayload = {
...basePayload,
@@ -446,7 +400,12 @@ export const xAIProvider: ProviderConfig = {
// Normal tool processing - check if this might be the final response
if (request.responseFormat) {
// Use response format for what might be the final response
nextPayload = createResponseFormatPayload(currentMessages)
nextPayload = createResponseFormatPayload(
basePayload,
allMessages,
request.responseFormat,
currentMessages
)
} else {
nextPayload = {
...basePayload,
@@ -464,7 +423,14 @@ export const xAIProvider: ProviderConfig = {
// Check if any forced tools were used in this response
if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') {
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
const result = checkForForcedToolUsage(
currentResponse,
nextPayload.tool_choice,
forcedTools,
usedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
}
const nextModelEndTime = Date.now()
@@ -509,7 +475,12 @@ export const xAIProvider: ProviderConfig = {
if (request.responseFormat) {
// Use response format, no tools
finalStreamingPayload = {
...createResponseFormatPayload(currentMessages),
...createResponseFormatPayload(
basePayload,
allMessages,
request.responseFormat,
currentMessages
),
stream: true,
}
} else {

View File

@@ -0,0 +1,83 @@
import { createLogger } from '@/lib/logs/console/logger'
import { trackForcedToolUsage } from '@/providers/utils'
const logger = createLogger('XAIProvider')
/**
* Helper to wrap XAI (OpenAI-compatible) streaming into a browser-friendly
* ReadableStream of raw assistant text chunks.
*/
export function createReadableStreamFromXAIStream(xaiStream: any): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of xaiStream) {
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
controller.enqueue(new TextEncoder().encode(content))
}
}
controller.close()
} catch (err) {
controller.error(err)
}
},
})
}
/**
* Creates a response format payload for XAI API requests.
*/
export function createResponseFormatPayload(
basePayload: any,
allMessages: any[],
responseFormat: any,
currentMessages?: any[]
) {
const payload = {
...basePayload,
messages: currentMessages || allMessages,
}
if (responseFormat) {
payload.response_format = {
type: 'json_schema',
json_schema: {
name: responseFormat.name || 'structured_response',
schema: responseFormat.schema || responseFormat,
strict: responseFormat.strict !== false,
},
}
}
return payload
}
/**
* Helper function to check for forced tool usage in responses.
*/
export function checkForForcedToolUsage(
response: any,
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
forcedTools: string[],
usedForcedTools: string[]
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
let hasUsedForcedTool = false
let updatedUsedForcedTools = usedForcedTools
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
const toolCallsResponse = response.choices[0].message.tool_calls
const result = trackForcedToolUsage(
toolCallsResponse,
toolChoice,
logger,
'xai',
forcedTools,
updatedUsedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
updatedUsedForcedTools = result.usedForcedTools
}
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
}

View File

@@ -13,6 +13,8 @@ interface LLMChatParams {
maxTokens?: number
azureEndpoint?: string
azureApiVersion?: string
vertexProject?: string
vertexLocation?: string
}
interface LLMChatResponse extends ToolResponse {
@@ -77,6 +79,18 @@ export const llmChatTool: ToolConfig<LLMChatParams, LLMChatResponse> = {
visibility: 'hidden',
description: 'Azure OpenAI API version',
},
vertexProject: {
type: 'string',
required: false,
visibility: 'hidden',
description: 'Google Cloud project ID for Vertex AI',
},
vertexLocation: {
type: 'string',
required: false,
visibility: 'hidden',
description: 'Google Cloud location for Vertex AI (defaults to us-central1)',
},
},
request: {
@@ -98,6 +112,8 @@ export const llmChatTool: ToolConfig<LLMChatParams, LLMChatResponse> = {
maxTokens: params.maxTokens,
azureEndpoint: params.azureEndpoint,
azureApiVersion: params.azureApiVersion,
vertexProject: params.vertexProject,
vertexLocation: params.vertexLocation,
}
},
},

View File

@@ -1,5 +1,6 @@
{
"lockfileVersion": 1,
"configVersion": 0,
"workspaces": {
"": {
"name": "simstudio",
@@ -266,12 +267,12 @@
"sharp",
],
"overrides": {
"react": "19.2.1",
"react-dom": "19.2.1",
"next": "16.1.0-canary.21",
"@next/env": "16.1.0-canary.21",
"drizzle-orm": "^0.44.5",
"next": "16.1.0-canary.21",
"postgres": "^3.4.5",
"react": "19.2.1",
"react-dom": "19.2.1",
},
"packages": {
"@adobe/css-tools": ["@adobe/css-tools@4.4.4", "", {}, "sha512-Elp+iwUx5rN5+Y8xLt5/GRoG20WGoDCQ/1Fb+1LiGtvwbDavuSk0jhD/eZdckHAuzcDzccnkv+rEjyWfRx18gg=="],