mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-08 14:43:54 -05:00
feat(vertex): added vertex to list of supported providers (#2430)
* feat(vertex): added vertex to list of supported providers * added utils files for each provider, consolidated gemini utils, added dynamic verbosity and reasoning fetcher
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
"private": true,
|
||||
"license": "Apache-2.0",
|
||||
"scripts": {
|
||||
"dev": "next dev --port 3001",
|
||||
"dev": "next dev --port 7322",
|
||||
"build": "fumadocs-mdx && NODE_OPTIONS='--max-old-space-size=8192' next build",
|
||||
"start": "next start",
|
||||
"postinstall": "fumadocs-mdx",
|
||||
|
||||
@@ -303,6 +303,14 @@ export async function POST(req: NextRequest) {
|
||||
apiVersion: 'preview',
|
||||
endpoint: env.AZURE_OPENAI_ENDPOINT,
|
||||
}
|
||||
} else if (providerEnv === 'vertex') {
|
||||
providerConfig = {
|
||||
provider: 'vertex',
|
||||
model: modelToUse,
|
||||
apiKey: env.COPILOT_API_KEY,
|
||||
vertexProject: env.VERTEX_PROJECT,
|
||||
vertexLocation: env.VERTEX_LOCATION,
|
||||
}
|
||||
} else {
|
||||
providerConfig = {
|
||||
provider: providerEnv,
|
||||
|
||||
@@ -66,6 +66,14 @@ export async function POST(req: NextRequest) {
|
||||
apiVersion: env.AZURE_OPENAI_API_VERSION,
|
||||
endpoint: env.AZURE_OPENAI_ENDPOINT,
|
||||
}
|
||||
} else if (providerEnv === 'vertex') {
|
||||
providerConfig = {
|
||||
provider: 'vertex',
|
||||
model: modelToUse,
|
||||
apiKey: env.COPILOT_API_KEY,
|
||||
vertexProject: env.VERTEX_PROJECT,
|
||||
vertexLocation: env.VERTEX_LOCATION,
|
||||
}
|
||||
} else {
|
||||
providerConfig = {
|
||||
provider: providerEnv,
|
||||
|
||||
@@ -35,6 +35,8 @@ export async function POST(request: NextRequest) {
|
||||
apiKey,
|
||||
azureEndpoint,
|
||||
azureApiVersion,
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
responseFormat,
|
||||
workflowId,
|
||||
workspaceId,
|
||||
@@ -58,6 +60,8 @@ export async function POST(request: NextRequest) {
|
||||
hasApiKey: !!apiKey,
|
||||
hasAzureEndpoint: !!azureEndpoint,
|
||||
hasAzureApiVersion: !!azureApiVersion,
|
||||
hasVertexProject: !!vertexProject,
|
||||
hasVertexLocation: !!vertexLocation,
|
||||
hasResponseFormat: !!responseFormat,
|
||||
workflowId,
|
||||
stream: !!stream,
|
||||
@@ -104,6 +108,8 @@ export async function POST(request: NextRequest) {
|
||||
apiKey: finalApiKey,
|
||||
azureEndpoint,
|
||||
azureApiVersion,
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
responseFormat,
|
||||
workflowId,
|
||||
workspaceId,
|
||||
|
||||
@@ -8,6 +8,8 @@ import {
|
||||
getHostedModels,
|
||||
getMaxTemperature,
|
||||
getProviderIcon,
|
||||
getReasoningEffortValuesForModel,
|
||||
getVerbosityValuesForModel,
|
||||
MODELS_WITH_REASONING_EFFORT,
|
||||
MODELS_WITH_VERBOSITY,
|
||||
providers,
|
||||
@@ -114,12 +116,47 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
||||
type: 'dropdown',
|
||||
placeholder: 'Select reasoning effort...',
|
||||
options: [
|
||||
{ label: 'none', id: 'none' },
|
||||
{ label: 'minimal', id: 'minimal' },
|
||||
{ label: 'low', id: 'low' },
|
||||
{ label: 'medium', id: 'medium' },
|
||||
{ label: 'high', id: 'high' },
|
||||
],
|
||||
dependsOn: ['model'],
|
||||
fetchOptions: async (blockId: string) => {
|
||||
const { useSubBlockStore } = await import('@/stores/workflows/subblock/store')
|
||||
const { useWorkflowRegistry } = await import('@/stores/workflows/registry/store')
|
||||
|
||||
const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId
|
||||
if (!activeWorkflowId) {
|
||||
return [
|
||||
{ label: 'low', id: 'low' },
|
||||
{ label: 'medium', id: 'medium' },
|
||||
{ label: 'high', id: 'high' },
|
||||
]
|
||||
}
|
||||
|
||||
const workflowValues = useSubBlockStore.getState().workflowValues[activeWorkflowId]
|
||||
const blockValues = workflowValues?.[blockId]
|
||||
const modelValue = blockValues?.model as string
|
||||
|
||||
if (!modelValue) {
|
||||
return [
|
||||
{ label: 'low', id: 'low' },
|
||||
{ label: 'medium', id: 'medium' },
|
||||
{ label: 'high', id: 'high' },
|
||||
]
|
||||
}
|
||||
|
||||
const validOptions = getReasoningEffortValuesForModel(modelValue)
|
||||
if (!validOptions) {
|
||||
return [
|
||||
{ label: 'low', id: 'low' },
|
||||
{ label: 'medium', id: 'medium' },
|
||||
{ label: 'high', id: 'high' },
|
||||
]
|
||||
}
|
||||
|
||||
return validOptions.map((opt) => ({ label: opt, id: opt }))
|
||||
},
|
||||
value: () => 'medium',
|
||||
condition: {
|
||||
field: 'model',
|
||||
@@ -136,6 +173,43 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
||||
{ label: 'medium', id: 'medium' },
|
||||
{ label: 'high', id: 'high' },
|
||||
],
|
||||
dependsOn: ['model'],
|
||||
fetchOptions: async (blockId: string) => {
|
||||
const { useSubBlockStore } = await import('@/stores/workflows/subblock/store')
|
||||
const { useWorkflowRegistry } = await import('@/stores/workflows/registry/store')
|
||||
|
||||
const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId
|
||||
if (!activeWorkflowId) {
|
||||
return [
|
||||
{ label: 'low', id: 'low' },
|
||||
{ label: 'medium', id: 'medium' },
|
||||
{ label: 'high', id: 'high' },
|
||||
]
|
||||
}
|
||||
|
||||
const workflowValues = useSubBlockStore.getState().workflowValues[activeWorkflowId]
|
||||
const blockValues = workflowValues?.[blockId]
|
||||
const modelValue = blockValues?.model as string
|
||||
|
||||
if (!modelValue) {
|
||||
return [
|
||||
{ label: 'low', id: 'low' },
|
||||
{ label: 'medium', id: 'medium' },
|
||||
{ label: 'high', id: 'high' },
|
||||
]
|
||||
}
|
||||
|
||||
const validOptions = getVerbosityValuesForModel(modelValue)
|
||||
if (!validOptions) {
|
||||
return [
|
||||
{ label: 'low', id: 'low' },
|
||||
{ label: 'medium', id: 'medium' },
|
||||
{ label: 'high', id: 'high' },
|
||||
]
|
||||
}
|
||||
|
||||
return validOptions.map((opt) => ({ label: opt, id: opt }))
|
||||
},
|
||||
value: () => 'medium',
|
||||
condition: {
|
||||
field: 'model',
|
||||
@@ -166,6 +240,28 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
||||
value: providers['azure-openai'].models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'vertexProject',
|
||||
title: 'Vertex AI Project',
|
||||
type: 'short-input',
|
||||
placeholder: 'your-gcp-project-id',
|
||||
connectionDroppable: false,
|
||||
condition: {
|
||||
field: 'model',
|
||||
value: providers.vertex.models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'vertexLocation',
|
||||
title: 'Vertex AI Location',
|
||||
type: 'short-input',
|
||||
placeholder: 'us-central1',
|
||||
connectionDroppable: false,
|
||||
condition: {
|
||||
field: 'model',
|
||||
value: providers.vertex.models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'tools',
|
||||
title: 'Tools',
|
||||
@@ -465,6 +561,8 @@ Example 3 (Array Input):
|
||||
apiKey: { type: 'string', description: 'Provider API key' },
|
||||
azureEndpoint: { type: 'string', description: 'Azure OpenAI endpoint URL' },
|
||||
azureApiVersion: { type: 'string', description: 'Azure API version' },
|
||||
vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' },
|
||||
vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' },
|
||||
responseFormat: {
|
||||
type: 'json',
|
||||
description: 'JSON response format schema',
|
||||
|
||||
@@ -239,6 +239,28 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
|
||||
value: providers['azure-openai'].models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'vertexProject',
|
||||
title: 'Vertex AI Project',
|
||||
type: 'short-input',
|
||||
placeholder: 'your-gcp-project-id',
|
||||
connectionDroppable: false,
|
||||
condition: {
|
||||
field: 'model',
|
||||
value: providers.vertex.models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'vertexLocation',
|
||||
title: 'Vertex AI Location',
|
||||
type: 'short-input',
|
||||
placeholder: 'us-central1',
|
||||
connectionDroppable: false,
|
||||
condition: {
|
||||
field: 'model',
|
||||
value: providers.vertex.models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'temperature',
|
||||
title: 'Temperature',
|
||||
@@ -356,6 +378,14 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
|
||||
apiKey: { type: 'string' as ParamType, description: 'Provider API key' },
|
||||
azureEndpoint: { type: 'string' as ParamType, description: 'Azure OpenAI endpoint URL' },
|
||||
azureApiVersion: { type: 'string' as ParamType, description: 'Azure API version' },
|
||||
vertexProject: {
|
||||
type: 'string' as ParamType,
|
||||
description: 'Google Cloud project ID for Vertex AI',
|
||||
},
|
||||
vertexLocation: {
|
||||
type: 'string' as ParamType,
|
||||
description: 'Google Cloud location for Vertex AI',
|
||||
},
|
||||
temperature: {
|
||||
type: 'number' as ParamType,
|
||||
description: 'Response randomness level (low for consistent evaluation)',
|
||||
|
||||
@@ -188,6 +188,28 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
|
||||
value: providers['azure-openai'].models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'vertexProject',
|
||||
title: 'Vertex AI Project',
|
||||
type: 'short-input',
|
||||
placeholder: 'your-gcp-project-id',
|
||||
connectionDroppable: false,
|
||||
condition: {
|
||||
field: 'model',
|
||||
value: providers.vertex.models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'vertexLocation',
|
||||
title: 'Vertex AI Location',
|
||||
type: 'short-input',
|
||||
placeholder: 'us-central1',
|
||||
connectionDroppable: false,
|
||||
condition: {
|
||||
field: 'model',
|
||||
value: providers.vertex.models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'temperature',
|
||||
title: 'Temperature',
|
||||
@@ -235,6 +257,8 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
|
||||
apiKey: { type: 'string', description: 'Provider API key' },
|
||||
azureEndpoint: { type: 'string', description: 'Azure OpenAI endpoint URL' },
|
||||
azureApiVersion: { type: 'string', description: 'Azure API version' },
|
||||
vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' },
|
||||
vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' },
|
||||
temperature: {
|
||||
type: 'number',
|
||||
description: 'Response randomness level (low for consistent routing)',
|
||||
|
||||
@@ -99,6 +99,28 @@ export const TranslateBlock: BlockConfig = {
|
||||
value: providers['azure-openai'].models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'vertexProject',
|
||||
title: 'Vertex AI Project',
|
||||
type: 'short-input',
|
||||
placeholder: 'your-gcp-project-id',
|
||||
connectionDroppable: false,
|
||||
condition: {
|
||||
field: 'model',
|
||||
value: providers.vertex.models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'vertexLocation',
|
||||
title: 'Vertex AI Location',
|
||||
type: 'short-input',
|
||||
placeholder: 'us-central1',
|
||||
connectionDroppable: false,
|
||||
condition: {
|
||||
field: 'model',
|
||||
value: providers.vertex.models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'systemPrompt',
|
||||
title: 'System Prompt',
|
||||
@@ -120,6 +142,8 @@ export const TranslateBlock: BlockConfig = {
|
||||
apiKey: params.apiKey,
|
||||
azureEndpoint: params.azureEndpoint,
|
||||
azureApiVersion: params.azureApiVersion,
|
||||
vertexProject: params.vertexProject,
|
||||
vertexLocation: params.vertexLocation,
|
||||
}),
|
||||
},
|
||||
},
|
||||
@@ -129,6 +153,8 @@ export const TranslateBlock: BlockConfig = {
|
||||
apiKey: { type: 'string', description: 'Provider API key' },
|
||||
azureEndpoint: { type: 'string', description: 'Azure OpenAI endpoint URL' },
|
||||
azureApiVersion: { type: 'string', description: 'Azure API version' },
|
||||
vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' },
|
||||
vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' },
|
||||
systemPrompt: { type: 'string', description: 'Translation instructions' },
|
||||
},
|
||||
outputs: {
|
||||
|
||||
@@ -2452,6 +2452,56 @@ export const GeminiIcon = (props: SVGProps<SVGSVGElement>) => (
|
||||
</svg>
|
||||
)
|
||||
|
||||
export const VertexIcon = (props: SVGProps<SVGSVGElement>) => (
|
||||
<svg
|
||||
{...props}
|
||||
id='standard_product_icon'
|
||||
xmlns='http://www.w3.org/2000/svg'
|
||||
version='1.1'
|
||||
viewBox='0 0 512 512'
|
||||
>
|
||||
<g id='bounding_box'>
|
||||
<rect width='512' height='512' fill='none' />
|
||||
</g>
|
||||
<g id='art'>
|
||||
<path
|
||||
d='M128,244.99c-8.84,0-16-7.16-16-16v-95.97c0-8.84,7.16-16,16-16s16,7.16,16,16v95.97c0,8.84-7.16,16-16,16Z'
|
||||
fill='#ea4335'
|
||||
/>
|
||||
<path
|
||||
d='M256,458c-2.98,0-5.97-.83-8.59-2.5l-186-122c-7.46-4.74-9.65-14.63-4.91-22.09,4.75-7.46,14.64-9.65,22.09-4.91l177.41,116.53,177.41-116.53c7.45-4.74,17.34-2.55,22.09,4.91,4.74,7.46,2.55,17.34-4.91,22.09l-186,122c-2.62,1.67-5.61,2.5-8.59,2.5Z'
|
||||
fill='#fbbc04'
|
||||
/>
|
||||
<path
|
||||
d='M256,388.03c-8.84,0-16-7.16-16-16v-73.06c0-8.84,7.16-16,16-16s16,7.16,16,16v73.06c0,8.84-7.16,16-16,16Z'
|
||||
fill='#34a853'
|
||||
/>
|
||||
<circle cx='128' cy='70' r='16' fill='#ea4335' />
|
||||
<circle cx='128' cy='292' r='16' fill='#ea4335' />
|
||||
<path
|
||||
d='M384.23,308.01c-8.82,0-15.98-7.14-16-15.97l-.23-94.01c-.02-8.84,7.13-16.02,15.97-16.03h.04c8.82,0,15.98,7.14,16,15.97l.23,94.01c.02,8.84-7.13,16.02-15.97,16.03h-.04Z'
|
||||
fill='#4285f4'
|
||||
/>
|
||||
<circle cx='384' cy='70' r='16' fill='#4285f4' />
|
||||
<circle cx='384' cy='134' r='16' fill='#4285f4' />
|
||||
<path
|
||||
d='M320,220.36c-8.84,0-16-7.16-16-16v-103.02c0-8.84,7.16-16,16-16s16,7.16,16,16v103.02c0,8.84-7.16,16-16,16Z'
|
||||
fill='#fbbc04'
|
||||
/>
|
||||
<circle cx='256' cy='171' r='16' fill='#34a853' />
|
||||
<circle cx='256' cy='235' r='16' fill='#34a853' />
|
||||
<circle cx='320' cy='265' r='16' fill='#fbbc04' />
|
||||
<circle cx='320' cy='329' r='16' fill='#fbbc04' />
|
||||
<path
|
||||
d='M192,217.36c-8.84,0-16-7.16-16-16v-100.02c0-8.84,7.16-16,16-16s16,7.16,16,16v100.02c0,8.84-7.16,16-16,16Z'
|
||||
fill='#fbbc04'
|
||||
/>
|
||||
<circle cx='192' cy='265' r='16' fill='#fbbc04' />
|
||||
<circle cx='192' cy='329' r='16' fill='#fbbc04' />
|
||||
</g>
|
||||
</svg>
|
||||
)
|
||||
|
||||
export const CerebrasIcon = (props: SVGProps<SVGSVGElement>) => (
|
||||
<svg
|
||||
{...props}
|
||||
|
||||
@@ -493,7 +493,7 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
const discoveredTools = await this.discoverMcpToolsForServer(ctx, serverId)
|
||||
return { serverId, tools, discoveredTools, error: null as Error | null }
|
||||
} catch (error) {
|
||||
logger.error(`Failed to discover tools from server ${serverId}:`, error)
|
||||
logger.error(`Failed to discover tools from server ${serverId}:`)
|
||||
return { serverId, tools, discoveredTools: [] as any[], error: error as Error }
|
||||
}
|
||||
})
|
||||
@@ -883,6 +883,8 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
apiKey: inputs.apiKey,
|
||||
azureEndpoint: inputs.azureEndpoint,
|
||||
azureApiVersion: inputs.azureApiVersion,
|
||||
vertexProject: inputs.vertexProject,
|
||||
vertexLocation: inputs.vertexLocation,
|
||||
responseFormat,
|
||||
workflowId: ctx.workflowId,
|
||||
workspaceId: ctx.workspaceId,
|
||||
@@ -975,6 +977,8 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
apiKey: finalApiKey,
|
||||
azureEndpoint: providerRequest.azureEndpoint,
|
||||
azureApiVersion: providerRequest.azureApiVersion,
|
||||
vertexProject: providerRequest.vertexProject,
|
||||
vertexLocation: providerRequest.vertexLocation,
|
||||
responseFormat: providerRequest.responseFormat,
|
||||
workflowId: providerRequest.workflowId,
|
||||
workspaceId: providerRequest.workspaceId,
|
||||
|
||||
@@ -19,6 +19,8 @@ export interface AgentInputs {
|
||||
apiKey?: string
|
||||
azureEndpoint?: string
|
||||
azureApiVersion?: string
|
||||
vertexProject?: string
|
||||
vertexLocation?: string
|
||||
reasoningEffort?: string
|
||||
verbosity?: string
|
||||
}
|
||||
|
||||
@@ -148,7 +148,14 @@ export type CopilotProviderConfig =
|
||||
endpoint?: string
|
||||
}
|
||||
| {
|
||||
provider: Exclude<ProviderId, 'azure-openai'>
|
||||
provider: 'vertex'
|
||||
model: string
|
||||
apiKey?: string
|
||||
vertexProject?: string
|
||||
vertexLocation?: string
|
||||
}
|
||||
| {
|
||||
provider: Exclude<ProviderId, 'azure-openai' | 'vertex'>
|
||||
model?: string
|
||||
apiKey?: string
|
||||
}
|
||||
|
||||
@@ -98,6 +98,10 @@ export const env = createEnv({
|
||||
OCR_AZURE_MODEL_NAME: z.string().optional(), // Azure Mistral OCR model name for document processing
|
||||
OCR_AZURE_API_KEY: z.string().min(1).optional(), // Azure Mistral OCR API key
|
||||
|
||||
// Vertex AI Configuration
|
||||
VERTEX_PROJECT: z.string().optional(), // Google Cloud project ID for Vertex AI
|
||||
VERTEX_LOCATION: z.string().optional(), // Google Cloud location/region for Vertex AI (defaults to us-central1)
|
||||
|
||||
// Monitoring & Analytics
|
||||
TELEMETRY_ENDPOINT: z.string().url().optional(), // Custom telemetry/analytics endpoint
|
||||
COST_MULTIPLIER: z.number().optional(), // Multiplier for cost calculations
|
||||
|
||||
@@ -404,15 +404,11 @@ class McpService {
|
||||
failedCount++
|
||||
const errorMessage =
|
||||
result.reason instanceof Error ? result.reason.message : 'Unknown error'
|
||||
logger.warn(
|
||||
`[${requestId}] Failed to discover tools from server ${server.name}:`,
|
||||
result.reason
|
||||
)
|
||||
logger.warn(`[${requestId}] Failed to discover tools from server ${server.name}:`)
|
||||
statusUpdates.push(this.updateServerStatus(server.id!, workspaceId, false, errorMessage))
|
||||
}
|
||||
})
|
||||
|
||||
// Update server statuses in parallel (don't block on this)
|
||||
Promise.allSettled(statusUpdates).catch((err) => {
|
||||
logger.error(`[${requestId}] Error updating server statuses:`, err)
|
||||
})
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"node": ">=20.0.0"
|
||||
},
|
||||
"scripts": {
|
||||
"dev": "next dev --port 3000",
|
||||
"dev": "next dev --port 7321",
|
||||
"dev:webpack": "next dev --webpack",
|
||||
"dev:sockets": "bun run socket-server/index.ts",
|
||||
"dev:full": "concurrently -n \"App,Realtime\" -c \"cyan,magenta\" \"bun run dev\" \"bun run dev:sockets\"",
|
||||
|
||||
@@ -1,35 +1,24 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import {
|
||||
checkForForcedToolUsage,
|
||||
createReadableStreamFromAnthropicStream,
|
||||
generateToolUseId,
|
||||
} from '@/providers/anthropic/utils'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
import { getProviderDefaultModel, getProviderModels } from '../models'
|
||||
import type { ProviderConfig, ProviderRequest, ProviderResponse, TimeSegment } from '../types'
|
||||
import { prepareToolExecution, prepareToolsWithUsageControl, trackForcedToolUsage } from '../utils'
|
||||
|
||||
const logger = createLogger('AnthropicProvider')
|
||||
|
||||
/**
|
||||
* Helper to wrap Anthropic streaming into a browser-friendly ReadableStream
|
||||
*/
|
||||
function createReadableStreamFromAnthropicStream(
|
||||
anthropicStream: AsyncIterable<any>
|
||||
): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === 'content_block_delta' && event.delta?.text) {
|
||||
controller.enqueue(new TextEncoder().encode(event.delta.text))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (err) {
|
||||
controller.error(err)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const anthropicProvider: ProviderConfig = {
|
||||
id: 'anthropic',
|
||||
name: 'Anthropic',
|
||||
@@ -47,11 +36,6 @@ export const anthropicProvider: ProviderConfig = {
|
||||
|
||||
const anthropic = new Anthropic({ apiKey: request.apiKey })
|
||||
|
||||
// Helper function to generate a simple unique ID for tool uses
|
||||
const generateToolUseId = (toolName: string) => {
|
||||
return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}`
|
||||
}
|
||||
|
||||
// Transform messages to Anthropic format
|
||||
const messages: any[] = []
|
||||
|
||||
@@ -373,7 +357,6 @@ ${fieldDescriptions}
|
||||
const toolResults = []
|
||||
const currentMessages = [...messages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
// Track if a forced tool has been used
|
||||
let hasUsedForcedTool = false
|
||||
@@ -393,47 +376,20 @@ ${fieldDescriptions}
|
||||
},
|
||||
]
|
||||
|
||||
// Helper function to check for forced tool usage in Anthropic responses
|
||||
const checkForForcedToolUsage = (response: any, toolChoice: any) => {
|
||||
if (
|
||||
typeof toolChoice === 'object' &&
|
||||
toolChoice !== null &&
|
||||
Array.isArray(response.content)
|
||||
) {
|
||||
const toolUses = response.content.filter((item: any) => item.type === 'tool_use')
|
||||
|
||||
if (toolUses.length > 0) {
|
||||
// Convert Anthropic tool_use format to a format trackForcedToolUsage can understand
|
||||
const adaptedToolCalls = toolUses.map((tool: any) => ({
|
||||
name: tool.name,
|
||||
}))
|
||||
|
||||
// Convert Anthropic tool_choice format to match OpenAI format for tracking
|
||||
const adaptedToolChoice =
|
||||
toolChoice.type === 'tool' ? { function: { name: toolChoice.name } } : toolChoice
|
||||
|
||||
const result = trackForcedToolUsage(
|
||||
adaptedToolCalls,
|
||||
adaptedToolChoice,
|
||||
logger,
|
||||
'anthropic',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
// Make the behavior consistent with the initial check
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
return result
|
||||
}
|
||||
}
|
||||
return null
|
||||
// Check if a forced tool was used in the first response
|
||||
const firstCheckResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
originalToolChoice,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
if (firstCheckResult) {
|
||||
hasUsedForcedTool = firstCheckResult.hasUsedForcedTool
|
||||
usedForcedTools = firstCheckResult.usedForcedTools
|
||||
}
|
||||
|
||||
// Check if a forced tool was used in the first response
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolUses = currentResponse.content.filter((item) => item.type === 'tool_use')
|
||||
if (!toolUses || toolUses.length === 0) {
|
||||
@@ -576,7 +532,16 @@ ${fieldDescriptions}
|
||||
currentResponse = await anthropic.messages.create(nextPayload)
|
||||
|
||||
// Check if any forced tools were used in this response
|
||||
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
|
||||
const nextCheckResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
nextPayload.tool_choice,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
if (nextCheckResult) {
|
||||
hasUsedForcedTool = nextCheckResult.hasUsedForcedTool
|
||||
usedForcedTools = nextCheckResult.usedForcedTools
|
||||
}
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
@@ -727,7 +692,6 @@ ${fieldDescriptions}
|
||||
const toolResults = []
|
||||
const currentMessages = [...messages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
// Track if a forced tool has been used
|
||||
let hasUsedForcedTool = false
|
||||
@@ -747,47 +711,20 @@ ${fieldDescriptions}
|
||||
},
|
||||
]
|
||||
|
||||
// Helper function to check for forced tool usage in Anthropic responses
|
||||
const checkForForcedToolUsage = (response: any, toolChoice: any) => {
|
||||
if (
|
||||
typeof toolChoice === 'object' &&
|
||||
toolChoice !== null &&
|
||||
Array.isArray(response.content)
|
||||
) {
|
||||
const toolUses = response.content.filter((item: any) => item.type === 'tool_use')
|
||||
|
||||
if (toolUses.length > 0) {
|
||||
// Convert Anthropic tool_use format to a format trackForcedToolUsage can understand
|
||||
const adaptedToolCalls = toolUses.map((tool: any) => ({
|
||||
name: tool.name,
|
||||
}))
|
||||
|
||||
// Convert Anthropic tool_choice format to match OpenAI format for tracking
|
||||
const adaptedToolChoice =
|
||||
toolChoice.type === 'tool' ? { function: { name: toolChoice.name } } : toolChoice
|
||||
|
||||
const result = trackForcedToolUsage(
|
||||
adaptedToolCalls,
|
||||
adaptedToolChoice,
|
||||
logger,
|
||||
'anthropic',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
// Make the behavior consistent with the initial check
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
return result
|
||||
}
|
||||
}
|
||||
return null
|
||||
// Check if a forced tool was used in the first response
|
||||
const firstCheckResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
originalToolChoice,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
if (firstCheckResult) {
|
||||
hasUsedForcedTool = firstCheckResult.hasUsedForcedTool
|
||||
usedForcedTools = firstCheckResult.usedForcedTools
|
||||
}
|
||||
|
||||
// Check if a forced tool was used in the first response
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolUses = currentResponse.content.filter((item) => item.type === 'tool_use')
|
||||
if (!toolUses || toolUses.length === 0) {
|
||||
@@ -926,7 +863,16 @@ ${fieldDescriptions}
|
||||
currentResponse = await anthropic.messages.create(nextPayload)
|
||||
|
||||
// Check if any forced tools were used in this response
|
||||
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
|
||||
const nextCheckResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
nextPayload.tool_choice,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
if (nextCheckResult) {
|
||||
hasUsedForcedTool = nextCheckResult.hasUsedForcedTool
|
||||
usedForcedTools = nextCheckResult.usedForcedTools
|
||||
}
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
70
apps/sim/providers/anthropic/utils.ts
Normal file
70
apps/sim/providers/anthropic/utils.ts
Normal file
@@ -0,0 +1,70 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { trackForcedToolUsage } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('AnthropicUtils')
|
||||
|
||||
/**
|
||||
* Helper to wrap Anthropic streaming into a browser-friendly ReadableStream
|
||||
*/
|
||||
export function createReadableStreamFromAnthropicStream(
|
||||
anthropicStream: AsyncIterable<any>
|
||||
): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === 'content_block_delta' && event.delta?.text) {
|
||||
controller.enqueue(new TextEncoder().encode(event.delta.text))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (err) {
|
||||
controller.error(err)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to generate a simple unique ID for tool uses
|
||||
*/
|
||||
export function generateToolUseId(toolName: string): string {
|
||||
return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to check for forced tool usage in Anthropic responses
|
||||
*/
|
||||
export function checkForForcedToolUsage(
|
||||
response: any,
|
||||
toolChoice: any,
|
||||
forcedTools: string[],
|
||||
usedForcedTools: string[]
|
||||
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } | null {
|
||||
if (typeof toolChoice === 'object' && toolChoice !== null && Array.isArray(response.content)) {
|
||||
const toolUses = response.content.filter((item: any) => item.type === 'tool_use')
|
||||
|
||||
if (toolUses.length > 0) {
|
||||
// Convert Anthropic tool_use format to a format trackForcedToolUsage can understand
|
||||
const adaptedToolCalls = toolUses.map((tool: any) => ({
|
||||
name: tool.name,
|
||||
}))
|
||||
|
||||
// Convert Anthropic tool_choice format to match OpenAI format for tracking
|
||||
const adaptedToolChoice =
|
||||
toolChoice.type === 'tool' ? { function: { name: toolChoice.name } } : toolChoice
|
||||
|
||||
const result = trackForcedToolUsage(
|
||||
adaptedToolCalls,
|
||||
adaptedToolChoice,
|
||||
logger,
|
||||
'anthropic',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
@@ -2,6 +2,11 @@ import { AzureOpenAI } from 'openai'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import {
|
||||
checkForForcedToolUsage,
|
||||
createReadableStreamFromAzureOpenAIStream,
|
||||
} from '@/providers/azure-openai/utils'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
@@ -9,55 +14,11 @@ import type {
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('AzureOpenAIProvider')
|
||||
|
||||
/**
|
||||
* Helper function to convert an Azure OpenAI stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
*/
|
||||
function createReadableStreamFromAzureOpenAIStream(
|
||||
azureOpenAIStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of azureOpenAIStream) {
|
||||
// Check for usage data in the final chunk
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
// Once stream is complete, call the completion callback with the final content and usage
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Azure OpenAI provider configuration
|
||||
*/
|
||||
@@ -303,26 +264,6 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
|
||||
// Helper function to check for forced tool usage in responses
|
||||
const checkForForcedToolUsage = (
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }
|
||||
) => {
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
'azure-openai',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
}
|
||||
}
|
||||
|
||||
let currentResponse = await azureOpenAI.chat.completions.create(payload)
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
@@ -337,7 +278,6 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
@@ -358,9 +298,17 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
]
|
||||
|
||||
// Check if a forced tool was used in the first response
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
const firstCheckResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
originalToolChoice,
|
||||
logger,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = firstCheckResult.hasUsedForcedTool
|
||||
usedForcedTools = firstCheckResult.usedForcedTools
|
||||
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
@@ -368,7 +316,7 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
// Track time for tool calls in this batch
|
||||
@@ -491,7 +439,15 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
currentResponse = await azureOpenAI.chat.completions.create(nextPayload)
|
||||
|
||||
// Check if any forced tools were used in this response
|
||||
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
|
||||
const nextCheckResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
nextPayload.tool_choice,
|
||||
logger,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = nextCheckResult.hasUsedForcedTool
|
||||
usedForcedTools = nextCheckResult.usedForcedTools
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
70
apps/sim/providers/azure-openai/utils.ts
Normal file
70
apps/sim/providers/azure-openai/utils.ts
Normal file
@@ -0,0 +1,70 @@
|
||||
import type { Logger } from '@/lib/logs/console/logger'
|
||||
import { trackForcedToolUsage } from '@/providers/utils'
|
||||
|
||||
/**
|
||||
* Helper function to convert an Azure OpenAI stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
*/
|
||||
export function createReadableStreamFromAzureOpenAIStream(
|
||||
azureOpenAIStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of azureOpenAIStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to check for forced tool usage in responses
|
||||
*/
|
||||
export function checkForForcedToolUsage(
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
|
||||
logger: Logger,
|
||||
forcedTools: string[],
|
||||
usedForcedTools: string[]
|
||||
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
|
||||
let hasUsedForcedTool = false
|
||||
let updatedUsedForcedTools = [...usedForcedTools]
|
||||
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
'azure-openai',
|
||||
forcedTools,
|
||||
updatedUsedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
updatedUsedForcedTools = result.usedForcedTools
|
||||
}
|
||||
|
||||
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
import { Cerebras } from '@cerebras/cerebras_cloud_sdk'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import type { CerebrasResponse } from '@/providers/cerebras/types'
|
||||
import { createReadableStreamFromCerebrasStream } from '@/providers/cerebras/utils'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
@@ -14,35 +17,9 @@ import {
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
import type { CerebrasResponse } from './types'
|
||||
|
||||
const logger = createLogger('CerebrasProvider')
|
||||
|
||||
/**
|
||||
* Helper to convert a Cerebras streaming response (async iterable) into a ReadableStream.
|
||||
* Enqueues only the model's text delta chunks as UTF-8 encoded bytes.
|
||||
*/
|
||||
function createReadableStreamFromCerebrasStream(
|
||||
cerebrasStream: AsyncIterable<any>
|
||||
): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of cerebrasStream) {
|
||||
// Expecting delta content similar to OpenAI: chunk.choices[0]?.delta?.content
|
||||
const content = chunk.choices?.[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const cerebrasProvider: ProviderConfig = {
|
||||
id: 'cerebras',
|
||||
name: 'Cerebras',
|
||||
@@ -223,7 +200,6 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
@@ -246,7 +222,7 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
const toolCallSignatures = new Set()
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
|
||||
|
||||
23
apps/sim/providers/cerebras/utils.ts
Normal file
23
apps/sim/providers/cerebras/utils.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* Helper to convert a Cerebras streaming response (async iterable) into a ReadableStream.
|
||||
* Enqueues only the model's text delta chunks as UTF-8 encoded bytes.
|
||||
*/
|
||||
export function createReadableStreamFromCerebrasStream(
|
||||
cerebrasStream: AsyncIterable<any>
|
||||
): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of cerebrasStream) {
|
||||
const content = chunk.choices?.[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
import OpenAI from 'openai'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import { createReadableStreamFromDeepseekStream } from '@/providers/deepseek/utils'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
@@ -17,28 +19,6 @@ import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('DeepseekProvider')
|
||||
|
||||
/**
|
||||
* Helper function to convert a DeepSeek (OpenAI-compatible) stream to a ReadableStream
|
||||
* of text chunks that can be consumed by the browser.
|
||||
*/
|
||||
function createReadableStreamFromDeepseekStream(deepseekStream: any): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of deepseekStream) {
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const deepseekProvider: ProviderConfig = {
|
||||
id: 'deepseek',
|
||||
name: 'Deepseek',
|
||||
@@ -231,7 +211,6 @@ export const deepseekProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
// Track if a forced tool has been used
|
||||
let hasUsedForcedTool = false
|
||||
@@ -270,7 +249,7 @@ export const deepseekProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
|
||||
21
apps/sim/providers/deepseek/utils.ts
Normal file
21
apps/sim/providers/deepseek/utils.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
/**
|
||||
* Helper function to convert a DeepSeek (OpenAI-compatible) stream to a ReadableStream
|
||||
* of text chunks that can be consumed by the browser.
|
||||
*/
|
||||
export function createReadableStreamFromDeepseekStream(deepseekStream: any): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of deepseekStream) {
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -1,5 +1,12 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import {
|
||||
cleanSchemaForGemini,
|
||||
convertToGeminiFormat,
|
||||
extractFunctionCall,
|
||||
extractTextContent,
|
||||
} from '@/providers/google/utils'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
@@ -19,7 +26,13 @@ const logger = createLogger('GoogleProvider')
|
||||
/**
|
||||
* Creates a ReadableStream from Google's Gemini stream response
|
||||
*/
|
||||
function createReadableStreamFromGeminiStream(response: Response): ReadableStream<Uint8Array> {
|
||||
function createReadableStreamFromGeminiStream(
|
||||
response: Response,
|
||||
onComplete?: (
|
||||
content: string,
|
||||
usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number }
|
||||
) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
throw new Error('Failed to get reader from response body')
|
||||
@@ -29,18 +42,24 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
|
||||
async start(controller) {
|
||||
try {
|
||||
let buffer = ''
|
||||
let fullContent = ''
|
||||
let usageData: {
|
||||
promptTokenCount?: number
|
||||
candidatesTokenCount?: number
|
||||
totalTokenCount?: number
|
||||
} | null = null
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
// Try to parse any remaining buffer as complete JSON
|
||||
if (buffer.trim()) {
|
||||
// Processing final buffer
|
||||
try {
|
||||
const data = JSON.parse(buffer.trim())
|
||||
if (data.usageMetadata) {
|
||||
usageData = data.usageMetadata
|
||||
}
|
||||
const candidate = data.candidates?.[0]
|
||||
if (candidate?.content?.parts) {
|
||||
// Check if this is a function call
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
@@ -49,26 +68,27 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
// Function calls should not be streamed - end the stream early
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
const content = extractTextContent(candidate)
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
// Final buffer not valid JSON, checking if it contains JSON array
|
||||
// Try parsing as JSON array if it starts with [
|
||||
if (buffer.trim().startsWith('[')) {
|
||||
try {
|
||||
const dataArray = JSON.parse(buffer.trim())
|
||||
if (Array.isArray(dataArray)) {
|
||||
for (const item of dataArray) {
|
||||
if (item.usageMetadata) {
|
||||
usageData = item.usageMetadata
|
||||
}
|
||||
const candidate = item.candidates?.[0]
|
||||
if (candidate?.content?.parts) {
|
||||
// Check if this is a function call
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
@@ -77,11 +97,13 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
const content = extractTextContent(candidate)
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
@@ -93,6 +115,7 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
|
||||
}
|
||||
}
|
||||
}
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
break
|
||||
}
|
||||
@@ -100,14 +123,11 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
|
||||
const text = new TextDecoder().decode(value)
|
||||
buffer += text
|
||||
|
||||
// Try to find complete JSON objects in buffer
|
||||
// Look for patterns like: {...}\n{...} or just a single {...}
|
||||
let searchIndex = 0
|
||||
while (searchIndex < buffer.length) {
|
||||
const openBrace = buffer.indexOf('{', searchIndex)
|
||||
if (openBrace === -1) break
|
||||
|
||||
// Try to find the matching closing brace
|
||||
let braceCount = 0
|
||||
let inString = false
|
||||
let escaped = false
|
||||
@@ -138,28 +158,34 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
|
||||
}
|
||||
|
||||
if (closeBrace !== -1) {
|
||||
// Found a complete JSON object
|
||||
const jsonStr = buffer.substring(openBrace, closeBrace + 1)
|
||||
|
||||
try {
|
||||
const data = JSON.parse(jsonStr)
|
||||
// JSON parsed successfully from stream
|
||||
|
||||
if (data.usageMetadata) {
|
||||
usageData = data.usageMetadata
|
||||
}
|
||||
|
||||
const candidate = data.candidates?.[0]
|
||||
|
||||
// Handle specific finish reasons
|
||||
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
|
||||
logger.warn('Gemini returned UNEXPECTED_TOOL_CALL in streaming mode', {
|
||||
finishReason: candidate.finishReason,
|
||||
hasContent: !!candidate?.content,
|
||||
hasParts: !!candidate?.content?.parts,
|
||||
})
|
||||
// This indicates a configuration issue - tools might be improperly configured for streaming
|
||||
continue
|
||||
const textContent = extractTextContent(candidate)
|
||||
if (textContent) {
|
||||
fullContent += textContent
|
||||
controller.enqueue(new TextEncoder().encode(textContent))
|
||||
}
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
|
||||
if (candidate?.content?.parts) {
|
||||
// Check if this is a function call
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
@@ -168,13 +194,13 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
// Function calls should not be streamed - we need to end the stream
|
||||
// and let the non-streaming tool execution flow handle this
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
const content = extractTextContent(candidate)
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
@@ -185,7 +211,6 @@ function createReadableStreamFromGeminiStream(response: Response): ReadableStrea
|
||||
})
|
||||
}
|
||||
|
||||
// Remove processed JSON from buffer and continue searching
|
||||
buffer = buffer.substring(closeBrace + 1)
|
||||
searchIndex = 0
|
||||
} else {
|
||||
@@ -232,45 +257,36 @@ export const googleProvider: ProviderConfig = {
|
||||
streaming: !!request.stream,
|
||||
})
|
||||
|
||||
// Start execution timer for the entire provider execution
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
// Convert messages to Gemini format
|
||||
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
|
||||
|
||||
const requestedModel = request.model || 'gemini-2.5-pro'
|
||||
|
||||
// Build request payload
|
||||
const payload: any = {
|
||||
contents,
|
||||
generationConfig: {},
|
||||
}
|
||||
|
||||
// Add temperature if specified
|
||||
if (request.temperature !== undefined && request.temperature !== null) {
|
||||
payload.generationConfig.temperature = request.temperature
|
||||
}
|
||||
|
||||
// Add max tokens if specified
|
||||
if (request.maxTokens !== undefined) {
|
||||
payload.generationConfig.maxOutputTokens = request.maxTokens
|
||||
}
|
||||
|
||||
// Add system instruction if provided
|
||||
if (systemInstruction) {
|
||||
payload.systemInstruction = systemInstruction
|
||||
}
|
||||
|
||||
// Add structured output format if requested (but not when tools are present)
|
||||
if (request.responseFormat && !tools?.length) {
|
||||
const responseFormatSchema = request.responseFormat.schema || request.responseFormat
|
||||
|
||||
// Clean the schema using our helper function
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
// Use Gemini's native structured output approach
|
||||
payload.generationConfig.responseMimeType = 'application/json'
|
||||
payload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
@@ -284,7 +300,6 @@ export const googleProvider: ProviderConfig = {
|
||||
)
|
||||
}
|
||||
|
||||
// Handle tools and tool usage control
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
|
||||
if (tools?.length) {
|
||||
@@ -298,7 +313,6 @@ export const googleProvider: ProviderConfig = {
|
||||
},
|
||||
]
|
||||
|
||||
// Add Google-specific tool configuration
|
||||
if (toolConfig) {
|
||||
payload.toolConfig = toolConfig
|
||||
}
|
||||
@@ -313,14 +327,10 @@ export const googleProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// Make the API request
|
||||
const initialCallTime = Date.now()
|
||||
|
||||
// Disable streaming for initial requests when tools are present to avoid function calls in streams
|
||||
// Only enable streaming for the final response after tool execution
|
||||
const shouldStream = request.stream && !tools?.length
|
||||
|
||||
// Use streamGenerateContent for streaming requests
|
||||
const endpoint = shouldStream
|
||||
? `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:streamGenerateContent?key=${request.apiKey}`
|
||||
: `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:generateContent?key=${request.apiKey}`
|
||||
@@ -352,16 +362,11 @@ export const googleProvider: ProviderConfig = {
|
||||
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
// Handle streaming response
|
||||
if (shouldStream) {
|
||||
logger.info('Handling Google Gemini streaming response')
|
||||
|
||||
// Create a ReadableStream from the Google Gemini stream
|
||||
const stream = createReadableStreamFromGeminiStream(response)
|
||||
|
||||
// Create an object that combines the stream with execution metadata
|
||||
const streamingExecution: StreamingExecution = {
|
||||
stream,
|
||||
const streamingResult: StreamingExecution = {
|
||||
stream: null as any,
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
@@ -389,7 +394,6 @@ export const googleProvider: ProviderConfig = {
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
],
|
||||
// Cost will be calculated in logger
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
@@ -402,18 +406,49 @@ export const googleProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
|
||||
return streamingExecution
|
||||
streamingResult.stream = createReadableStreamFromGeminiStream(
|
||||
response,
|
||||
(content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
if (streamingResult.execution.output.providerTiming) {
|
||||
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
|
||||
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
|
||||
streamEndTime
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.promptTokenCount || 0,
|
||||
completion: usage.candidatesTokenCount || 0,
|
||||
total:
|
||||
usage.totalTokenCount ||
|
||||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return streamingResult
|
||||
}
|
||||
|
||||
let geminiResponse = await response.json()
|
||||
|
||||
// Check structured output format
|
||||
if (payload.generationConfig?.responseSchema) {
|
||||
const candidate = geminiResponse.candidates?.[0]
|
||||
if (candidate?.content?.parts?.[0]?.text) {
|
||||
const text = candidate.content.parts[0].text
|
||||
try {
|
||||
// Validate JSON structure
|
||||
JSON.parse(text)
|
||||
logger.info('Successfully received structured JSON output')
|
||||
} catch (_e) {
|
||||
@@ -422,7 +457,6 @@ export const googleProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize response tracking variables
|
||||
let content = ''
|
||||
let tokens = {
|
||||
prompt: 0,
|
||||
@@ -432,16 +466,13 @@ export const googleProvider: ProviderConfig = {
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
// Track forced tools and their usage (similar to OpenAI pattern)
|
||||
const originalToolConfig = preparedTools?.toolConfig
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
let hasUsedForcedTool = false
|
||||
let currentToolConfig = originalToolConfig
|
||||
|
||||
// Helper function to check for forced tool usage in responses
|
||||
const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => {
|
||||
if (currentToolConfig && forcedTools.length > 0) {
|
||||
const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }]
|
||||
@@ -466,11 +497,9 @@ export const googleProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
// Track each model and tool call segment with timestamps
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
@@ -482,46 +511,50 @@ export const googleProvider: ProviderConfig = {
|
||||
]
|
||||
|
||||
try {
|
||||
// Extract content or function calls from initial response
|
||||
const candidate = geminiResponse.candidates?.[0]
|
||||
|
||||
// Check if response contains function calls
|
||||
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
|
||||
logger.warn(
|
||||
'Gemini returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
|
||||
{
|
||||
finishReason: candidate.finishReason,
|
||||
hasContent: !!candidate?.content,
|
||||
hasParts: !!candidate?.content?.parts,
|
||||
}
|
||||
)
|
||||
content = extractTextContent(candidate)
|
||||
}
|
||||
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
|
||||
if (functionCall) {
|
||||
logger.info(`Received function call from Gemini: ${functionCall.name}`)
|
||||
|
||||
// Process function calls in a loop
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
// Get the latest function calls
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
const latestResponse = geminiResponse.candidates?.[0]
|
||||
const latestFunctionCall = extractFunctionCall(latestResponse)
|
||||
|
||||
if (!latestFunctionCall) {
|
||||
// No more function calls - extract final text content
|
||||
content = extractTextContent(latestResponse)
|
||||
break
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
|
||||
`Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
// Track time for tool calls
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
try {
|
||||
const toolName = latestFunctionCall.name
|
||||
const toolArgs = latestFunctionCall.args || {}
|
||||
|
||||
// Get the tool from the tools registry
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) {
|
||||
logger.warn(`Tool ${toolName} not found in registry, skipping`)
|
||||
break
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
@@ -529,7 +562,6 @@ export const googleProvider: ProviderConfig = {
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
// Add to time segments for both success and failure
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
@@ -538,13 +570,11 @@ export const googleProvider: ProviderConfig = {
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
// Prepare result content for the LLM
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
// Include error information so LLM can respond appropriately
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
@@ -562,14 +592,10 @@ export const googleProvider: ProviderConfig = {
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
// Prepare for next request with simplified messages
|
||||
// Use simple format: original query + most recent function call + result
|
||||
const simplifiedMessages = [
|
||||
// Original user request - find the first user request
|
||||
...(contents.filter((m) => m.role === 'user').length > 0
|
||||
? [contents.filter((m) => m.role === 'user')[0]]
|
||||
: [contents[0]]),
|
||||
// Function call from model
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
@@ -581,7 +607,6 @@ export const googleProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
// Function response - but use USER role since Gemini only accepts user or model
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
@@ -592,35 +617,27 @@ export const googleProvider: ProviderConfig = {
|
||||
},
|
||||
]
|
||||
|
||||
// Calculate tool call time
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
// Check for forced tool usage and update configuration
|
||||
checkForForcedToolUsage(latestFunctionCall)
|
||||
|
||||
// Make the next request with updated messages
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
try {
|
||||
// Check if we should stream the final response after tool calls
|
||||
if (request.stream) {
|
||||
// Create a payload for the streaming response after tool calls
|
||||
const streamingPayload = {
|
||||
...payload,
|
||||
contents: simplifiedMessages,
|
||||
}
|
||||
|
||||
// Check if we should remove tools and enable structured output for final response
|
||||
const allForcedToolsUsed =
|
||||
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
|
||||
|
||||
if (allForcedToolsUsed && request.responseFormat) {
|
||||
// All forced tools have been used, we can now remove tools and enable structured output
|
||||
streamingPayload.tools = undefined
|
||||
streamingPayload.toolConfig = undefined
|
||||
|
||||
// Add structured output format for final response
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
@@ -633,7 +650,6 @@ export const googleProvider: ProviderConfig = {
|
||||
|
||||
logger.info('Using structured output for final response after tool execution')
|
||||
} else {
|
||||
// Use updated tool configuration if available, otherwise default to AUTO
|
||||
if (currentToolConfig) {
|
||||
streamingPayload.toolConfig = currentToolConfig
|
||||
} else {
|
||||
@@ -641,11 +657,8 @@ export const googleProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we should handle this as a potential forced tool call
|
||||
// First make a non-streaming request to see if we get a function call
|
||||
const checkPayload = {
|
||||
...streamingPayload,
|
||||
// Remove stream property to get non-streaming response
|
||||
}
|
||||
checkPayload.stream = undefined
|
||||
|
||||
@@ -677,7 +690,6 @@ export const googleProvider: ProviderConfig = {
|
||||
const checkFunctionCall = extractFunctionCall(checkCandidate)
|
||||
|
||||
if (checkFunctionCall) {
|
||||
// We have a function call - handle it in non-streaming mode
|
||||
logger.info(
|
||||
'Function call detected in follow-up, handling in non-streaming mode',
|
||||
{
|
||||
@@ -685,10 +697,8 @@ export const googleProvider: ProviderConfig = {
|
||||
}
|
||||
)
|
||||
|
||||
// Update geminiResponse to continue the tool execution loop
|
||||
geminiResponse = checkResult
|
||||
|
||||
// Update token counts if available
|
||||
if (checkResult.usageMetadata) {
|
||||
tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0
|
||||
tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0
|
||||
@@ -697,12 +707,10 @@ export const googleProvider: ProviderConfig = {
|
||||
(checkResult.usageMetadata.candidatesTokenCount || 0)
|
||||
}
|
||||
|
||||
// Calculate timing for this model call
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
modelTime += thisModelTime
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
@@ -711,14 +719,32 @@ export const googleProvider: ProviderConfig = {
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
// Continue the loop to handle the function call
|
||||
iterationCount++
|
||||
continue
|
||||
}
|
||||
// No function call - proceed with streaming
|
||||
logger.info('No function call detected, proceeding with streaming response')
|
||||
|
||||
// Make the streaming request with streamGenerateContent endpoint
|
||||
// Apply structured output for the final response if responseFormat is specified
|
||||
// This works regardless of whether tools were forced or auto
|
||||
if (request.responseFormat) {
|
||||
streamingPayload.tools = undefined
|
||||
streamingPayload.toolConfig = undefined
|
||||
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
if (!streamingPayload.generationConfig) {
|
||||
streamingPayload.generationConfig = {}
|
||||
}
|
||||
streamingPayload.generationConfig.responseMimeType = 'application/json'
|
||||
streamingPayload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info(
|
||||
'Using structured output for final streaming response after tool execution'
|
||||
)
|
||||
}
|
||||
|
||||
const streamingResponse = await fetch(
|
||||
`https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:streamGenerateContent?key=${request.apiKey}`,
|
||||
{
|
||||
@@ -742,15 +768,10 @@ export const googleProvider: ProviderConfig = {
|
||||
)
|
||||
}
|
||||
|
||||
// Create a stream from the response
|
||||
const stream = createReadableStreamFromGeminiStream(streamingResponse)
|
||||
|
||||
// Calculate timing information
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
modelTime += thisModelTime
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: 'Final streaming response after tool calls',
|
||||
@@ -759,9 +780,8 @@ export const googleProvider: ProviderConfig = {
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
// Return a streaming execution with tool call information
|
||||
const streamingExecution: StreamingExecution = {
|
||||
stream,
|
||||
stream: null as any,
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
@@ -786,7 +806,6 @@ export const googleProvider: ProviderConfig = {
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments,
|
||||
},
|
||||
// Cost will be calculated in logger
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
@@ -798,25 +817,55 @@ export const googleProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
|
||||
streamingExecution.stream = createReadableStreamFromGeminiStream(
|
||||
streamingResponse,
|
||||
(content, usage) => {
|
||||
streamingExecution.execution.output.content = content
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
if (streamingExecution.execution.output.providerTiming) {
|
||||
streamingExecution.execution.output.providerTiming.endTime =
|
||||
streamEndTimeISO
|
||||
streamingExecution.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
const existingTokens = streamingExecution.execution.output.tokens || {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
streamingExecution.execution.output.tokens = {
|
||||
prompt: (existingTokens.prompt || 0) + (usage.promptTokenCount || 0),
|
||||
completion:
|
||||
(existingTokens.completion || 0) + (usage.candidatesTokenCount || 0),
|
||||
total:
|
||||
(existingTokens.total || 0) +
|
||||
(usage.totalTokenCount ||
|
||||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return streamingExecution
|
||||
}
|
||||
|
||||
// Make the next request for non-streaming response
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
contents: simplifiedMessages,
|
||||
}
|
||||
|
||||
// Check if we should remove tools and enable structured output for final response
|
||||
const allForcedToolsUsed =
|
||||
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
|
||||
|
||||
if (allForcedToolsUsed && request.responseFormat) {
|
||||
// All forced tools have been used, we can now remove tools and enable structured output
|
||||
nextPayload.tools = undefined
|
||||
nextPayload.toolConfig = undefined
|
||||
|
||||
// Add structured output format for final response
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
@@ -831,7 +880,6 @@ export const googleProvider: ProviderConfig = {
|
||||
'Using structured output for final non-streaming response after tool execution'
|
||||
)
|
||||
} else {
|
||||
// Add updated tool configuration if available
|
||||
if (currentToolConfig) {
|
||||
nextPayload.toolConfig = currentToolConfig
|
||||
}
|
||||
@@ -864,7 +912,6 @@ export const googleProvider: ProviderConfig = {
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
@@ -873,15 +920,65 @@ export const googleProvider: ProviderConfig = {
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
// Add to model time
|
||||
modelTime += thisModelTime
|
||||
|
||||
// Check if we need to continue or break
|
||||
const nextCandidate = geminiResponse.candidates?.[0]
|
||||
const nextFunctionCall = extractFunctionCall(nextCandidate)
|
||||
|
||||
if (!nextFunctionCall) {
|
||||
content = extractTextContent(nextCandidate)
|
||||
// If responseFormat is specified, make one final request with structured output
|
||||
if (request.responseFormat) {
|
||||
const finalPayload = {
|
||||
...payload,
|
||||
contents: nextPayload.contents,
|
||||
tools: undefined,
|
||||
toolConfig: undefined,
|
||||
}
|
||||
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
if (!finalPayload.generationConfig) {
|
||||
finalPayload.generationConfig = {}
|
||||
}
|
||||
finalPayload.generationConfig.responseMimeType = 'application/json'
|
||||
finalPayload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info('Making final request with structured output after tool execution')
|
||||
|
||||
const finalResponse = await fetch(
|
||||
`https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:generateContent?key=${request.apiKey}`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(finalPayload),
|
||||
}
|
||||
)
|
||||
|
||||
if (finalResponse.ok) {
|
||||
const finalResult = await finalResponse.json()
|
||||
const finalCandidate = finalResult.candidates?.[0]
|
||||
content = extractTextContent(finalCandidate)
|
||||
|
||||
if (finalResult.usageMetadata) {
|
||||
tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0
|
||||
tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0
|
||||
tokens.total +=
|
||||
(finalResult.usageMetadata.promptTokenCount || 0) +
|
||||
(finalResult.usageMetadata.candidatesTokenCount || 0)
|
||||
}
|
||||
} else {
|
||||
logger.warn(
|
||||
'Failed to get structured output, falling back to regular response'
|
||||
)
|
||||
content = extractTextContent(nextCandidate)
|
||||
}
|
||||
} else {
|
||||
content = extractTextContent(nextCandidate)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
@@ -902,7 +999,6 @@ export const googleProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Regular text response
|
||||
content = extractTextContent(candidate)
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -911,18 +1007,15 @@ export const googleProvider: ProviderConfig = {
|
||||
iterationCount,
|
||||
})
|
||||
|
||||
// Don't rethrow, so we can still return partial results
|
||||
if (!content && toolCalls.length > 0) {
|
||||
content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.`
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate overall timing
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
// Extract token usage if available
|
||||
if (geminiResponse.usageMetadata) {
|
||||
tokens = {
|
||||
prompt: geminiResponse.usageMetadata.promptTokenCount || 0,
|
||||
@@ -949,10 +1042,8 @@ export const googleProvider: ProviderConfig = {
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
// Cost will be calculated in logger
|
||||
}
|
||||
} catch (error) {
|
||||
// Include timing information even for errors
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -962,7 +1053,6 @@ export const googleProvider: ProviderConfig = {
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
// Create a new error with timing information
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
enhancedError.timing = {
|
||||
@@ -975,200 +1065,3 @@ export const googleProvider: ProviderConfig = {
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to remove additionalProperties from a schema object
|
||||
* and perform a deep copy of the schema to avoid modifying the original
|
||||
*/
|
||||
function cleanSchemaForGemini(schema: any): any {
|
||||
// Handle base cases
|
||||
if (schema === null || schema === undefined) return schema
|
||||
if (typeof schema !== 'object') return schema
|
||||
if (Array.isArray(schema)) {
|
||||
return schema.map((item) => cleanSchemaForGemini(item))
|
||||
}
|
||||
|
||||
// Create a new object for the deep copy
|
||||
const cleanedSchema: any = {}
|
||||
|
||||
// Process each property in the schema
|
||||
for (const key in schema) {
|
||||
// Skip additionalProperties
|
||||
if (key === 'additionalProperties') continue
|
||||
|
||||
// Deep copy nested objects
|
||||
cleanedSchema[key] = cleanSchemaForGemini(schema[key])
|
||||
}
|
||||
|
||||
return cleanedSchema
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to extract content from a Gemini response, handling structured output
|
||||
*/
|
||||
function extractTextContent(candidate: any): string {
|
||||
if (!candidate?.content?.parts) return ''
|
||||
|
||||
// Check for JSON response (typically from structured output)
|
||||
if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) {
|
||||
const text = candidate.content.parts[0].text
|
||||
if (text && (text.trim().startsWith('{') || text.trim().startsWith('['))) {
|
||||
try {
|
||||
JSON.parse(text) // Validate JSON
|
||||
return text // Return valid JSON as-is
|
||||
} catch (_e) {
|
||||
/* Not valid JSON, continue with normal extraction */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Standard text extraction
|
||||
return candidate.content.parts
|
||||
.filter((part: any) => part.text)
|
||||
.map((part: any) => part.text)
|
||||
.join('\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to extract a function call from a Gemini response
|
||||
*/
|
||||
function extractFunctionCall(candidate: any): { name: string; args: any } | null {
|
||||
if (!candidate?.content?.parts) return null
|
||||
|
||||
// Check for functionCall in parts
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.functionCall) {
|
||||
const args = part.functionCall.args || {}
|
||||
// Parse string args if they look like JSON
|
||||
if (
|
||||
typeof part.functionCall.args === 'string' &&
|
||||
part.functionCall.args.trim().startsWith('{')
|
||||
) {
|
||||
try {
|
||||
return { name: part.functionCall.name, args: JSON.parse(part.functionCall.args) }
|
||||
} catch (_e) {
|
||||
return { name: part.functionCall.name, args: part.functionCall.args }
|
||||
}
|
||||
}
|
||||
return { name: part.functionCall.name, args }
|
||||
}
|
||||
}
|
||||
|
||||
// Check for alternative function_call format
|
||||
if (candidate.content.function_call) {
|
||||
const args =
|
||||
typeof candidate.content.function_call.arguments === 'string'
|
||||
? JSON.parse(candidate.content.function_call.arguments || '{}')
|
||||
: candidate.content.function_call.arguments || {}
|
||||
return { name: candidate.content.function_call.name, args }
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert OpenAI-style request format to Gemini format
|
||||
*/
|
||||
function convertToGeminiFormat(request: ProviderRequest): {
|
||||
contents: any[]
|
||||
tools: any[] | undefined
|
||||
systemInstruction: any | undefined
|
||||
} {
|
||||
const contents = []
|
||||
let systemInstruction
|
||||
|
||||
// Handle system prompt
|
||||
if (request.systemPrompt) {
|
||||
systemInstruction = { parts: [{ text: request.systemPrompt }] }
|
||||
}
|
||||
|
||||
// Add context as user message if present
|
||||
if (request.context) {
|
||||
contents.push({ role: 'user', parts: [{ text: request.context }] })
|
||||
}
|
||||
|
||||
// Process messages
|
||||
if (request.messages && request.messages.length > 0) {
|
||||
for (const message of request.messages) {
|
||||
if (message.role === 'system') {
|
||||
// Add to system instruction
|
||||
if (!systemInstruction) {
|
||||
systemInstruction = { parts: [{ text: message.content }] }
|
||||
} else {
|
||||
// Append to existing system instruction
|
||||
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text || ''}\n${message.content}`
|
||||
}
|
||||
} else if (message.role === 'user' || message.role === 'assistant') {
|
||||
// Convert to Gemini role format
|
||||
const geminiRole = message.role === 'user' ? 'user' : 'model'
|
||||
|
||||
// Add text content
|
||||
if (message.content) {
|
||||
contents.push({ role: geminiRole, parts: [{ text: message.content }] })
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if (message.role === 'assistant' && message.tool_calls && message.tool_calls.length > 0) {
|
||||
const functionCalls = message.tool_calls.map((toolCall) => ({
|
||||
functionCall: {
|
||||
name: toolCall.function?.name,
|
||||
args: JSON.parse(toolCall.function?.arguments || '{}'),
|
||||
},
|
||||
}))
|
||||
|
||||
contents.push({ role: 'model', parts: functionCalls })
|
||||
}
|
||||
} else if (message.role === 'tool') {
|
||||
// Convert tool response (Gemini only accepts user/model roles)
|
||||
contents.push({
|
||||
role: 'user',
|
||||
parts: [{ text: `Function result: ${message.content}` }],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tools to Gemini function declarations
|
||||
const tools = request.tools?.map((tool) => {
|
||||
const toolParameters = { ...(tool.parameters || {}) }
|
||||
|
||||
// Process schema properties
|
||||
if (toolParameters.properties) {
|
||||
const properties = { ...toolParameters.properties }
|
||||
const required = toolParameters.required ? [...toolParameters.required] : []
|
||||
|
||||
// Remove defaults and optional parameters
|
||||
for (const key in properties) {
|
||||
const prop = properties[key] as any
|
||||
|
||||
if (prop.default !== undefined) {
|
||||
const { default: _, ...cleanProp } = prop
|
||||
properties[key] = cleanProp
|
||||
}
|
||||
}
|
||||
|
||||
// Build Gemini-compatible parameters schema
|
||||
const parameters = {
|
||||
type: toolParameters.type || 'object',
|
||||
properties,
|
||||
...(required.length > 0 ? { required } : {}),
|
||||
}
|
||||
|
||||
// Clean schema for Gemini
|
||||
return {
|
||||
name: tool.id,
|
||||
description: tool.description || `Execute the ${tool.id} function`,
|
||||
parameters: cleanSchemaForGemini(parameters),
|
||||
}
|
||||
}
|
||||
|
||||
// Simple schema case
|
||||
return {
|
||||
name: tool.id,
|
||||
description: tool.description || `Execute the ${tool.id} function`,
|
||||
parameters: cleanSchemaForGemini(toolParameters),
|
||||
}
|
||||
})
|
||||
|
||||
return { contents, tools, systemInstruction }
|
||||
}
|
||||
|
||||
171
apps/sim/providers/google/utils.ts
Normal file
171
apps/sim/providers/google/utils.ts
Normal file
@@ -0,0 +1,171 @@
|
||||
import type { ProviderRequest } from '@/providers/types'
|
||||
|
||||
/**
|
||||
* Removes additionalProperties from a schema object (not supported by Gemini)
|
||||
*/
|
||||
export function cleanSchemaForGemini(schema: any): any {
|
||||
if (schema === null || schema === undefined) return schema
|
||||
if (typeof schema !== 'object') return schema
|
||||
if (Array.isArray(schema)) {
|
||||
return schema.map((item) => cleanSchemaForGemini(item))
|
||||
}
|
||||
|
||||
const cleanedSchema: any = {}
|
||||
|
||||
for (const key in schema) {
|
||||
if (key === 'additionalProperties') continue
|
||||
cleanedSchema[key] = cleanSchemaForGemini(schema[key])
|
||||
}
|
||||
|
||||
return cleanedSchema
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts text content from a Gemini response candidate, handling structured output
|
||||
*/
|
||||
export function extractTextContent(candidate: any): string {
|
||||
if (!candidate?.content?.parts) return ''
|
||||
|
||||
if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) {
|
||||
const text = candidate.content.parts[0].text
|
||||
if (text && (text.trim().startsWith('{') || text.trim().startsWith('['))) {
|
||||
try {
|
||||
JSON.parse(text)
|
||||
return text
|
||||
} catch (_e) {
|
||||
/* Not valid JSON, continue with normal extraction */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return candidate.content.parts
|
||||
.filter((part: any) => part.text)
|
||||
.map((part: any) => part.text)
|
||||
.join('\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a function call from a Gemini response candidate
|
||||
*/
|
||||
export function extractFunctionCall(candidate: any): { name: string; args: any } | null {
|
||||
if (!candidate?.content?.parts) return null
|
||||
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.functionCall) {
|
||||
const args = part.functionCall.args || {}
|
||||
if (
|
||||
typeof part.functionCall.args === 'string' &&
|
||||
part.functionCall.args.trim().startsWith('{')
|
||||
) {
|
||||
try {
|
||||
return { name: part.functionCall.name, args: JSON.parse(part.functionCall.args) }
|
||||
} catch (_e) {
|
||||
return { name: part.functionCall.name, args: part.functionCall.args }
|
||||
}
|
||||
}
|
||||
return { name: part.functionCall.name, args }
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate.content.function_call) {
|
||||
const args =
|
||||
typeof candidate.content.function_call.arguments === 'string'
|
||||
? JSON.parse(candidate.content.function_call.arguments || '{}')
|
||||
: candidate.content.function_call.arguments || {}
|
||||
return { name: candidate.content.function_call.name, args }
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts OpenAI-style request format to Gemini format
|
||||
*/
|
||||
export function convertToGeminiFormat(request: ProviderRequest): {
|
||||
contents: any[]
|
||||
tools: any[] | undefined
|
||||
systemInstruction: any | undefined
|
||||
} {
|
||||
const contents: any[] = []
|
||||
let systemInstruction
|
||||
|
||||
if (request.systemPrompt) {
|
||||
systemInstruction = { parts: [{ text: request.systemPrompt }] }
|
||||
}
|
||||
|
||||
if (request.context) {
|
||||
contents.push({ role: 'user', parts: [{ text: request.context }] })
|
||||
}
|
||||
|
||||
if (request.messages && request.messages.length > 0) {
|
||||
for (const message of request.messages) {
|
||||
if (message.role === 'system') {
|
||||
if (!systemInstruction) {
|
||||
systemInstruction = { parts: [{ text: message.content }] }
|
||||
} else {
|
||||
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text || ''}\n${message.content}`
|
||||
}
|
||||
} else if (message.role === 'user' || message.role === 'assistant') {
|
||||
const geminiRole = message.role === 'user' ? 'user' : 'model'
|
||||
|
||||
if (message.content) {
|
||||
contents.push({ role: geminiRole, parts: [{ text: message.content }] })
|
||||
}
|
||||
|
||||
if (message.role === 'assistant' && message.tool_calls && message.tool_calls.length > 0) {
|
||||
const functionCalls = message.tool_calls.map((toolCall) => ({
|
||||
functionCall: {
|
||||
name: toolCall.function?.name,
|
||||
args: JSON.parse(toolCall.function?.arguments || '{}'),
|
||||
},
|
||||
}))
|
||||
|
||||
contents.push({ role: 'model', parts: functionCalls })
|
||||
}
|
||||
} else if (message.role === 'tool') {
|
||||
contents.push({
|
||||
role: 'user',
|
||||
parts: [{ text: `Function result: ${message.content}` }],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const tools = request.tools?.map((tool) => {
|
||||
const toolParameters = { ...(tool.parameters || {}) }
|
||||
|
||||
if (toolParameters.properties) {
|
||||
const properties = { ...toolParameters.properties }
|
||||
const required = toolParameters.required ? [...toolParameters.required] : []
|
||||
|
||||
for (const key in properties) {
|
||||
const prop = properties[key] as any
|
||||
|
||||
if (prop.default !== undefined) {
|
||||
const { default: _, ...cleanProp } = prop
|
||||
properties[key] = cleanProp
|
||||
}
|
||||
}
|
||||
|
||||
const parameters = {
|
||||
type: toolParameters.type || 'object',
|
||||
properties,
|
||||
...(required.length > 0 ? { required } : {}),
|
||||
}
|
||||
|
||||
return {
|
||||
name: tool.id,
|
||||
description: tool.description || `Execute the ${tool.id} function`,
|
||||
parameters: cleanSchemaForGemini(parameters),
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
name: tool.id,
|
||||
description: tool.description || `Execute the ${tool.id} function`,
|
||||
parameters: cleanSchemaForGemini(toolParameters),
|
||||
}
|
||||
})
|
||||
|
||||
return { contents, tools, systemInstruction }
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
import { Groq } from 'groq-sdk'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import { createReadableStreamFromGroqStream } from '@/providers/groq/utils'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
@@ -17,27 +19,6 @@ import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('GroqProvider')
|
||||
|
||||
/**
|
||||
* Helper to wrap Groq streaming into a browser-friendly ReadableStream
|
||||
* of raw assistant text chunks.
|
||||
*/
|
||||
function createReadableStreamFromGroqStream(groqStream: any): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of groqStream) {
|
||||
if (chunk.choices[0]?.delta?.content) {
|
||||
controller.enqueue(new TextEncoder().encode(chunk.choices[0].delta.content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (err) {
|
||||
controller.error(err)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const groqProvider: ProviderConfig = {
|
||||
id: 'groq',
|
||||
name: 'Groq',
|
||||
@@ -225,7 +206,6 @@ export const groqProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
@@ -243,7 +223,7 @@ export const groqProvider: ProviderConfig = {
|
||||
]
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
|
||||
23
apps/sim/providers/groq/utils.ts
Normal file
23
apps/sim/providers/groq/utils.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* Helper to wrap Groq streaming into a browser-friendly ReadableStream
|
||||
* of raw assistant text chunks.
|
||||
*
|
||||
* @param groqStream - The Groq streaming response
|
||||
* @returns A ReadableStream that emits text chunks
|
||||
*/
|
||||
export function createReadableStreamFromGroqStream(groqStream: any): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of groqStream) {
|
||||
if (chunk.choices[0]?.delta?.content) {
|
||||
controller.enqueue(new TextEncoder().encode(chunk.choices[0].delta.content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (err) {
|
||||
controller.error(err)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -12,6 +12,12 @@ import {
|
||||
|
||||
const logger = createLogger('Providers')
|
||||
|
||||
/**
|
||||
* Maximum number of iterations for tool call loops to prevent infinite loops.
|
||||
* Used across all providers that support tool/function calling.
|
||||
*/
|
||||
export const MAX_TOOL_ITERATIONS = 20
|
||||
|
||||
function sanitizeRequest(request: ProviderRequest): ProviderRequest {
|
||||
const sanitizedRequest = { ...request }
|
||||
|
||||
@@ -44,7 +50,6 @@ export async function executeProviderRequest(
|
||||
}
|
||||
const sanitizedRequest = sanitizeRequest(request)
|
||||
|
||||
// If responseFormat is provided, modify the system prompt to enforce structured output
|
||||
if (sanitizedRequest.responseFormat) {
|
||||
if (
|
||||
typeof sanitizedRequest.responseFormat === 'string' &&
|
||||
@@ -53,12 +58,10 @@ export async function executeProviderRequest(
|
||||
logger.info('Empty response format provided, ignoring it')
|
||||
sanitizedRequest.responseFormat = undefined
|
||||
} else {
|
||||
// Generate structured output instructions
|
||||
const structuredOutputInstructions = generateStructuredOutputInstructions(
|
||||
sanitizedRequest.responseFormat
|
||||
)
|
||||
|
||||
// Only add additional instructions if they're not empty
|
||||
if (structuredOutputInstructions.trim()) {
|
||||
const originalPrompt = sanitizedRequest.systemPrompt || ''
|
||||
sanitizedRequest.systemPrompt =
|
||||
@@ -69,10 +72,8 @@ export async function executeProviderRequest(
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the request using the provider's implementation
|
||||
const response = await provider.executeRequest(sanitizedRequest)
|
||||
|
||||
// If we received a StreamingExecution or ReadableStream, just pass it through
|
||||
if (isStreamingExecution(response)) {
|
||||
logger.info('Provider returned StreamingExecution')
|
||||
return response
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import OpenAI from 'openai'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import { createReadableStreamFromMistralStream } from '@/providers/mistral/utils'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
@@ -17,40 +19,6 @@ import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('MistralProvider')
|
||||
|
||||
function createReadableStreamFromMistralStream(
|
||||
mistralStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of mistralStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Mistral AI provider configuration
|
||||
*/
|
||||
@@ -288,7 +256,6 @@ export const mistralProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10
|
||||
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
@@ -307,14 +274,14 @@ export const mistralProvider: ProviderConfig = {
|
||||
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
39
apps/sim/providers/mistral/utils.ts
Normal file
39
apps/sim/providers/mistral/utils.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
/**
|
||||
* Creates a ReadableStream from a Mistral AI streaming response
|
||||
* @param mistralStream - The Mistral AI stream object
|
||||
* @param onComplete - Optional callback when streaming completes
|
||||
* @returns A ReadableStream that yields text chunks
|
||||
*/
|
||||
export function createReadableStreamFromMistralStream(
|
||||
mistralStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of mistralStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
OllamaIcon,
|
||||
OpenAIIcon,
|
||||
OpenRouterIcon,
|
||||
VertexIcon,
|
||||
VllmIcon,
|
||||
xAIIcon,
|
||||
} from '@/components/icons'
|
||||
@@ -130,7 +131,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
capabilities: {
|
||||
reasoningEffort: {
|
||||
values: ['none', 'low', 'medium', 'high'],
|
||||
values: ['none', 'minimal', 'low', 'medium', 'high', 'xhigh'],
|
||||
},
|
||||
verbosity: {
|
||||
values: ['low', 'medium', 'high'],
|
||||
@@ -283,7 +284,11 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
output: 60,
|
||||
updatedAt: '2025-06-17',
|
||||
},
|
||||
capabilities: {},
|
||||
capabilities: {
|
||||
reasoningEffort: {
|
||||
values: ['low', 'medium', 'high'],
|
||||
},
|
||||
},
|
||||
contextWindow: 200000,
|
||||
},
|
||||
{
|
||||
@@ -294,7 +299,11 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
output: 8,
|
||||
updatedAt: '2025-06-17',
|
||||
},
|
||||
capabilities: {},
|
||||
capabilities: {
|
||||
reasoningEffort: {
|
||||
values: ['low', 'medium', 'high'],
|
||||
},
|
||||
},
|
||||
contextWindow: 128000,
|
||||
},
|
||||
{
|
||||
@@ -305,7 +314,11 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
output: 4.4,
|
||||
updatedAt: '2025-06-17',
|
||||
},
|
||||
capabilities: {},
|
||||
capabilities: {
|
||||
reasoningEffort: {
|
||||
values: ['low', 'medium', 'high'],
|
||||
},
|
||||
},
|
||||
contextWindow: 128000,
|
||||
},
|
||||
{
|
||||
@@ -383,7 +396,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
capabilities: {
|
||||
reasoningEffort: {
|
||||
values: ['none', 'low', 'medium', 'high'],
|
||||
values: ['none', 'minimal', 'low', 'medium', 'high', 'xhigh'],
|
||||
},
|
||||
verbosity: {
|
||||
values: ['low', 'medium', 'high'],
|
||||
@@ -536,7 +549,11 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
output: 40,
|
||||
updatedAt: '2025-06-15',
|
||||
},
|
||||
capabilities: {},
|
||||
capabilities: {
|
||||
reasoningEffort: {
|
||||
values: ['low', 'medium', 'high'],
|
||||
},
|
||||
},
|
||||
contextWindow: 128000,
|
||||
},
|
||||
{
|
||||
@@ -547,7 +564,11 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
output: 4.4,
|
||||
updatedAt: '2025-06-15',
|
||||
},
|
||||
capabilities: {},
|
||||
capabilities: {
|
||||
reasoningEffort: {
|
||||
values: ['low', 'medium', 'high'],
|
||||
},
|
||||
},
|
||||
contextWindow: 128000,
|
||||
},
|
||||
{
|
||||
@@ -708,9 +729,22 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
id: 'gemini-3-pro-preview',
|
||||
pricing: {
|
||||
input: 2.0,
|
||||
cachedInput: 1.0,
|
||||
cachedInput: 0.2,
|
||||
output: 12.0,
|
||||
updatedAt: '2025-11-18',
|
||||
updatedAt: '2025-12-17',
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
},
|
||||
{
|
||||
id: 'gemini-3-flash-preview',
|
||||
pricing: {
|
||||
input: 0.5,
|
||||
cachedInput: 0.05,
|
||||
output: 3.0,
|
||||
updatedAt: '2025-12-17',
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
@@ -756,6 +790,132 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
contextWindow: 1048576,
|
||||
},
|
||||
{
|
||||
id: 'gemini-2.0-flash',
|
||||
pricing: {
|
||||
input: 0.1,
|
||||
output: 0.4,
|
||||
updatedAt: '2025-12-17',
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
},
|
||||
{
|
||||
id: 'gemini-2.0-flash-lite',
|
||||
pricing: {
|
||||
input: 0.075,
|
||||
output: 0.3,
|
||||
updatedAt: '2025-12-17',
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
},
|
||||
],
|
||||
},
|
||||
vertex: {
|
||||
id: 'vertex',
|
||||
name: 'Vertex AI',
|
||||
description: "Google's Vertex AI platform for Gemini models",
|
||||
defaultModel: 'vertex/gemini-2.5-pro',
|
||||
modelPatterns: [/^vertex\//],
|
||||
icon: VertexIcon,
|
||||
capabilities: {
|
||||
toolUsageControl: true,
|
||||
},
|
||||
models: [
|
||||
{
|
||||
id: 'vertex/gemini-3-pro-preview',
|
||||
pricing: {
|
||||
input: 2.0,
|
||||
cachedInput: 0.2,
|
||||
output: 12.0,
|
||||
updatedAt: '2025-12-17',
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
},
|
||||
{
|
||||
id: 'vertex/gemini-3-flash-preview',
|
||||
pricing: {
|
||||
input: 0.5,
|
||||
cachedInput: 0.05,
|
||||
output: 3.0,
|
||||
updatedAt: '2025-12-17',
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
},
|
||||
{
|
||||
id: 'vertex/gemini-2.5-pro',
|
||||
pricing: {
|
||||
input: 1.25,
|
||||
cachedInput: 0.125,
|
||||
output: 10.0,
|
||||
updatedAt: '2025-12-02',
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
},
|
||||
contextWindow: 1048576,
|
||||
},
|
||||
{
|
||||
id: 'vertex/gemini-2.5-flash',
|
||||
pricing: {
|
||||
input: 0.3,
|
||||
cachedInput: 0.03,
|
||||
output: 2.5,
|
||||
updatedAt: '2025-12-02',
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
},
|
||||
contextWindow: 1048576,
|
||||
},
|
||||
{
|
||||
id: 'vertex/gemini-2.5-flash-lite',
|
||||
pricing: {
|
||||
input: 0.1,
|
||||
cachedInput: 0.01,
|
||||
output: 0.4,
|
||||
updatedAt: '2025-12-02',
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
},
|
||||
contextWindow: 1048576,
|
||||
},
|
||||
{
|
||||
id: 'vertex/gemini-2.0-flash',
|
||||
pricing: {
|
||||
input: 0.1,
|
||||
output: 0.4,
|
||||
updatedAt: '2025-12-17',
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
},
|
||||
{
|
||||
id: 'vertex/gemini-2.0-flash-lite',
|
||||
pricing: {
|
||||
input: 0.075,
|
||||
output: 0.3,
|
||||
updatedAt: '2025-12-17',
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
},
|
||||
],
|
||||
},
|
||||
deepseek: {
|
||||
@@ -1708,6 +1868,20 @@ export function getModelsWithReasoningEffort(): string[] {
|
||||
return models
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort values for a specific model
|
||||
* Returns the valid options for that model, or null if the model doesn't support reasoning effort
|
||||
*/
|
||||
export function getReasoningEffortValuesForModel(modelId: string): string[] | null {
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase())
|
||||
if (model?.capabilities.reasoningEffort) {
|
||||
return model.capabilities.reasoningEffort.values
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models that support verbosity
|
||||
*/
|
||||
@@ -1722,3 +1896,17 @@ export function getModelsWithVerbosity(): string[] {
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the verbosity values for a specific model
|
||||
* Returns the valid options for that model, or null if the model doesn't support verbosity
|
||||
*/
|
||||
export function getVerbosityValuesForModel(modelId: string): string[] | null {
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase())
|
||||
if (model?.capabilities.verbosity) {
|
||||
return model.capabilities.verbosity.values
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ import OpenAI from 'openai'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import type { ModelsObject } from '@/providers/ollama/types'
|
||||
import { createReadableStreamFromOllamaStream } from '@/providers/ollama/utils'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
@@ -16,46 +18,6 @@ import { executeTool } from '@/tools'
|
||||
const logger = createLogger('OllamaProvider')
|
||||
const OLLAMA_HOST = env.OLLAMA_URL || 'http://localhost:11434'
|
||||
|
||||
/**
|
||||
* Helper function to convert an Ollama stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
*/
|
||||
function createReadableStreamFromOllamaStream(
|
||||
ollamaStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of ollamaStream) {
|
||||
// Check for usage data in the final chunk
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
// Once stream is complete, call the completion callback with the final content and usage
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const ollamaProvider: ProviderConfig = {
|
||||
id: 'ollama',
|
||||
name: 'Ollama',
|
||||
@@ -334,7 +296,6 @@ export const ollamaProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
@@ -351,7 +312,7 @@ export const ollamaProvider: ProviderConfig = {
|
||||
},
|
||||
]
|
||||
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
@@ -359,7 +320,7 @@ export const ollamaProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
// Track time for tool calls in this batch
|
||||
|
||||
37
apps/sim/providers/ollama/utils.ts
Normal file
37
apps/sim/providers/ollama/utils.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Helper function to convert an Ollama stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
*/
|
||||
export function createReadableStreamFromOllamaStream(
|
||||
ollamaStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of ollamaStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
import OpenAI from 'openai'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import { createReadableStreamFromOpenAIStream } from '@/providers/openai/utils'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
@@ -17,46 +19,6 @@ import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('OpenAIProvider')
|
||||
|
||||
/**
|
||||
* Helper function to convert an OpenAI stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
*/
|
||||
function createReadableStreamFromOpenAIStream(
|
||||
openaiStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of openaiStream) {
|
||||
// Check for usage data in the final chunk
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
// Once stream is complete, call the completion callback with the final content and usage
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI provider configuration
|
||||
*/
|
||||
@@ -319,7 +281,6 @@ export const openaiProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10 // Prevent infinite loops
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
@@ -342,7 +303,7 @@ export const openaiProvider: ProviderConfig = {
|
||||
// Check if a forced tool was used in the first response
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
@@ -350,7 +311,7 @@ export const openaiProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
// Track time for tool calls in this batch
|
||||
|
||||
37
apps/sim/providers/openai/utils.ts
Normal file
37
apps/sim/providers/openai/utils.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Helper function to convert an OpenAI stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
*/
|
||||
export function createReadableStreamFromOpenAIStream(
|
||||
openaiStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of openaiStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -1,56 +1,23 @@
|
||||
import OpenAI from 'openai'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import {
|
||||
checkForForcedToolUsage,
|
||||
createReadableStreamFromOpenAIStream,
|
||||
} from '@/providers/openrouter/utils'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('OpenRouterProvider')
|
||||
|
||||
function createReadableStreamFromOpenAIStream(
|
||||
openaiStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of openaiStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const openRouterProvider: ProviderConfig = {
|
||||
id: 'openrouter',
|
||||
name: 'OpenRouter',
|
||||
@@ -227,7 +194,6 @@ export const openRouterProvider: ProviderConfig = {
|
||||
const toolResults = [] as any[]
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
let hasUsedForcedTool = false
|
||||
@@ -241,28 +207,16 @@ export const openRouterProvider: ProviderConfig = {
|
||||
},
|
||||
]
|
||||
|
||||
const checkForForcedToolUsage = (
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }
|
||||
) => {
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
'openrouter',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
}
|
||||
}
|
||||
const forcedToolResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
originalToolChoice,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = forcedToolResult.hasUsedForcedTool
|
||||
usedForcedTools = forcedToolResult.usedForcedTools
|
||||
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
@@ -359,7 +313,14 @@ export const openRouterProvider: ProviderConfig = {
|
||||
|
||||
const nextModelStartTime = Date.now()
|
||||
currentResponse = await client.chat.completions.create(nextPayload)
|
||||
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
|
||||
const nextForcedToolResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
nextPayload.tool_choice,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = nextForcedToolResult.hasUsedForcedTool
|
||||
usedForcedTools = nextForcedToolResult.usedForcedTools
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
timeSegments.push({
|
||||
|
||||
78
apps/sim/providers/openrouter/utils.ts
Normal file
78
apps/sim/providers/openrouter/utils.ts
Normal file
@@ -0,0 +1,78 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { trackForcedToolUsage } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('OpenRouterProvider')
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from an OpenAI-compatible stream response
|
||||
* @param openaiStream - The OpenAI stream to convert
|
||||
* @param onComplete - Optional callback when streaming is complete with content and usage data
|
||||
* @returns ReadableStream that emits text chunks
|
||||
*/
|
||||
export function createReadableStreamFromOpenAIStream(
|
||||
openaiStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of openaiStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a forced tool was used in the response and updates tracking
|
||||
* @param response - The API response containing tool calls
|
||||
* @param toolChoice - The tool choice configuration (string or object)
|
||||
* @param forcedTools - Array of forced tool names
|
||||
* @param usedForcedTools - Array of already used forced tools
|
||||
* @returns Object with hasUsedForcedTool flag and updated usedForcedTools array
|
||||
*/
|
||||
export function checkForForcedToolUsage(
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
|
||||
forcedTools: string[],
|
||||
usedForcedTools: string[]
|
||||
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
|
||||
let hasUsedForcedTool = false
|
||||
let updatedUsedForcedTools = usedForcedTools
|
||||
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
'openrouter',
|
||||
forcedTools,
|
||||
updatedUsedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
updatedUsedForcedTools = result.usedForcedTools
|
||||
}
|
||||
|
||||
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
|
||||
}
|
||||
@@ -5,6 +5,7 @@ export type ProviderId =
|
||||
| 'azure-openai'
|
||||
| 'anthropic'
|
||||
| 'google'
|
||||
| 'vertex'
|
||||
| 'deepseek'
|
||||
| 'xai'
|
||||
| 'cerebras'
|
||||
@@ -163,6 +164,9 @@ export interface ProviderRequest {
|
||||
// Azure OpenAI specific parameters
|
||||
azureEndpoint?: string
|
||||
azureApiVersion?: string
|
||||
// Vertex AI specific parameters
|
||||
vertexProject?: string
|
||||
vertexLocation?: string
|
||||
// GPT-5 specific parameters
|
||||
reasoningEffort?: string
|
||||
verbosity?: string
|
||||
|
||||
@@ -383,6 +383,17 @@ describe('Model Capabilities', () => {
|
||||
expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/gpt-5-mini')
|
||||
expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/gpt-5-nano')
|
||||
|
||||
// Should contain gpt-5.2 models
|
||||
expect(MODELS_WITH_REASONING_EFFORT).toContain('gpt-5.2')
|
||||
expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/gpt-5.2')
|
||||
|
||||
// Should contain o-series reasoning models (reasoning_effort added Dec 17, 2024)
|
||||
expect(MODELS_WITH_REASONING_EFFORT).toContain('o1')
|
||||
expect(MODELS_WITH_REASONING_EFFORT).toContain('o3')
|
||||
expect(MODELS_WITH_REASONING_EFFORT).toContain('o4-mini')
|
||||
expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/o3')
|
||||
expect(MODELS_WITH_REASONING_EFFORT).toContain('azure/o4-mini')
|
||||
|
||||
// Should NOT contain non-reasoning GPT-5 models
|
||||
expect(MODELS_WITH_REASONING_EFFORT).not.toContain('gpt-5-chat-latest')
|
||||
expect(MODELS_WITH_REASONING_EFFORT).not.toContain('azure/gpt-5-chat-latest')
|
||||
@@ -390,7 +401,6 @@ describe('Model Capabilities', () => {
|
||||
// Should NOT contain other models
|
||||
expect(MODELS_WITH_REASONING_EFFORT).not.toContain('gpt-4o')
|
||||
expect(MODELS_WITH_REASONING_EFFORT).not.toContain('claude-sonnet-4-0')
|
||||
expect(MODELS_WITH_REASONING_EFFORT).not.toContain('o1')
|
||||
})
|
||||
|
||||
it.concurrent('should have correct models in MODELS_WITH_VERBOSITY', () => {
|
||||
@@ -409,19 +419,37 @@ describe('Model Capabilities', () => {
|
||||
expect(MODELS_WITH_VERBOSITY).toContain('azure/gpt-5-mini')
|
||||
expect(MODELS_WITH_VERBOSITY).toContain('azure/gpt-5-nano')
|
||||
|
||||
// Should contain gpt-5.2 models
|
||||
expect(MODELS_WITH_VERBOSITY).toContain('gpt-5.2')
|
||||
expect(MODELS_WITH_VERBOSITY).toContain('azure/gpt-5.2')
|
||||
|
||||
// Should NOT contain non-reasoning GPT-5 models
|
||||
expect(MODELS_WITH_VERBOSITY).not.toContain('gpt-5-chat-latest')
|
||||
expect(MODELS_WITH_VERBOSITY).not.toContain('azure/gpt-5-chat-latest')
|
||||
|
||||
// Should NOT contain o-series models (they support reasoning_effort but not verbosity)
|
||||
expect(MODELS_WITH_VERBOSITY).not.toContain('o1')
|
||||
expect(MODELS_WITH_VERBOSITY).not.toContain('o3')
|
||||
expect(MODELS_WITH_VERBOSITY).not.toContain('o4-mini')
|
||||
|
||||
// Should NOT contain other models
|
||||
expect(MODELS_WITH_VERBOSITY).not.toContain('gpt-4o')
|
||||
expect(MODELS_WITH_VERBOSITY).not.toContain('claude-sonnet-4-0')
|
||||
expect(MODELS_WITH_VERBOSITY).not.toContain('o1')
|
||||
})
|
||||
|
||||
it.concurrent('should have same models in both reasoning effort and verbosity arrays', () => {
|
||||
// GPT-5 models that support reasoning effort should also support verbosity and vice versa
|
||||
expect(MODELS_WITH_REASONING_EFFORT.sort()).toEqual(MODELS_WITH_VERBOSITY.sort())
|
||||
it.concurrent('should have GPT-5 models in both reasoning effort and verbosity arrays', () => {
|
||||
// GPT-5 series models support both reasoning effort and verbosity
|
||||
const gpt5ModelsWithReasoningEffort = MODELS_WITH_REASONING_EFFORT.filter(
|
||||
(m) => m.includes('gpt-5') && !m.includes('chat-latest')
|
||||
)
|
||||
const gpt5ModelsWithVerbosity = MODELS_WITH_VERBOSITY.filter(
|
||||
(m) => m.includes('gpt-5') && !m.includes('chat-latest')
|
||||
)
|
||||
expect(gpt5ModelsWithReasoningEffort.sort()).toEqual(gpt5ModelsWithVerbosity.sort())
|
||||
|
||||
// o-series models have reasoning effort but NOT verbosity
|
||||
expect(MODELS_WITH_REASONING_EFFORT).toContain('o1')
|
||||
expect(MODELS_WITH_VERBOSITY).not.toContain('o1')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -21,6 +21,8 @@ import {
|
||||
getModelsWithVerbosity,
|
||||
getProviderModels as getProviderModelsFromDefinitions,
|
||||
getProvidersWithToolUsageControl,
|
||||
getReasoningEffortValuesForModel as getReasoningEffortValuesForModelFromDefinitions,
|
||||
getVerbosityValuesForModel as getVerbosityValuesForModelFromDefinitions,
|
||||
PROVIDER_DEFINITIONS,
|
||||
supportsTemperature as supportsTemperatureFromDefinitions,
|
||||
supportsToolUsageControl as supportsToolUsageControlFromDefinitions,
|
||||
@@ -30,6 +32,7 @@ import { ollamaProvider } from '@/providers/ollama'
|
||||
import { openaiProvider } from '@/providers/openai'
|
||||
import { openRouterProvider } from '@/providers/openrouter'
|
||||
import type { ProviderConfig, ProviderId, ProviderToolConfig } from '@/providers/types'
|
||||
import { vertexProvider } from '@/providers/vertex'
|
||||
import { vllmProvider } from '@/providers/vllm'
|
||||
import { xAIProvider } from '@/providers/xai'
|
||||
import { useCustomToolsStore } from '@/stores/custom-tools/store'
|
||||
@@ -67,6 +70,11 @@ export const providers: Record<
|
||||
models: getProviderModelsFromDefinitions('google'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.google.modelPatterns,
|
||||
},
|
||||
vertex: {
|
||||
...vertexProvider,
|
||||
models: getProviderModelsFromDefinitions('vertex'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.vertex.modelPatterns,
|
||||
},
|
||||
deepseek: {
|
||||
...deepseekProvider,
|
||||
models: getProviderModelsFromDefinitions('deepseek'),
|
||||
@@ -274,16 +282,12 @@ export function getProviderIcon(model: string): React.ComponentType<{ className?
|
||||
}
|
||||
|
||||
export function generateStructuredOutputInstructions(responseFormat: any): string {
|
||||
// Handle null/undefined input
|
||||
if (!responseFormat) return ''
|
||||
|
||||
// If using the new JSON Schema format, don't add additional instructions
|
||||
// This is necessary because providers now handle the schema directly
|
||||
if (responseFormat.schema || (responseFormat.type === 'object' && responseFormat.properties)) {
|
||||
return ''
|
||||
}
|
||||
|
||||
// Handle legacy format with fields array
|
||||
if (!responseFormat.fields) return ''
|
||||
|
||||
function generateFieldStructure(field: any): string {
|
||||
@@ -335,10 +339,8 @@ Each metric should be an object containing 'score' (number) and 'reasoning' (str
|
||||
}
|
||||
|
||||
export function extractAndParseJSON(content: string): any {
|
||||
// First clean up the string
|
||||
const trimmed = content.trim()
|
||||
|
||||
// Find the first '{' and last '}'
|
||||
const firstBrace = trimmed.indexOf('{')
|
||||
const lastBrace = trimmed.lastIndexOf('}')
|
||||
|
||||
@@ -346,17 +348,15 @@ export function extractAndParseJSON(content: string): any {
|
||||
throw new Error('No JSON object found in content')
|
||||
}
|
||||
|
||||
// Extract just the JSON part
|
||||
const jsonStr = trimmed.slice(firstBrace, lastBrace + 1)
|
||||
|
||||
try {
|
||||
return JSON.parse(jsonStr)
|
||||
} catch (_error) {
|
||||
// If parsing fails, try to clean up common issues
|
||||
const cleaned = jsonStr
|
||||
.replace(/\n/g, ' ') // Remove newlines
|
||||
.replace(/\s+/g, ' ') // Normalize whitespace
|
||||
.replace(/,\s*([}\]])/g, '$1') // Remove trailing commas
|
||||
.replace(/\n/g, ' ')
|
||||
.replace(/\s+/g, ' ')
|
||||
.replace(/,\s*([}\]])/g, '$1')
|
||||
|
||||
try {
|
||||
return JSON.parse(cleaned)
|
||||
@@ -386,10 +386,10 @@ export function transformCustomTool(customTool: any): ProviderToolConfig {
|
||||
}
|
||||
|
||||
return {
|
||||
id: `custom_${customTool.id}`, // Prefix with 'custom_' to identify custom tools
|
||||
id: `custom_${customTool.id}`,
|
||||
name: schema.function.name,
|
||||
description: schema.function.description || '',
|
||||
params: {}, // This will be derived from parameters
|
||||
params: {},
|
||||
parameters: {
|
||||
type: schema.function.parameters.type,
|
||||
properties: schema.function.parameters.properties,
|
||||
@@ -402,10 +402,8 @@ export function transformCustomTool(customTool: any): ProviderToolConfig {
|
||||
* Gets all available custom tools as provider tool configs
|
||||
*/
|
||||
export function getCustomTools(): ProviderToolConfig[] {
|
||||
// Get custom tools from the store
|
||||
const customTools = useCustomToolsStore.getState().getAllTools()
|
||||
|
||||
// Transform each custom tool into a provider tool config
|
||||
return customTools.map(transformCustomTool)
|
||||
}
|
||||
|
||||
@@ -427,20 +425,16 @@ export async function transformBlockTool(
|
||||
): Promise<ProviderToolConfig | null> {
|
||||
const { selectedOperation, getAllBlocks, getTool, getToolAsync } = options
|
||||
|
||||
// Get the block definition
|
||||
const blockDef = getAllBlocks().find((b: any) => b.type === block.type)
|
||||
if (!blockDef) {
|
||||
logger.warn(`Block definition not found for type: ${block.type}`)
|
||||
return null
|
||||
}
|
||||
|
||||
// If the block has multiple operations, use the selected one or the first one
|
||||
let toolId: string | null = null
|
||||
|
||||
if ((blockDef.tools?.access?.length || 0) > 1) {
|
||||
// If we have an operation dropdown in the block and a selected operation
|
||||
if (selectedOperation && blockDef.tools?.config?.tool) {
|
||||
// Use the block's tool selection function to get the right tool
|
||||
try {
|
||||
toolId = blockDef.tools.config.tool({
|
||||
...block.params,
|
||||
@@ -455,11 +449,9 @@ export async function transformBlockTool(
|
||||
return null
|
||||
}
|
||||
} else {
|
||||
// Default to first tool if no operation specified
|
||||
toolId = blockDef.tools.access[0]
|
||||
}
|
||||
} else {
|
||||
// Single tool case
|
||||
toolId = blockDef.tools?.access?.[0] || null
|
||||
}
|
||||
|
||||
@@ -468,14 +460,11 @@ export async function transformBlockTool(
|
||||
return null
|
||||
}
|
||||
|
||||
// Get the tool config - check if it's a custom tool that needs async fetching
|
||||
let toolConfig: any
|
||||
|
||||
if (toolId.startsWith('custom_') && getToolAsync) {
|
||||
// Use the async version for custom tools
|
||||
toolConfig = await getToolAsync(toolId)
|
||||
} else {
|
||||
// Use the synchronous version for built-in tools
|
||||
toolConfig = getTool(toolId)
|
||||
}
|
||||
|
||||
@@ -484,16 +473,12 @@ export async function transformBlockTool(
|
||||
return null
|
||||
}
|
||||
|
||||
// Import the new tool parameter utilities
|
||||
const { createLLMToolSchema } = await import('@/tools/params')
|
||||
|
||||
// Get user-provided parameters from the block
|
||||
const userProvidedParams = block.params || {}
|
||||
|
||||
// Create LLM schema that excludes user-provided parameters
|
||||
const llmSchema = await createLLMToolSchema(toolConfig, userProvidedParams)
|
||||
|
||||
// Return formatted tool config
|
||||
return {
|
||||
id: toolConfig.id,
|
||||
name: toolConfig.name,
|
||||
@@ -521,15 +506,12 @@ export function calculateCost(
|
||||
inputMultiplier?: number,
|
||||
outputMultiplier?: number
|
||||
) {
|
||||
// First check if it's an embedding model
|
||||
let pricing = getEmbeddingModelPricing(model)
|
||||
|
||||
// If not found, check chat models
|
||||
if (!pricing) {
|
||||
pricing = getModelPricingFromDefinitions(model)
|
||||
}
|
||||
|
||||
// If no pricing found, return default pricing
|
||||
if (!pricing) {
|
||||
const defaultPricing = {
|
||||
input: 1.0,
|
||||
@@ -545,8 +527,6 @@ export function calculateCost(
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate costs in USD
|
||||
// Convert from "per million tokens" to "per token" by dividing by 1,000,000
|
||||
const inputCost =
|
||||
promptTokens *
|
||||
(useCachedInput && pricing.cachedInput
|
||||
@@ -559,7 +539,7 @@ export function calculateCost(
|
||||
const finalTotalCost = finalInputCost + finalOutputCost
|
||||
|
||||
return {
|
||||
input: Number.parseFloat(finalInputCost.toFixed(8)), // Use 8 decimal places for small costs
|
||||
input: Number.parseFloat(finalInputCost.toFixed(8)),
|
||||
output: Number.parseFloat(finalOutputCost.toFixed(8)),
|
||||
total: Number.parseFloat(finalTotalCost.toFixed(8)),
|
||||
pricing,
|
||||
@@ -997,6 +977,22 @@ export function supportsToolUsageControl(provider: string): boolean {
|
||||
return supportsToolUsageControlFromDefinitions(provider)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get reasoning effort values for a specific model
|
||||
* Returns the valid options for that model, or null if the model doesn't support reasoning effort
|
||||
*/
|
||||
export function getReasoningEffortValuesForModel(model: string): string[] | null {
|
||||
return getReasoningEffortValuesForModelFromDefinitions(model)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get verbosity values for a specific model
|
||||
* Returns the valid options for that model, or null if the model doesn't support verbosity
|
||||
*/
|
||||
export function getVerbosityValuesForModel(model: string): string[] | null {
|
||||
return getVerbosityValuesForModelFromDefinitions(model)
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare tool execution parameters, separating tool parameters from system parameters
|
||||
*/
|
||||
|
||||
899
apps/sim/providers/vertex/index.ts
Normal file
899
apps/sim/providers/vertex/index.ts
Normal file
@@ -0,0 +1,899 @@
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import {
|
||||
cleanSchemaForGemini,
|
||||
convertToGeminiFormat,
|
||||
extractFunctionCall,
|
||||
extractTextContent,
|
||||
} from '@/providers/google/utils'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { buildVertexEndpoint, createReadableStreamFromVertexStream } from '@/providers/vertex/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('VertexProvider')
|
||||
|
||||
/**
|
||||
* Vertex AI provider configuration
|
||||
*/
|
||||
export const vertexProvider: ProviderConfig = {
|
||||
id: 'vertex',
|
||||
name: 'Vertex AI',
|
||||
description: "Google's Vertex AI platform for Gemini models",
|
||||
version: '1.0.0',
|
||||
models: getProviderModels('vertex'),
|
||||
defaultModel: getProviderDefaultModel('vertex'),
|
||||
|
||||
executeRequest: async (
|
||||
request: ProviderRequest
|
||||
): Promise<ProviderResponse | StreamingExecution> => {
|
||||
const vertexProject = env.VERTEX_PROJECT || request.vertexProject
|
||||
const vertexLocation = env.VERTEX_LOCATION || request.vertexLocation || 'us-central1'
|
||||
|
||||
if (!vertexProject) {
|
||||
throw new Error(
|
||||
'Vertex AI project is required. Please provide it via VERTEX_PROJECT environment variable or vertexProject parameter.'
|
||||
)
|
||||
}
|
||||
|
||||
if (!request.apiKey) {
|
||||
throw new Error(
|
||||
'Access token is required for Vertex AI. Run `gcloud auth print-access-token` to get one, or use a service account.'
|
||||
)
|
||||
}
|
||||
|
||||
logger.info('Preparing Vertex AI request', {
|
||||
model: request.model || 'vertex/gemini-2.5-pro',
|
||||
hasSystemPrompt: !!request.systemPrompt,
|
||||
hasMessages: !!request.messages?.length,
|
||||
hasTools: !!request.tools?.length,
|
||||
toolCount: request.tools?.length || 0,
|
||||
hasResponseFormat: !!request.responseFormat,
|
||||
streaming: !!request.stream,
|
||||
project: vertexProject,
|
||||
location: vertexLocation,
|
||||
})
|
||||
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
|
||||
|
||||
const requestedModel = (request.model || 'vertex/gemini-2.5-pro').replace('vertex/', '')
|
||||
|
||||
const payload: any = {
|
||||
contents,
|
||||
generationConfig: {},
|
||||
}
|
||||
|
||||
if (request.temperature !== undefined && request.temperature !== null) {
|
||||
payload.generationConfig.temperature = request.temperature
|
||||
}
|
||||
|
||||
if (request.maxTokens !== undefined) {
|
||||
payload.generationConfig.maxOutputTokens = request.maxTokens
|
||||
}
|
||||
|
||||
if (systemInstruction) {
|
||||
payload.systemInstruction = systemInstruction
|
||||
}
|
||||
|
||||
if (request.responseFormat && !tools?.length) {
|
||||
const responseFormatSchema = request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
payload.generationConfig.responseMimeType = 'application/json'
|
||||
payload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info('Using Vertex AI native structured output format', {
|
||||
hasSchema: !!cleanSchema,
|
||||
mimeType: 'application/json',
|
||||
})
|
||||
} else if (request.responseFormat && tools?.length) {
|
||||
logger.warn(
|
||||
'Vertex AI does not support structured output (responseFormat) with function calling (tools). Structured output will be ignored.'
|
||||
)
|
||||
}
|
||||
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
|
||||
if (tools?.length) {
|
||||
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'google')
|
||||
const { tools: filteredTools, toolConfig } = preparedTools
|
||||
|
||||
if (filteredTools?.length) {
|
||||
payload.tools = [
|
||||
{
|
||||
functionDeclarations: filteredTools,
|
||||
},
|
||||
]
|
||||
|
||||
if (toolConfig) {
|
||||
payload.toolConfig = toolConfig
|
||||
}
|
||||
|
||||
logger.info('Vertex AI request with tools:', {
|
||||
toolCount: filteredTools.length,
|
||||
model: requestedModel,
|
||||
tools: filteredTools.map((t) => t.name),
|
||||
hasToolConfig: !!toolConfig,
|
||||
toolConfig: toolConfig,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const initialCallTime = Date.now()
|
||||
const shouldStream = !!(request.stream && !tools?.length)
|
||||
|
||||
const endpoint = buildVertexEndpoint(
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
requestedModel,
|
||||
shouldStream
|
||||
)
|
||||
|
||||
if (request.stream && tools?.length) {
|
||||
logger.info('Streaming disabled for initial request due to tools presence', {
|
||||
toolCount: tools.length,
|
||||
willStreamAfterTools: true,
|
||||
})
|
||||
}
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${request.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const responseText = await response.text()
|
||||
logger.error('Vertex AI API error details:', {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
responseBody: responseText,
|
||||
})
|
||||
throw new Error(`Vertex AI API error: ${response.status} ${response.statusText}`)
|
||||
}
|
||||
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
if (shouldStream) {
|
||||
logger.info('Handling Vertex AI streaming response')
|
||||
|
||||
const streamingResult: StreamingExecution = {
|
||||
stream: null as any,
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
},
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: firstResponseTime,
|
||||
modelTime: firstResponseTime,
|
||||
toolsTime: 0,
|
||||
firstResponseTime,
|
||||
iterations: 1,
|
||||
timeSegments: [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Initial streaming response',
|
||||
startTime: initialCallTime,
|
||||
endTime: initialCallTime + firstResponseTime,
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
isStreaming: true,
|
||||
},
|
||||
}
|
||||
|
||||
streamingResult.stream = createReadableStreamFromVertexStream(
|
||||
response,
|
||||
(content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
if (streamingResult.execution.output.providerTiming) {
|
||||
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
|
||||
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
|
||||
streamEndTime
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.promptTokenCount || 0,
|
||||
completion: usage.candidatesTokenCount || 0,
|
||||
total:
|
||||
usage.totalTokenCount ||
|
||||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return streamingResult
|
||||
}
|
||||
|
||||
let geminiResponse = await response.json()
|
||||
|
||||
if (payload.generationConfig?.responseSchema) {
|
||||
const candidate = geminiResponse.candidates?.[0]
|
||||
if (candidate?.content?.parts?.[0]?.text) {
|
||||
const text = candidate.content.parts[0].text
|
||||
try {
|
||||
JSON.parse(text)
|
||||
logger.info('Successfully received structured JSON output')
|
||||
} catch (_e) {
|
||||
logger.warn('Failed to parse structured output as JSON')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let content = ''
|
||||
let tokens = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
let iterationCount = 0
|
||||
|
||||
const originalToolConfig = preparedTools?.toolConfig
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
let hasUsedForcedTool = false
|
||||
let currentToolConfig = originalToolConfig
|
||||
|
||||
const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => {
|
||||
if (currentToolConfig && forcedTools.length > 0) {
|
||||
const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }]
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsForTracking,
|
||||
currentToolConfig,
|
||||
logger,
|
||||
'google',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
|
||||
if (result.nextToolConfig) {
|
||||
currentToolConfig = result.nextToolConfig
|
||||
logger.info('Updated tool config for next iteration', {
|
||||
hasNextToolConfig: !!currentToolConfig,
|
||||
usedForcedTools: usedForcedTools,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Initial response',
|
||||
startTime: initialCallTime,
|
||||
endTime: initialCallTime + firstResponseTime,
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
]
|
||||
|
||||
try {
|
||||
const candidate = geminiResponse.candidates?.[0]
|
||||
|
||||
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
|
||||
logger.warn(
|
||||
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
|
||||
{
|
||||
finishReason: candidate.finishReason,
|
||||
hasContent: !!candidate?.content,
|
||||
hasParts: !!candidate?.content?.parts,
|
||||
}
|
||||
)
|
||||
content = extractTextContent(candidate)
|
||||
}
|
||||
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
|
||||
if (functionCall) {
|
||||
logger.info(`Received function call from Vertex AI: ${functionCall.name}`)
|
||||
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
const latestResponse = geminiResponse.candidates?.[0]
|
||||
const latestFunctionCall = extractFunctionCall(latestResponse)
|
||||
|
||||
if (!latestFunctionCall) {
|
||||
content = extractTextContent(latestResponse)
|
||||
break
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
try {
|
||||
const toolName = latestFunctionCall.name
|
||||
const toolArgs = latestFunctionCall.args || {}
|
||||
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) {
|
||||
logger.warn(`Tool ${toolName} not found in registry, skipping`)
|
||||
break
|
||||
}
|
||||
|
||||
const toolCallStartTime = Date.now()
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
const simplifiedMessages = [
|
||||
...(contents.filter((m) => m.role === 'user').length > 0
|
||||
? [contents.filter((m) => m.role === 'user')[0]]
|
||||
: [contents[0]]),
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
name: latestFunctionCall.name,
|
||||
args: latestFunctionCall.args,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: `Function ${latestFunctionCall.name} result: ${JSON.stringify(resultContent)}`,
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
checkForForcedToolUsage(latestFunctionCall)
|
||||
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
try {
|
||||
if (request.stream) {
|
||||
const streamingPayload = {
|
||||
...payload,
|
||||
contents: simplifiedMessages,
|
||||
}
|
||||
|
||||
const allForcedToolsUsed =
|
||||
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
|
||||
|
||||
if (allForcedToolsUsed && request.responseFormat) {
|
||||
streamingPayload.tools = undefined
|
||||
streamingPayload.toolConfig = undefined
|
||||
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
if (!streamingPayload.generationConfig) {
|
||||
streamingPayload.generationConfig = {}
|
||||
}
|
||||
streamingPayload.generationConfig.responseMimeType = 'application/json'
|
||||
streamingPayload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info('Using structured output for final response after tool execution')
|
||||
} else {
|
||||
if (currentToolConfig) {
|
||||
streamingPayload.toolConfig = currentToolConfig
|
||||
} else {
|
||||
streamingPayload.toolConfig = { functionCallingConfig: { mode: 'AUTO' } }
|
||||
}
|
||||
}
|
||||
|
||||
const checkPayload = {
|
||||
...streamingPayload,
|
||||
}
|
||||
checkPayload.stream = undefined
|
||||
|
||||
const checkEndpoint = buildVertexEndpoint(
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
requestedModel,
|
||||
false
|
||||
)
|
||||
|
||||
const checkResponse = await fetch(checkEndpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${request.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(checkPayload),
|
||||
})
|
||||
|
||||
if (!checkResponse.ok) {
|
||||
const errorBody = await checkResponse.text()
|
||||
logger.error('Error in Vertex AI check request:', {
|
||||
status: checkResponse.status,
|
||||
statusText: checkResponse.statusText,
|
||||
responseBody: errorBody,
|
||||
})
|
||||
throw new Error(
|
||||
`Vertex AI API check error: ${checkResponse.status} ${checkResponse.statusText}`
|
||||
)
|
||||
}
|
||||
|
||||
const checkResult = await checkResponse.json()
|
||||
const checkCandidate = checkResult.candidates?.[0]
|
||||
const checkFunctionCall = extractFunctionCall(checkCandidate)
|
||||
|
||||
if (checkFunctionCall) {
|
||||
logger.info(
|
||||
'Function call detected in follow-up, handling in non-streaming mode',
|
||||
{
|
||||
functionName: checkFunctionCall.name,
|
||||
}
|
||||
)
|
||||
|
||||
geminiResponse = checkResult
|
||||
|
||||
if (checkResult.usageMetadata) {
|
||||
tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0
|
||||
tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0
|
||||
tokens.total +=
|
||||
(checkResult.usageMetadata.promptTokenCount || 0) +
|
||||
(checkResult.usageMetadata.candidatesTokenCount || 0)
|
||||
}
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
modelTime += thisModelTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
startTime: nextModelStartTime,
|
||||
endTime: nextModelEndTime,
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
iterationCount++
|
||||
continue
|
||||
}
|
||||
|
||||
logger.info('No function call detected, proceeding with streaming response')
|
||||
|
||||
// Apply structured output for the final response if responseFormat is specified
|
||||
// This works regardless of whether tools were forced or auto
|
||||
if (request.responseFormat) {
|
||||
streamingPayload.tools = undefined
|
||||
streamingPayload.toolConfig = undefined
|
||||
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
if (!streamingPayload.generationConfig) {
|
||||
streamingPayload.generationConfig = {}
|
||||
}
|
||||
streamingPayload.generationConfig.responseMimeType = 'application/json'
|
||||
streamingPayload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info(
|
||||
'Using structured output for final streaming response after tool execution'
|
||||
)
|
||||
}
|
||||
|
||||
const streamEndpoint = buildVertexEndpoint(
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
requestedModel,
|
||||
true
|
||||
)
|
||||
|
||||
const streamingResponse = await fetch(streamEndpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${request.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(streamingPayload),
|
||||
})
|
||||
|
||||
if (!streamingResponse.ok) {
|
||||
const errorBody = await streamingResponse.text()
|
||||
logger.error('Error in Vertex AI streaming follow-up request:', {
|
||||
status: streamingResponse.status,
|
||||
statusText: streamingResponse.statusText,
|
||||
responseBody: errorBody,
|
||||
})
|
||||
throw new Error(
|
||||
`Vertex AI API streaming error: ${streamingResponse.status} ${streamingResponse.statusText}`
|
||||
)
|
||||
}
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
modelTime += thisModelTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: 'Final streaming response after tool calls',
|
||||
startTime: nextModelStartTime,
|
||||
endTime: nextModelEndTime,
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
const streamingExecution: StreamingExecution = {
|
||||
stream: null as any,
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls:
|
||||
toolCalls.length > 0
|
||||
? {
|
||||
list: toolCalls,
|
||||
count: toolCalls.length,
|
||||
}
|
||||
: undefined,
|
||||
toolResults,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
modelTime,
|
||||
toolsTime,
|
||||
firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments,
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
isStreaming: true,
|
||||
},
|
||||
}
|
||||
|
||||
streamingExecution.stream = createReadableStreamFromVertexStream(
|
||||
streamingResponse,
|
||||
(content, usage) => {
|
||||
streamingExecution.execution.output.content = content
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
if (streamingExecution.execution.output.providerTiming) {
|
||||
streamingExecution.execution.output.providerTiming.endTime =
|
||||
streamEndTimeISO
|
||||
streamingExecution.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
const existingTokens = streamingExecution.execution.output.tokens || {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
streamingExecution.execution.output.tokens = {
|
||||
prompt: (existingTokens.prompt || 0) + (usage.promptTokenCount || 0),
|
||||
completion:
|
||||
(existingTokens.completion || 0) + (usage.candidatesTokenCount || 0),
|
||||
total:
|
||||
(existingTokens.total || 0) +
|
||||
(usage.totalTokenCount ||
|
||||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return streamingExecution
|
||||
}
|
||||
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
contents: simplifiedMessages,
|
||||
}
|
||||
|
||||
const allForcedToolsUsed =
|
||||
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
|
||||
|
||||
if (allForcedToolsUsed && request.responseFormat) {
|
||||
nextPayload.tools = undefined
|
||||
nextPayload.toolConfig = undefined
|
||||
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
if (!nextPayload.generationConfig) {
|
||||
nextPayload.generationConfig = {}
|
||||
}
|
||||
nextPayload.generationConfig.responseMimeType = 'application/json'
|
||||
nextPayload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info(
|
||||
'Using structured output for final non-streaming response after tool execution'
|
||||
)
|
||||
} else {
|
||||
if (currentToolConfig) {
|
||||
nextPayload.toolConfig = currentToolConfig
|
||||
}
|
||||
}
|
||||
|
||||
const nextEndpoint = buildVertexEndpoint(
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
requestedModel,
|
||||
false
|
||||
)
|
||||
|
||||
const nextResponse = await fetch(nextEndpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${request.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(nextPayload),
|
||||
})
|
||||
|
||||
if (!nextResponse.ok) {
|
||||
const errorBody = await nextResponse.text()
|
||||
logger.error('Error in Vertex AI follow-up request:', {
|
||||
status: nextResponse.status,
|
||||
statusText: nextResponse.statusText,
|
||||
responseBody: errorBody,
|
||||
iterationCount,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
geminiResponse = await nextResponse.json()
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
startTime: nextModelStartTime,
|
||||
endTime: nextModelEndTime,
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
modelTime += thisModelTime
|
||||
|
||||
const nextCandidate = geminiResponse.candidates?.[0]
|
||||
const nextFunctionCall = extractFunctionCall(nextCandidate)
|
||||
|
||||
if (!nextFunctionCall) {
|
||||
// If responseFormat is specified, make one final request with structured output
|
||||
if (request.responseFormat) {
|
||||
const finalPayload = {
|
||||
...payload,
|
||||
contents: nextPayload.contents,
|
||||
tools: undefined,
|
||||
toolConfig: undefined,
|
||||
}
|
||||
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
if (!finalPayload.generationConfig) {
|
||||
finalPayload.generationConfig = {}
|
||||
}
|
||||
finalPayload.generationConfig.responseMimeType = 'application/json'
|
||||
finalPayload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info('Making final request with structured output after tool execution')
|
||||
|
||||
const finalEndpoint = buildVertexEndpoint(
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
requestedModel,
|
||||
false
|
||||
)
|
||||
|
||||
const finalResponse = await fetch(finalEndpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${request.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(finalPayload),
|
||||
})
|
||||
|
||||
if (finalResponse.ok) {
|
||||
const finalResult = await finalResponse.json()
|
||||
const finalCandidate = finalResult.candidates?.[0]
|
||||
content = extractTextContent(finalCandidate)
|
||||
|
||||
if (finalResult.usageMetadata) {
|
||||
tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0
|
||||
tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0
|
||||
tokens.total +=
|
||||
(finalResult.usageMetadata.promptTokenCount || 0) +
|
||||
(finalResult.usageMetadata.candidatesTokenCount || 0)
|
||||
}
|
||||
} else {
|
||||
logger.warn(
|
||||
'Failed to get structured output, falling back to regular response'
|
||||
)
|
||||
content = extractTextContent(nextCandidate)
|
||||
}
|
||||
} else {
|
||||
content = extractTextContent(nextCandidate)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
iterationCount++
|
||||
} catch (error) {
|
||||
logger.error('Error in Vertex AI follow-up request:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
iterationCount,
|
||||
})
|
||||
break
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error processing function call:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
functionName: latestFunctionCall?.name || 'unknown',
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
content = extractTextContent(candidate)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error processing Vertex AI response:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
iterationCount,
|
||||
})
|
||||
|
||||
if (!content && toolCalls.length > 0) {
|
||||
content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.`
|
||||
}
|
||||
}
|
||||
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
if (geminiResponse.usageMetadata) {
|
||||
tokens = {
|
||||
prompt: geminiResponse.usageMetadata.promptTokenCount || 0,
|
||||
completion: geminiResponse.usageMetadata.candidatesTokenCount || 0,
|
||||
total:
|
||||
(geminiResponse.usageMetadata.promptTokenCount || 0) +
|
||||
(geminiResponse.usageMetadata.candidatesTokenCount || 0),
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content,
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
timing: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
logger.error('Error in Vertex AI request:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
}
|
||||
|
||||
throw enhancedError
|
||||
}
|
||||
},
|
||||
}
|
||||
233
apps/sim/providers/vertex/utils.ts
Normal file
233
apps/sim/providers/vertex/utils.ts
Normal file
@@ -0,0 +1,233 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { extractFunctionCall, extractTextContent } from '@/providers/google/utils'
|
||||
|
||||
const logger = createLogger('VertexUtils')
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from Vertex AI's Gemini stream response
|
||||
*/
|
||||
export function createReadableStreamFromVertexStream(
|
||||
response: Response,
|
||||
onComplete?: (
|
||||
content: string,
|
||||
usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number }
|
||||
) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
throw new Error('Failed to get reader from response body')
|
||||
}
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
let buffer = ''
|
||||
let fullContent = ''
|
||||
let usageData: {
|
||||
promptTokenCount?: number
|
||||
candidatesTokenCount?: number
|
||||
totalTokenCount?: number
|
||||
} | null = null
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
if (buffer.trim()) {
|
||||
try {
|
||||
const data = JSON.parse(buffer.trim())
|
||||
if (data.usageMetadata) {
|
||||
usageData = data.usageMetadata
|
||||
}
|
||||
const candidate = data.candidates?.[0]
|
||||
if (candidate?.content?.parts) {
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
'Function call detected in final buffer, ending stream to execute tool',
|
||||
{
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
const content = extractTextContent(candidate)
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (buffer.trim().startsWith('[')) {
|
||||
try {
|
||||
const dataArray = JSON.parse(buffer.trim())
|
||||
if (Array.isArray(dataArray)) {
|
||||
for (const item of dataArray) {
|
||||
if (item.usageMetadata) {
|
||||
usageData = item.usageMetadata
|
||||
}
|
||||
const candidate = item.candidates?.[0]
|
||||
if (candidate?.content?.parts) {
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
'Function call detected in array item, ending stream to execute tool',
|
||||
{
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
const content = extractTextContent(candidate)
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (arrayError) {
|
||||
// Buffer is not valid JSON array
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
break
|
||||
}
|
||||
|
||||
const text = new TextDecoder().decode(value)
|
||||
buffer += text
|
||||
|
||||
let searchIndex = 0
|
||||
while (searchIndex < buffer.length) {
|
||||
const openBrace = buffer.indexOf('{', searchIndex)
|
||||
if (openBrace === -1) break
|
||||
|
||||
let braceCount = 0
|
||||
let inString = false
|
||||
let escaped = false
|
||||
let closeBrace = -1
|
||||
|
||||
for (let i = openBrace; i < buffer.length; i++) {
|
||||
const char = buffer[i]
|
||||
|
||||
if (!inString) {
|
||||
if (char === '"' && !escaped) {
|
||||
inString = true
|
||||
} else if (char === '{') {
|
||||
braceCount++
|
||||
} else if (char === '}') {
|
||||
braceCount--
|
||||
if (braceCount === 0) {
|
||||
closeBrace = i
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (char === '"' && !escaped) {
|
||||
inString = false
|
||||
}
|
||||
}
|
||||
|
||||
escaped = char === '\\' && !escaped
|
||||
}
|
||||
|
||||
if (closeBrace !== -1) {
|
||||
const jsonStr = buffer.substring(openBrace, closeBrace + 1)
|
||||
|
||||
try {
|
||||
const data = JSON.parse(jsonStr)
|
||||
|
||||
if (data.usageMetadata) {
|
||||
usageData = data.usageMetadata
|
||||
}
|
||||
|
||||
const candidate = data.candidates?.[0]
|
||||
|
||||
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
|
||||
logger.warn(
|
||||
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
|
||||
{
|
||||
finishReason: candidate.finishReason,
|
||||
hasContent: !!candidate?.content,
|
||||
hasParts: !!candidate?.content?.parts,
|
||||
}
|
||||
)
|
||||
const textContent = extractTextContent(candidate)
|
||||
if (textContent) {
|
||||
fullContent += textContent
|
||||
controller.enqueue(new TextEncoder().encode(textContent))
|
||||
}
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
|
||||
if (candidate?.content?.parts) {
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
'Function call detected in stream, ending stream to execute tool',
|
||||
{
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
const content = extractTextContent(candidate)
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error('Error parsing JSON from stream', {
|
||||
error: e instanceof Error ? e.message : String(e),
|
||||
jsonPreview: jsonStr.substring(0, 200),
|
||||
})
|
||||
}
|
||||
|
||||
buffer = buffer.substring(closeBrace + 1)
|
||||
searchIndex = 0
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error('Error reading Vertex AI stream', {
|
||||
error: e instanceof Error ? e.message : String(e),
|
||||
})
|
||||
controller.error(e)
|
||||
}
|
||||
},
|
||||
async cancel() {
|
||||
await reader.cancel()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Build Vertex AI endpoint URL
|
||||
*/
|
||||
export function buildVertexEndpoint(
|
||||
project: string,
|
||||
location: string,
|
||||
model: string,
|
||||
isStreaming: boolean
|
||||
): string {
|
||||
const action = isStreaming ? 'streamGenerateContent' : 'generateContent'
|
||||
|
||||
if (location === 'global') {
|
||||
return `https://aiplatform.googleapis.com/v1/projects/${project}/locations/global/publishers/google/models/${model}:${action}`
|
||||
}
|
||||
|
||||
return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/google/models/${model}:${action}`
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import OpenAI from 'openai'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
@@ -14,50 +15,13 @@ import {
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { createReadableStreamFromVLLMStream } from '@/providers/vllm/utils'
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('VLLMProvider')
|
||||
const VLLM_VERSION = '1.0.0'
|
||||
|
||||
/**
|
||||
* Helper function to convert a vLLM stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
*/
|
||||
function createReadableStreamFromVLLMStream(
|
||||
vllmStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of vllmStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const vllmProvider: ProviderConfig = {
|
||||
id: 'vllm',
|
||||
name: 'vLLM',
|
||||
@@ -341,7 +305,6 @@ export const vllmProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10
|
||||
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
@@ -360,14 +323,14 @@ export const vllmProvider: ProviderConfig = {
|
||||
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_ITERATIONS})`
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
37
apps/sim/providers/vllm/utils.ts
Normal file
37
apps/sim/providers/vllm/utils.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Helper function to convert a vLLM stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
*/
|
||||
export function createReadableStreamFromVLLMStream(
|
||||
vllmStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of vllmStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import OpenAI from 'openai'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
@@ -8,37 +9,16 @@ import type {
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
|
||||
import {
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
checkForForcedToolUsage,
|
||||
createReadableStreamFromXAIStream,
|
||||
createResponseFormatPayload,
|
||||
} from '@/providers/xai/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('XAIProvider')
|
||||
|
||||
/**
|
||||
* Helper to wrap XAI (OpenAI-compatible) streaming into a browser-friendly
|
||||
* ReadableStream of raw assistant text chunks.
|
||||
*/
|
||||
function createReadableStreamFromXAIStream(xaiStream: any): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of xaiStream) {
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (err) {
|
||||
controller.error(err)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const xAIProvider: ProviderConfig = {
|
||||
id: 'xai',
|
||||
name: 'xAI',
|
||||
@@ -115,27 +95,6 @@ export const xAIProvider: ProviderConfig = {
|
||||
if (request.temperature !== undefined) basePayload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) basePayload.max_tokens = request.maxTokens
|
||||
|
||||
// Function to create response format configuration
|
||||
const createResponseFormatPayload = (messages: any[] = allMessages) => {
|
||||
const payload = {
|
||||
...basePayload,
|
||||
messages,
|
||||
}
|
||||
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: request.responseFormat.name || 'structured_response',
|
||||
schema: request.responseFormat.schema || request.responseFormat,
|
||||
strict: request.responseFormat.strict !== false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// Handle tools and tool usage control
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
|
||||
@@ -154,7 +113,7 @@ export const xAIProvider: ProviderConfig = {
|
||||
|
||||
// Use response format payload if needed, otherwise use base payload
|
||||
const streamingPayload = request.responseFormat
|
||||
? createResponseFormatPayload()
|
||||
? createResponseFormatPayload(basePayload, allMessages, request.responseFormat)
|
||||
: { ...basePayload, stream: true }
|
||||
|
||||
if (!request.responseFormat) {
|
||||
@@ -243,7 +202,11 @@ export const xAIProvider: ProviderConfig = {
|
||||
originalToolChoice = toolChoice
|
||||
} else if (request.responseFormat) {
|
||||
// Only add response format if there are no tools
|
||||
const responseFormatPayload = createResponseFormatPayload()
|
||||
const responseFormatPayload = createResponseFormatPayload(
|
||||
basePayload,
|
||||
allMessages,
|
||||
request.responseFormat
|
||||
)
|
||||
Object.assign(initialPayload, responseFormatPayload)
|
||||
}
|
||||
|
||||
@@ -260,7 +223,6 @@ export const xAIProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
const MAX_ITERATIONS = 10
|
||||
|
||||
// Track if a forced tool has been used
|
||||
let hasUsedForcedTool = false
|
||||
@@ -280,33 +242,20 @@ export const xAIProvider: ProviderConfig = {
|
||||
},
|
||||
]
|
||||
|
||||
// Helper function to check for forced tool usage in responses
|
||||
const checkForForcedToolUsage = (
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }
|
||||
) => {
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
'xai',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
}
|
||||
}
|
||||
|
||||
// Check if a forced tool was used in the first response
|
||||
if (originalToolChoice) {
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
const result = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
originalToolChoice,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
}
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_ITERATIONS) {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
@@ -432,7 +381,12 @@ export const xAIProvider: ProviderConfig = {
|
||||
} else {
|
||||
// All forced tools have been used, check if we need response format for final response
|
||||
if (request.responseFormat) {
|
||||
nextPayload = createResponseFormatPayload(currentMessages)
|
||||
nextPayload = createResponseFormatPayload(
|
||||
basePayload,
|
||||
allMessages,
|
||||
request.responseFormat,
|
||||
currentMessages
|
||||
)
|
||||
} else {
|
||||
nextPayload = {
|
||||
...basePayload,
|
||||
@@ -446,7 +400,12 @@ export const xAIProvider: ProviderConfig = {
|
||||
// Normal tool processing - check if this might be the final response
|
||||
if (request.responseFormat) {
|
||||
// Use response format for what might be the final response
|
||||
nextPayload = createResponseFormatPayload(currentMessages)
|
||||
nextPayload = createResponseFormatPayload(
|
||||
basePayload,
|
||||
allMessages,
|
||||
request.responseFormat,
|
||||
currentMessages
|
||||
)
|
||||
} else {
|
||||
nextPayload = {
|
||||
...basePayload,
|
||||
@@ -464,7 +423,14 @@ export const xAIProvider: ProviderConfig = {
|
||||
|
||||
// Check if any forced tools were used in this response
|
||||
if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') {
|
||||
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
|
||||
const result = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
nextPayload.tool_choice,
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
}
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
@@ -509,7 +475,12 @@ export const xAIProvider: ProviderConfig = {
|
||||
if (request.responseFormat) {
|
||||
// Use response format, no tools
|
||||
finalStreamingPayload = {
|
||||
...createResponseFormatPayload(currentMessages),
|
||||
...createResponseFormatPayload(
|
||||
basePayload,
|
||||
allMessages,
|
||||
request.responseFormat,
|
||||
currentMessages
|
||||
),
|
||||
stream: true,
|
||||
}
|
||||
} else {
|
||||
|
||||
83
apps/sim/providers/xai/utils.ts
Normal file
83
apps/sim/providers/xai/utils.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { trackForcedToolUsage } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('XAIProvider')
|
||||
|
||||
/**
|
||||
* Helper to wrap XAI (OpenAI-compatible) streaming into a browser-friendly
|
||||
* ReadableStream of raw assistant text chunks.
|
||||
*/
|
||||
export function createReadableStreamFromXAIStream(xaiStream: any): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of xaiStream) {
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (err) {
|
||||
controller.error(err)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a response format payload for XAI API requests.
|
||||
*/
|
||||
export function createResponseFormatPayload(
|
||||
basePayload: any,
|
||||
allMessages: any[],
|
||||
responseFormat: any,
|
||||
currentMessages?: any[]
|
||||
) {
|
||||
const payload = {
|
||||
...basePayload,
|
||||
messages: currentMessages || allMessages,
|
||||
}
|
||||
|
||||
if (responseFormat) {
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: responseFormat.name || 'structured_response',
|
||||
schema: responseFormat.schema || responseFormat,
|
||||
strict: responseFormat.strict !== false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to check for forced tool usage in responses.
|
||||
*/
|
||||
export function checkForForcedToolUsage(
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
|
||||
forcedTools: string[],
|
||||
usedForcedTools: string[]
|
||||
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
|
||||
let hasUsedForcedTool = false
|
||||
let updatedUsedForcedTools = usedForcedTools
|
||||
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
'xai',
|
||||
forcedTools,
|
||||
updatedUsedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
updatedUsedForcedTools = result.usedForcedTools
|
||||
}
|
||||
|
||||
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
|
||||
}
|
||||
@@ -13,6 +13,8 @@ interface LLMChatParams {
|
||||
maxTokens?: number
|
||||
azureEndpoint?: string
|
||||
azureApiVersion?: string
|
||||
vertexProject?: string
|
||||
vertexLocation?: string
|
||||
}
|
||||
|
||||
interface LLMChatResponse extends ToolResponse {
|
||||
@@ -77,6 +79,18 @@ export const llmChatTool: ToolConfig<LLMChatParams, LLMChatResponse> = {
|
||||
visibility: 'hidden',
|
||||
description: 'Azure OpenAI API version',
|
||||
},
|
||||
vertexProject: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
visibility: 'hidden',
|
||||
description: 'Google Cloud project ID for Vertex AI',
|
||||
},
|
||||
vertexLocation: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
visibility: 'hidden',
|
||||
description: 'Google Cloud location for Vertex AI (defaults to us-central1)',
|
||||
},
|
||||
},
|
||||
|
||||
request: {
|
||||
@@ -98,6 +112,8 @@ export const llmChatTool: ToolConfig<LLMChatParams, LLMChatResponse> = {
|
||||
maxTokens: params.maxTokens,
|
||||
azureEndpoint: params.azureEndpoint,
|
||||
azureApiVersion: params.azureApiVersion,
|
||||
vertexProject: params.vertexProject,
|
||||
vertexLocation: params.vertexLocation,
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
7
bun.lock
7
bun.lock
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"lockfileVersion": 1,
|
||||
"configVersion": 0,
|
||||
"workspaces": {
|
||||
"": {
|
||||
"name": "simstudio",
|
||||
@@ -266,12 +267,12 @@
|
||||
"sharp",
|
||||
],
|
||||
"overrides": {
|
||||
"react": "19.2.1",
|
||||
"react-dom": "19.2.1",
|
||||
"next": "16.1.0-canary.21",
|
||||
"@next/env": "16.1.0-canary.21",
|
||||
"drizzle-orm": "^0.44.5",
|
||||
"next": "16.1.0-canary.21",
|
||||
"postgres": "^3.4.5",
|
||||
"react": "19.2.1",
|
||||
"react-dom": "19.2.1",
|
||||
},
|
||||
"packages": {
|
||||
"@adobe/css-tools": ["@adobe/css-tools@4.4.4", "", {}, "sha512-Elp+iwUx5rN5+Y8xLt5/GRoG20WGoDCQ/1Fb+1LiGtvwbDavuSk0jhD/eZdckHAuzcDzccnkv+rEjyWfRx18gg=="],
|
||||
|
||||
Reference in New Issue
Block a user