From b0748c82f9ae2dd6c222e2474a50e597ef06b2d6 Mon Sep 17 00:00:00 2001 From: Waleed Date: Mon, 22 Dec 2025 23:57:11 -0800 Subject: [PATCH] fix(search): removed full text param from built-in search, anthropic provider streaming fix (#2542) * fix(search): removed full text param from built-in search, anthropic provider streaming fix * rewrite gemini provider with official sdk + add thinking capability * vertex gemini consolidation * never silently use different model * pass oauth client through the googleAuthOptions param directly * make server side provider registry * remove comments * take oauth selector below model selector --------- Co-authored-by: Vikhyath Mondreti --- apps/sim/app/api/tools/search/route.ts | 4 +- .../components/oauth-required-modal.tsx | 2 + apps/sim/blocks/blocks/agent.ts | 80 +- .../executor/handlers/agent/agent-handler.ts | 68 +- apps/sim/executor/handlers/agent/types.ts | 1 + apps/sim/lib/auth/auth.ts | 15 + apps/sim/lib/core/utils/display-filters.ts | 30 +- apps/sim/lib/oauth/oauth.ts | 18 + apps/sim/providers/anthropic/index.ts | 120 +- apps/sim/providers/azure-openai/index.ts | 4 +- apps/sim/providers/cerebras/index.ts | 6 +- apps/sim/providers/deepseek/index.ts | 6 +- apps/sim/providers/gemini/client.ts | 58 + apps/sim/providers/gemini/core.ts | 680 ++++++++++ apps/sim/providers/gemini/index.ts | 18 + apps/sim/providers/gemini/types.ts | 64 + apps/sim/providers/google/index.ts | 1092 +---------------- apps/sim/providers/google/utils.ts | 297 ++++- apps/sim/providers/groq/index.ts | 13 +- apps/sim/providers/index.ts | 6 +- apps/sim/providers/mistral/index.ts | 6 +- apps/sim/providers/models.ts | 66 + apps/sim/providers/openai/index.ts | 6 +- apps/sim/providers/openrouter/index.ts | 2 +- apps/sim/providers/registry.ts | 59 + apps/sim/providers/types.ts | 1 + apps/sim/providers/utils.ts | 155 +-- apps/sim/providers/vertex/index.ts | 920 +------------- apps/sim/providers/vertex/utils.ts | 231 ---- apps/sim/providers/vllm/index.ts | 2 +- apps/sim/providers/xai/index.ts | 8 +- 31 files changed, 1607 insertions(+), 2431 deletions(-) create mode 100644 apps/sim/providers/gemini/client.ts create mode 100644 apps/sim/providers/gemini/core.ts create mode 100644 apps/sim/providers/gemini/index.ts create mode 100644 apps/sim/providers/gemini/types.ts create mode 100644 apps/sim/providers/registry.ts delete mode 100644 apps/sim/providers/vertex/utils.ts diff --git a/apps/sim/app/api/tools/search/route.ts b/apps/sim/app/api/tools/search/route.ts index fb7815da8..e396cdf9a 100644 --- a/apps/sim/app/api/tools/search/route.ts +++ b/apps/sim/app/api/tools/search/route.ts @@ -56,7 +56,7 @@ export async function POST(request: NextRequest) { query: validated.query, type: 'auto', useAutoprompt: true, - text: true, + highlights: true, apiKey: env.EXA_API_KEY, }) @@ -77,7 +77,7 @@ export async function POST(request: NextRequest) { const results = (result.output.results || []).map((r: any, index: number) => ({ title: r.title || '', link: r.url || '', - snippet: r.text || '', + snippet: Array.isArray(r.highlights) ? r.highlights.join(' ... ') : '', date: r.publishedDate || undefined, position: index + 1, })) diff --git a/apps/sim/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/editor/components/sub-block/components/credential-selector/components/oauth-required-modal.tsx b/apps/sim/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/editor/components/sub-block/components/credential-selector/components/oauth-required-modal.tsx index a72d30e75..2744a2b23 100644 --- a/apps/sim/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/editor/components/sub-block/components/credential-selector/components/oauth-required-modal.tsx +++ b/apps/sim/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/editor/components/sub-block/components/credential-selector/components/oauth-required-modal.tsx @@ -43,6 +43,8 @@ const SCOPE_DESCRIPTIONS: Record = { 'https://www.googleapis.com/auth/admin.directory.group.readonly': 'View Google Workspace groups', 'https://www.googleapis.com/auth/admin.directory.group.member.readonly': 'View Google Workspace group memberships', + 'https://www.googleapis.com/auth/cloud-platform': + 'Full access to Google Cloud resources for Vertex AI', 'read:confluence-content.all': 'Read all Confluence content', 'read:confluence-space.summary': 'Read Confluence space information', 'read:space:confluence': 'View Confluence spaces', diff --git a/apps/sim/blocks/blocks/agent.ts b/apps/sim/blocks/blocks/agent.ts index 31a54c91f..16227a290 100644 --- a/apps/sim/blocks/blocks/agent.ts +++ b/apps/sim/blocks/blocks/agent.ts @@ -9,8 +9,10 @@ import { getMaxTemperature, getProviderIcon, getReasoningEffortValuesForModel, + getThinkingLevelsForModel, getVerbosityValuesForModel, MODELS_WITH_REASONING_EFFORT, + MODELS_WITH_THINKING, MODELS_WITH_VERBOSITY, providers, supportsTemperature, @@ -108,7 +110,19 @@ export const AgentBlock: BlockConfig = { }) }, }, - + { + id: 'vertexCredential', + title: 'Google Cloud Account', + type: 'oauth-input', + serviceId: 'vertex-ai', + requiredScopes: ['https://www.googleapis.com/auth/cloud-platform'], + placeholder: 'Select Google Cloud account', + required: true, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, { id: 'reasoningEffort', title: 'Reasoning Effort', @@ -215,6 +229,57 @@ export const AgentBlock: BlockConfig = { value: MODELS_WITH_VERBOSITY, }, }, + { + id: 'thinkingLevel', + title: 'Thinking Level', + type: 'dropdown', + placeholder: 'Select thinking level...', + options: [ + { label: 'minimal', id: 'minimal' }, + { label: 'low', id: 'low' }, + { label: 'medium', id: 'medium' }, + { label: 'high', id: 'high' }, + ], + dependsOn: ['model'], + fetchOptions: async (blockId: string) => { + const { useSubBlockStore } = await import('@/stores/workflows/subblock/store') + const { useWorkflowRegistry } = await import('@/stores/workflows/registry/store') + + const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId + if (!activeWorkflowId) { + return [ + { label: 'low', id: 'low' }, + { label: 'high', id: 'high' }, + ] + } + + const workflowValues = useSubBlockStore.getState().workflowValues[activeWorkflowId] + const blockValues = workflowValues?.[blockId] + const modelValue = blockValues?.model as string + + if (!modelValue) { + return [ + { label: 'low', id: 'low' }, + { label: 'high', id: 'high' }, + ] + } + + const validOptions = getThinkingLevelsForModel(modelValue) + if (!validOptions) { + return [ + { label: 'low', id: 'low' }, + { label: 'high', id: 'high' }, + ] + } + + return validOptions.map((opt) => ({ label: opt, id: opt })) + }, + value: () => 'high', + condition: { + field: 'model', + value: MODELS_WITH_THINKING, + }, + }, { id: 'azureEndpoint', @@ -275,17 +340,21 @@ export const AgentBlock: BlockConfig = { password: true, connectionDroppable: false, required: true, - // Hide API key for hosted models, Ollama models, and vLLM models + // Hide API key for hosted models, Ollama models, vLLM models, and Vertex models (uses OAuth) condition: isHosted ? { field: 'model', - value: getHostedModels(), + value: [...getHostedModels(), ...providers.vertex.models], not: true, // Show for all models EXCEPT those listed } : () => ({ field: 'model', - value: [...getCurrentOllamaModels(), ...getCurrentVLLMModels()], - not: true, // Show for all models EXCEPT Ollama and vLLM models + value: [ + ...getCurrentOllamaModels(), + ...getCurrentVLLMModels(), + ...providers.vertex.models, + ], + not: true, // Show for all models EXCEPT Ollama, vLLM, and Vertex models }), }, { @@ -609,6 +678,7 @@ Example 3 (Array Input): temperature: { type: 'number', description: 'Response randomness level' }, reasoningEffort: { type: 'string', description: 'Reasoning effort level for GPT-5 models' }, verbosity: { type: 'string', description: 'Verbosity level for GPT-5 models' }, + thinkingLevel: { type: 'string', description: 'Thinking level for Gemini 3 models' }, tools: { type: 'json', description: 'Available tools configuration' }, }, outputs: { diff --git a/apps/sim/executor/handlers/agent/agent-handler.ts b/apps/sim/executor/handlers/agent/agent-handler.ts index e43533517..a358685ef 100644 --- a/apps/sim/executor/handlers/agent/agent-handler.ts +++ b/apps/sim/executor/handlers/agent/agent-handler.ts @@ -1,8 +1,9 @@ import { db } from '@sim/db' -import { mcpServers } from '@sim/db/schema' +import { account, mcpServers } from '@sim/db/schema' import { and, eq, inArray, isNull } from 'drizzle-orm' import { createLogger } from '@/lib/logs/console/logger' import { createMcpToolId } from '@/lib/mcp/utils' +import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils' import { getAllBlocks } from '@/blocks' import type { BlockOutput } from '@/blocks/types' import { AGENT, BlockType, DEFAULTS, HTTP } from '@/executor/constants' @@ -919,6 +920,7 @@ export class AgentBlockHandler implements BlockHandler { azureApiVersion: inputs.azureApiVersion, vertexProject: inputs.vertexProject, vertexLocation: inputs.vertexLocation, + vertexCredential: inputs.vertexCredential, responseFormat, workflowId: ctx.workflowId, workspaceId: ctx.workspaceId, @@ -997,7 +999,17 @@ export class AgentBlockHandler implements BlockHandler { responseFormat: any, providerStartTime: number ) { - const finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey) + let finalApiKey: string + + // For Vertex AI, resolve OAuth credential to access token + if (providerId === 'vertex' && providerRequest.vertexCredential) { + finalApiKey = await this.resolveVertexCredential( + providerRequest.vertexCredential, + ctx.workflowId + ) + } else { + finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey) + } const { blockData, blockNameMapping } = collectBlockData(ctx) @@ -1024,7 +1036,6 @@ export class AgentBlockHandler implements BlockHandler { blockNameMapping, }) - this.logExecutionSuccess(providerId, model, ctx, block, providerStartTime, response) return this.processProviderResponse(response, block, responseFormat) } @@ -1049,15 +1060,6 @@ export class AgentBlockHandler implements BlockHandler { throw new Error(errorMessage) } - this.logExecutionSuccess( - providerRequest.provider, - providerRequest.model, - ctx, - block, - providerStartTime, - 'HTTP response' - ) - const contentType = response.headers.get('Content-Type') if (contentType?.includes(HTTP.CONTENT_TYPE.EVENT_STREAM)) { return this.handleStreamingResponse(response, block, ctx, inputs) @@ -1117,21 +1119,33 @@ export class AgentBlockHandler implements BlockHandler { } } - private logExecutionSuccess( - provider: string, - model: string, - ctx: ExecutionContext, - block: SerializedBlock, - startTime: number, - response: any - ) { - const executionTime = Date.now() - startTime - const responseType = - response instanceof ReadableStream - ? 'stream' - : response && typeof response === 'object' && 'stream' in response - ? 'streaming-execution' - : 'json' + /** + * Resolves a Vertex AI OAuth credential to an access token + */ + private async resolveVertexCredential(credentialId: string, workflowId: string): Promise { + const requestId = `vertex-${Date.now()}` + + logger.info(`[${requestId}] Resolving Vertex AI credential: ${credentialId}`) + + // Get the credential - we need to find the owner + // Since we're in a workflow context, we can query the credential directly + const credential = await db.query.account.findFirst({ + where: eq(account.id, credentialId), + }) + + if (!credential) { + throw new Error(`Vertex AI credential not found: ${credentialId}`) + } + + // Refresh the token if needed + const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId) + + if (!accessToken) { + throw new Error('Failed to get Vertex AI access token') + } + + logger.info(`[${requestId}] Successfully resolved Vertex AI credential`) + return accessToken } private handleExecutionError( diff --git a/apps/sim/executor/handlers/agent/types.ts b/apps/sim/executor/handlers/agent/types.ts index facd129a6..60694171b 100644 --- a/apps/sim/executor/handlers/agent/types.ts +++ b/apps/sim/executor/handlers/agent/types.ts @@ -21,6 +21,7 @@ export interface AgentInputs { azureApiVersion?: string vertexProject?: string vertexLocation?: string + vertexCredential?: string reasoningEffort?: string verbosity?: string } diff --git a/apps/sim/lib/auth/auth.ts b/apps/sim/lib/auth/auth.ts index 6910bbb50..30796dbdc 100644 --- a/apps/sim/lib/auth/auth.ts +++ b/apps/sim/lib/auth/auth.ts @@ -579,6 +579,21 @@ export const auth = betterAuth({ redirectURI: `${getBaseUrl()}/api/auth/oauth2/callback/google-groups`, }, + { + providerId: 'vertex-ai', + clientId: env.GOOGLE_CLIENT_ID as string, + clientSecret: env.GOOGLE_CLIENT_SECRET as string, + discoveryUrl: 'https://accounts.google.com/.well-known/openid-configuration', + accessType: 'offline', + scopes: [ + 'https://www.googleapis.com/auth/userinfo.email', + 'https://www.googleapis.com/auth/userinfo.profile', + 'https://www.googleapis.com/auth/cloud-platform', + ], + prompt: 'consent', + redirectURI: `${getBaseUrl()}/api/auth/oauth2/callback/vertex-ai`, + }, + { providerId: 'microsoft-teams', clientId: env.MICROSOFT_CLIENT_ID as string, diff --git a/apps/sim/lib/core/utils/display-filters.ts b/apps/sim/lib/core/utils/display-filters.ts index 6e04322e9..21194e48a 100644 --- a/apps/sim/lib/core/utils/display-filters.ts +++ b/apps/sim/lib/core/utils/display-filters.ts @@ -41,7 +41,7 @@ function filterUserFile(data: any): any { const DISPLAY_FILTERS = [filterUserFile] export function filterForDisplay(data: any): any { - const seen = new WeakSet() + const seen = new Set() return filterForDisplayInternal(data, seen, 0) } @@ -49,7 +49,7 @@ function getObjectType(data: unknown): string { return Object.prototype.toString.call(data).slice(8, -1) } -function filterForDisplayInternal(data: any, seen: WeakSet, depth: number): any { +function filterForDisplayInternal(data: any, seen: Set, depth: number): any { try { if (data === null || data === undefined) { return data @@ -93,6 +93,7 @@ function filterForDisplayInternal(data: any, seen: WeakSet, depth: numbe return '[Unknown Type]' } + // True circular reference: object is an ancestor in the current path if (seen.has(data)) { return '[Circular Reference]' } @@ -131,18 +132,24 @@ function filterForDisplayInternal(data: any, seen: WeakSet, depth: numbe return `[ArrayBuffer: ${(data as ArrayBuffer).byteLength} bytes]` case 'Map': { + seen.add(data) const obj: Record = {} for (const [key, value] of (data as Map).entries()) { const keyStr = typeof key === 'string' ? key : String(key) obj[keyStr] = filterForDisplayInternal(value, seen, depth + 1) } + seen.delete(data) return obj } - case 'Set': - return Array.from(data as Set).map((item) => + case 'Set': { + seen.add(data) + const result = Array.from(data as Set).map((item) => filterForDisplayInternal(item, seen, depth + 1) ) + seen.delete(data) + return result + } case 'WeakMap': return '[WeakMap]' @@ -161,17 +168,22 @@ function filterForDisplayInternal(data: any, seen: WeakSet, depth: numbe return `[${objectType}: ${(data as ArrayBufferView).byteLength} bytes]` } + // Add to current path before processing children seen.add(data) for (const filterFn of DISPLAY_FILTERS) { - const result = filterFn(data) - if (result !== data) { - return filterForDisplayInternal(result, seen, depth + 1) + const filtered = filterFn(data) + if (filtered !== data) { + const result = filterForDisplayInternal(filtered, seen, depth + 1) + seen.delete(data) + return result } } if (Array.isArray(data)) { - return data.map((item) => filterForDisplayInternal(item, seen, depth + 1)) + const result = data.map((item) => filterForDisplayInternal(item, seen, depth + 1)) + seen.delete(data) + return result } const result: Record = {} @@ -182,6 +194,8 @@ function filterForDisplayInternal(data: any, seen: WeakSet, depth: numbe result[key] = '[Error accessing property]' } } + // Remove from current path after processing children + seen.delete(data) return result } catch { return '[Unserializable]' diff --git a/apps/sim/lib/oauth/oauth.ts b/apps/sim/lib/oauth/oauth.ts index 936b4d103..f9937f3be 100644 --- a/apps/sim/lib/oauth/oauth.ts +++ b/apps/sim/lib/oauth/oauth.ts @@ -32,6 +32,7 @@ import { SlackIcon, SpotifyIcon, TrelloIcon, + VertexIcon, WealthboxIcon, WebflowIcon, WordpressIcon, @@ -80,6 +81,7 @@ export type OAuthService = | 'google-vault' | 'google-forms' | 'google-groups' + | 'vertex-ai' | 'github' | 'x' | 'confluence' @@ -237,6 +239,16 @@ export const OAUTH_PROVIDERS: Record = { ], scopeHints: ['admin.directory.group'], }, + 'vertex-ai': { + id: 'vertex-ai', + name: 'Vertex AI', + description: 'Access Google Cloud Vertex AI for Gemini models with OAuth.', + providerId: 'vertex-ai', + icon: (props) => VertexIcon(props), + baseProviderIcon: (props) => VertexIcon(props), + scopes: ['https://www.googleapis.com/auth/cloud-platform'], + scopeHints: ['cloud-platform', 'vertex', 'aiplatform'], + }, }, defaultService: 'gmail', }, @@ -1099,6 +1111,12 @@ export function parseProvider(provider: OAuthProvider): ProviderConfig { featureType: 'microsoft-planner', } } + if (provider === 'vertex-ai') { + return { + baseProvider: 'google', + featureType: 'vertex-ai', + } + } // Handle compound providers (e.g., 'google-email' -> { baseProvider: 'google', featureType: 'email' }) const [base, feature] = provider.split('-') diff --git a/apps/sim/providers/anthropic/index.ts b/apps/sim/providers/anthropic/index.ts index 526fa5cb9..16c92ef84 100644 --- a/apps/sim/providers/anthropic/index.ts +++ b/apps/sim/providers/anthropic/index.ts @@ -58,7 +58,7 @@ export const anthropicProvider: ProviderConfig = { throw new Error('API key is required for Anthropic') } - const modelId = request.model || 'claude-3-7-sonnet-20250219' + const modelId = request.model const useNativeStructuredOutputs = !!( request.responseFormat && supportsNativeStructuredOutputs(modelId) ) @@ -174,7 +174,7 @@ export const anthropicProvider: ProviderConfig = { } const payload: any = { - model: request.model || 'claude-3-7-sonnet-20250219', + model: request.model, messages, system: systemPrompt, max_tokens: Number.parseInt(String(request.maxTokens)) || 1024, @@ -561,37 +561,93 @@ export const anthropicProvider: ProviderConfig = { throw error } - const providerEndTime = Date.now() - const providerEndTimeISO = new Date(providerEndTime).toISOString() - const totalDuration = providerEndTime - providerStartTime + const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion) - return { - content, - model: request.model || 'claude-3-7-sonnet-20250219', - tokens, - toolCalls: - toolCalls.length > 0 - ? toolCalls.map((tc) => ({ - name: tc.name, - arguments: tc.arguments as Record, - startTime: tc.startTime, - endTime: tc.endTime, - duration: tc.duration, - result: tc.result, - })) - : undefined, - toolResults: toolResults.length > 0 ? toolResults : undefined, - timing: { - startTime: providerStartTimeISO, - endTime: providerEndTimeISO, - duration: totalDuration, - modelTime: modelTime, - toolsTime: toolsTime, - firstResponseTime: firstResponseTime, - iterations: iterationCount + 1, - timeSegments: timeSegments, + const streamingPayload = { + ...payload, + messages: currentMessages, + stream: true, + tool_choice: undefined, + } + + const streamResponse: any = await anthropic.messages.create(streamingPayload) + + const streamingResult = { + stream: createReadableStreamFromAnthropicStream( + streamResponse, + (streamContent, usage) => { + streamingResult.execution.output.content = streamContent + streamingResult.execution.output.tokens = { + prompt: tokens.prompt + usage.input_tokens, + completion: tokens.completion + usage.output_tokens, + total: tokens.total + usage.input_tokens + usage.output_tokens, + } + + const streamCost = calculateCost( + request.model, + usage.input_tokens, + usage.output_tokens + ) + streamingResult.execution.output.cost = { + input: accumulatedCost.input + streamCost.input, + output: accumulatedCost.output + streamCost.output, + total: accumulatedCost.total + streamCost.total, + } + + const streamEndTime = Date.now() + const streamEndTimeISO = new Date(streamEndTime).toISOString() + + if (streamingResult.execution.output.providerTiming) { + streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO + streamingResult.execution.output.providerTiming.duration = + streamEndTime - providerStartTime + } + } + ), + execution: { + success: true, + output: { + content: '', + model: request.model, + tokens: { + prompt: tokens.prompt, + completion: tokens.completion, + total: tokens.total, + }, + toolCalls: + toolCalls.length > 0 + ? { + list: toolCalls, + count: toolCalls.length, + } + : undefined, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + modelTime: modelTime, + toolsTime: toolsTime, + firstResponseTime: firstResponseTime, + iterations: iterationCount + 1, + timeSegments: timeSegments, + }, + cost: { + input: accumulatedCost.input, + output: accumulatedCost.output, + total: accumulatedCost.total, + }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + }, + isStreaming: true, }, } + + return streamingResult as StreamingExecution } catch (error) { const providerEndTime = Date.now() const providerEndTimeISO = new Date(providerEndTime).toISOString() @@ -934,7 +990,7 @@ export const anthropicProvider: ProviderConfig = { success: true, output: { content: '', - model: request.model || 'claude-3-7-sonnet-20250219', + model: request.model, tokens: { prompt: tokens.prompt, completion: tokens.completion, @@ -978,7 +1034,7 @@ export const anthropicProvider: ProviderConfig = { return { content, - model: request.model || 'claude-3-7-sonnet-20250219', + model: request.model, tokens, toolCalls: toolCalls.length > 0 diff --git a/apps/sim/providers/azure-openai/index.ts b/apps/sim/providers/azure-openai/index.ts index ecd178b9f..2d63967c7 100644 --- a/apps/sim/providers/azure-openai/index.ts +++ b/apps/sim/providers/azure-openai/index.ts @@ -39,7 +39,7 @@ export const azureOpenAIProvider: ProviderConfig = { request: ProviderRequest ): Promise => { logger.info('Preparing Azure OpenAI request', { - model: request.model || 'azure/gpt-4o', + model: request.model, hasSystemPrompt: !!request.systemPrompt, hasMessages: !!request.messages?.length, hasTools: !!request.tools?.length, @@ -95,7 +95,7 @@ export const azureOpenAIProvider: ProviderConfig = { })) : undefined - const deploymentName = (request.model || 'azure/gpt-4o').replace('azure/', '') + const deploymentName = request.model.replace('azure/', '') const payload: any = { model: deploymentName, messages: allMessages, diff --git a/apps/sim/providers/cerebras/index.ts b/apps/sim/providers/cerebras/index.ts index 23cede1bd..131aad545 100644 --- a/apps/sim/providers/cerebras/index.ts +++ b/apps/sim/providers/cerebras/index.ts @@ -73,7 +73,7 @@ export const cerebrasProvider: ProviderConfig = { : undefined const payload: any = { - model: (request.model || 'cerebras/llama-3.3-70b').replace('cerebras/', ''), + model: request.model.replace('cerebras/', ''), messages: allMessages, } if (request.temperature !== undefined) payload.temperature = request.temperature @@ -145,7 +145,7 @@ export const cerebrasProvider: ProviderConfig = { success: true, output: { content: '', - model: request.model || 'cerebras/llama-3.3-70b', + model: request.model, tokens: { prompt: 0, completion: 0, total: 0 }, toolCalls: undefined, providerTiming: { @@ -470,7 +470,7 @@ export const cerebrasProvider: ProviderConfig = { success: true, output: { content: '', - model: request.model || 'cerebras/llama-3.3-70b', + model: request.model, tokens: { prompt: tokens.prompt, completion: tokens.completion, diff --git a/apps/sim/providers/deepseek/index.ts b/apps/sim/providers/deepseek/index.ts index 4f1336f58..c82809dd3 100644 --- a/apps/sim/providers/deepseek/index.ts +++ b/apps/sim/providers/deepseek/index.ts @@ -105,7 +105,7 @@ export const deepseekProvider: ProviderConfig = { : toolChoice.type === 'any' ? `force:${toolChoice.any?.name || 'unknown'}` : 'unknown', - model: request.model || 'deepseek-v3', + model: request.model, }) } } @@ -145,7 +145,7 @@ export const deepseekProvider: ProviderConfig = { success: true, output: { content: '', - model: request.model || 'deepseek-chat', + model: request.model, tokens: { prompt: 0, completion: 0, total: 0 }, toolCalls: undefined, providerTiming: { @@ -469,7 +469,7 @@ export const deepseekProvider: ProviderConfig = { success: true, output: { content: '', - model: request.model || 'deepseek-chat', + model: request.model, tokens: { prompt: tokens.prompt, completion: tokens.completion, diff --git a/apps/sim/providers/gemini/client.ts b/apps/sim/providers/gemini/client.ts new file mode 100644 index 000000000..0b5e5bdca --- /dev/null +++ b/apps/sim/providers/gemini/client.ts @@ -0,0 +1,58 @@ +import { GoogleGenAI } from '@google/genai' +import { createLogger } from '@/lib/logs/console/logger' +import type { GeminiClientConfig } from './types' + +const logger = createLogger('GeminiClient') + +/** + * Creates a GoogleGenAI client configured for either Google Gemini API or Vertex AI + * + * For Google Gemini API: + * - Uses API key authentication + * + * For Vertex AI: + * - Uses OAuth access token via HTTP Authorization header + * - Requires project and location + */ +export function createGeminiClient(config: GeminiClientConfig): GoogleGenAI { + if (config.vertexai) { + if (!config.project) { + throw new Error('Vertex AI requires a project ID') + } + if (!config.accessToken) { + throw new Error('Vertex AI requires an access token') + } + + const location = config.location ?? 'us-central1' + + logger.info('Creating Vertex AI client', { + project: config.project, + location, + hasAccessToken: !!config.accessToken, + }) + + // Create client with Vertex AI configuration + // Use httpOptions.headers to pass the access token directly + return new GoogleGenAI({ + vertexai: true, + project: config.project, + location, + httpOptions: { + headers: { + Authorization: `Bearer ${config.accessToken}`, + }, + }, + }) + } + + // Google Gemini API with API key + if (!config.apiKey) { + throw new Error('Google Gemini API requires an API key') + } + + logger.info('Creating Google Gemini client') + + return new GoogleGenAI({ + apiKey: config.apiKey, + }) +} diff --git a/apps/sim/providers/gemini/core.ts b/apps/sim/providers/gemini/core.ts new file mode 100644 index 000000000..f7cff4bac --- /dev/null +++ b/apps/sim/providers/gemini/core.ts @@ -0,0 +1,680 @@ +import { + type Content, + FunctionCallingConfigMode, + type FunctionDeclaration, + type GenerateContentConfig, + type GenerateContentResponse, + type GoogleGenAI, + type Part, + type Schema, + type ThinkingConfig, + type ToolConfig, +} from '@google/genai' +import { createLogger } from '@/lib/logs/console/logger' +import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import { + checkForForcedToolUsage, + cleanSchemaForGemini, + convertToGeminiFormat, + convertUsageMetadata, + createReadableStreamFromGeminiStream, + extractFunctionCallPart, + extractTextContent, + mapToThinkingLevel, +} from '@/providers/google/utils' +import { getThinkingCapability } from '@/providers/models' +import type { FunctionCallResponse, ProviderRequest, ProviderResponse } from '@/providers/types' +import { + calculateCost, + prepareToolExecution, + prepareToolsWithUsageControl, +} from '@/providers/utils' +import { executeTool } from '@/tools' +import type { ExecutionState, GeminiProviderType, GeminiUsage, ParsedFunctionCall } from './types' + +/** + * Creates initial execution state + */ +function createInitialState( + contents: Content[], + initialUsage: GeminiUsage, + firstResponseTime: number, + initialCallTime: number, + model: string, + toolConfig: ToolConfig | undefined +): ExecutionState { + const initialCost = calculateCost( + model, + initialUsage.promptTokenCount, + initialUsage.candidatesTokenCount + ) + + return { + contents, + tokens: { + prompt: initialUsage.promptTokenCount, + completion: initialUsage.candidatesTokenCount, + total: initialUsage.totalTokenCount, + }, + cost: initialCost, + toolCalls: [], + toolResults: [], + iterationCount: 0, + modelTime: firstResponseTime, + toolsTime: 0, + timeSegments: [ + { + type: 'model', + name: 'Initial response', + startTime: initialCallTime, + endTime: initialCallTime + firstResponseTime, + duration: firstResponseTime, + }, + ], + usedForcedTools: [], + currentToolConfig: toolConfig, + } +} + +/** + * Executes a tool call and updates state + */ +async function executeToolCall( + functionCallPart: Part, + functionCall: ParsedFunctionCall, + request: ProviderRequest, + state: ExecutionState, + forcedTools: string[], + logger: ReturnType +): Promise<{ success: boolean; state: ExecutionState }> { + const toolCallStartTime = Date.now() + const toolName = functionCall.name + + const tool = request.tools?.find((t) => t.id === toolName) + if (!tool) { + logger.warn(`Tool ${toolName} not found in registry, skipping`) + return { success: false, state } + } + + try { + const { toolParams, executionParams } = prepareToolExecution(tool, functionCall.args, request) + const result = await executeTool(toolName, executionParams, true) + const toolCallEndTime = Date.now() + const duration = toolCallEndTime - toolCallStartTime + + const resultContent: Record = result.success + ? (result.output as Record) + : { error: true, message: result.error || 'Tool execution failed', tool: toolName } + + const toolCall: FunctionCallResponse = { + name: toolName, + arguments: toolParams, + startTime: new Date(toolCallStartTime).toISOString(), + endTime: new Date(toolCallEndTime).toISOString(), + duration, + result: resultContent, + } + + const updatedContents: Content[] = [ + ...state.contents, + { + role: 'model', + parts: [functionCallPart], + }, + { + role: 'user', + parts: [ + { + functionResponse: { + name: functionCall.name, + response: resultContent, + }, + }, + ], + }, + ] + + const forcedToolCheck = checkForForcedToolUsage( + [{ name: functionCall.name, args: functionCall.args }], + state.currentToolConfig, + forcedTools, + state.usedForcedTools + ) + + return { + success: true, + state: { + ...state, + contents: updatedContents, + toolCalls: [...state.toolCalls, toolCall], + toolResults: result.success + ? [...state.toolResults, result.output as Record] + : state.toolResults, + toolsTime: state.toolsTime + duration, + timeSegments: [ + ...state.timeSegments, + { + type: 'tool', + name: toolName, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration, + }, + ], + usedForcedTools: forcedToolCheck?.usedForcedTools ?? state.usedForcedTools, + currentToolConfig: forcedToolCheck?.nextToolConfig ?? state.currentToolConfig, + }, + } + } catch (error) { + logger.error('Error processing function call:', { + error: error instanceof Error ? error.message : String(error), + functionName: toolName, + }) + return { success: false, state } + } +} + +/** + * Updates state with model response metadata + */ +function updateStateWithResponse( + state: ExecutionState, + response: GenerateContentResponse, + model: string, + startTime: number, + endTime: number +): ExecutionState { + const usage = convertUsageMetadata(response.usageMetadata) + const cost = calculateCost(model, usage.promptTokenCount, usage.candidatesTokenCount) + const duration = endTime - startTime + + return { + ...state, + tokens: { + prompt: state.tokens.prompt + usage.promptTokenCount, + completion: state.tokens.completion + usage.candidatesTokenCount, + total: state.tokens.total + usage.totalTokenCount, + }, + cost: { + input: state.cost.input + cost.input, + output: state.cost.output + cost.output, + total: state.cost.total + cost.total, + pricing: cost.pricing, // Use latest pricing + }, + modelTime: state.modelTime + duration, + timeSegments: [ + ...state.timeSegments, + { + type: 'model', + name: `Model response (iteration ${state.iterationCount + 1})`, + startTime, + endTime, + duration, + }, + ], + iterationCount: state.iterationCount + 1, + } +} + +/** + * Builds config for next iteration + */ +function buildNextConfig( + baseConfig: GenerateContentConfig, + state: ExecutionState, + forcedTools: string[], + request: ProviderRequest, + logger: ReturnType +): GenerateContentConfig { + const nextConfig = { ...baseConfig } + const allForcedToolsUsed = + forcedTools.length > 0 && state.usedForcedTools.length === forcedTools.length + + if (allForcedToolsUsed && request.responseFormat) { + nextConfig.tools = undefined + nextConfig.toolConfig = undefined + nextConfig.responseMimeType = 'application/json' + nextConfig.responseSchema = cleanSchemaForGemini(request.responseFormat.schema) as Schema + logger.info('Using structured output for final response after tool execution') + } else if (state.currentToolConfig) { + nextConfig.toolConfig = state.currentToolConfig + } else { + nextConfig.toolConfig = { functionCallingConfig: { mode: FunctionCallingConfigMode.AUTO } } + } + + return nextConfig +} + +/** + * Creates streaming execution result template + */ +function createStreamingResult( + providerStartTime: number, + providerStartTimeISO: string, + firstResponseTime: number, + initialCallTime: number, + state?: ExecutionState +): StreamingExecution { + return { + stream: undefined as unknown as ReadableStream, + execution: { + success: true, + output: { + content: '', + model: '', + tokens: state?.tokens ?? { prompt: 0, completion: 0, total: 0 }, + toolCalls: state?.toolCalls.length + ? { list: state.toolCalls, count: state.toolCalls.length } + : undefined, + toolResults: state?.toolResults, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + modelTime: state?.modelTime ?? firstResponseTime, + toolsTime: state?.toolsTime ?? 0, + firstResponseTime, + iterations: (state?.iterationCount ?? 0) + 1, + timeSegments: state?.timeSegments ?? [ + { + type: 'model', + name: 'Initial streaming response', + startTime: initialCallTime, + endTime: initialCallTime + firstResponseTime, + duration: firstResponseTime, + }, + ], + }, + cost: state?.cost ?? { + input: 0, + output: 0, + total: 0, + pricing: { input: 0, output: 0, updatedAt: new Date().toISOString() }, + }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + }, + isStreaming: true, + }, + } +} + +/** + * Configuration for executing a Gemini request + */ +export interface GeminiExecutionConfig { + ai: GoogleGenAI + model: string + request: ProviderRequest + providerType: GeminiProviderType +} + +/** + * Executes a request using the Gemini API + * + * This is the shared core logic for both Google and Vertex AI providers. + * The only difference is how the GoogleGenAI client is configured. + */ +export async function executeGeminiRequest( + config: GeminiExecutionConfig +): Promise { + const { ai, model, request, providerType } = config + const logger = createLogger(providerType === 'google' ? 'GoogleProvider' : 'VertexProvider') + + logger.info(`Preparing ${providerType} Gemini request`, { + model, + hasSystemPrompt: !!request.systemPrompt, + hasMessages: !!request.messages?.length, + hasTools: !!request.tools?.length, + toolCount: request.tools?.length ?? 0, + hasResponseFormat: !!request.responseFormat, + streaming: !!request.stream, + }) + + const providerStartTime = Date.now() + const providerStartTimeISO = new Date(providerStartTime).toISOString() + + try { + const { contents, tools, systemInstruction } = convertToGeminiFormat(request) + + // Build configuration + const geminiConfig: GenerateContentConfig = {} + + if (request.temperature !== undefined) { + geminiConfig.temperature = request.temperature + } + if (request.maxTokens !== undefined) { + geminiConfig.maxOutputTokens = request.maxTokens + } + if (systemInstruction) { + geminiConfig.systemInstruction = systemInstruction + } + + // Handle response format (only when no tools) + if (request.responseFormat && !tools?.length) { + geminiConfig.responseMimeType = 'application/json' + geminiConfig.responseSchema = cleanSchemaForGemini(request.responseFormat.schema) as Schema + logger.info('Using Gemini native structured output format') + } else if (request.responseFormat && tools?.length) { + logger.warn('Gemini does not support responseFormat with tools. Structured output ignored.') + } + + // Configure thinking for models that support it + const thinkingCapability = getThinkingCapability(model) + if (thinkingCapability) { + const level = request.thinkingLevel ?? thinkingCapability.default ?? 'high' + const thinkingConfig: ThinkingConfig = { + includeThoughts: false, + thinkingLevel: mapToThinkingLevel(level), + } + geminiConfig.thinkingConfig = thinkingConfig + } + + // Prepare tools + let preparedTools: ReturnType | null = null + let toolConfig: ToolConfig | undefined + + if (tools?.length) { + const functionDeclarations: FunctionDeclaration[] = tools.map((t) => ({ + name: t.name, + description: t.description, + parameters: t.parameters, + })) + + preparedTools = prepareToolsWithUsageControl( + functionDeclarations, + request.tools, + logger, + 'google' + ) + const { tools: filteredTools, toolConfig: preparedToolConfig } = preparedTools + + if (filteredTools?.length) { + geminiConfig.tools = [{ functionDeclarations: filteredTools as FunctionDeclaration[] }] + + if (preparedToolConfig) { + toolConfig = { + functionCallingConfig: { + mode: + { + AUTO: FunctionCallingConfigMode.AUTO, + ANY: FunctionCallingConfigMode.ANY, + NONE: FunctionCallingConfigMode.NONE, + }[preparedToolConfig.functionCallingConfig.mode] ?? FunctionCallingConfigMode.AUTO, + allowedFunctionNames: preparedToolConfig.functionCallingConfig.allowedFunctionNames, + }, + } + geminiConfig.toolConfig = toolConfig + } + + logger.info('Gemini request with tools:', { + toolCount: filteredTools.length, + model, + tools: filteredTools.map((t) => (t as FunctionDeclaration).name), + }) + } + } + + const initialCallTime = Date.now() + const shouldStream = request.stream && !tools?.length + + // Streaming without tools + if (shouldStream) { + logger.info('Handling Gemini streaming response') + + const streamGenerator = await ai.models.generateContentStream({ + model, + contents, + config: geminiConfig, + }) + const firstResponseTime = Date.now() - initialCallTime + + const streamingResult = createStreamingResult( + providerStartTime, + providerStartTimeISO, + firstResponseTime, + initialCallTime + ) + streamingResult.execution.output.model = model + + streamingResult.stream = createReadableStreamFromGeminiStream( + streamGenerator, + (content: string, usage: GeminiUsage) => { + streamingResult.execution.output.content = content + streamingResult.execution.output.tokens = { + prompt: usage.promptTokenCount, + completion: usage.candidatesTokenCount, + total: usage.totalTokenCount, + } + + const costResult = calculateCost( + model, + usage.promptTokenCount, + usage.candidatesTokenCount + ) + streamingResult.execution.output.cost = costResult + + const streamEndTime = Date.now() + if (streamingResult.execution.output.providerTiming) { + streamingResult.execution.output.providerTiming.endTime = new Date( + streamEndTime + ).toISOString() + streamingResult.execution.output.providerTiming.duration = + streamEndTime - providerStartTime + const segments = streamingResult.execution.output.providerTiming.timeSegments + if (segments?.[0]) { + segments[0].endTime = streamEndTime + segments[0].duration = streamEndTime - providerStartTime + } + } + } + ) + + return streamingResult + } + + // Non-streaming request + const response = await ai.models.generateContent({ model, contents, config: geminiConfig }) + const firstResponseTime = Date.now() - initialCallTime + + // Check for UNEXPECTED_TOOL_CALL + const candidate = response.candidates?.[0] + if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') { + logger.warn('Gemini returned UNEXPECTED_TOOL_CALL - model attempted to call unknown tool') + } + + const initialUsage = convertUsageMetadata(response.usageMetadata) + let state = createInitialState( + contents, + initialUsage, + firstResponseTime, + initialCallTime, + model, + toolConfig + ) + const forcedTools = preparedTools?.forcedTools ?? [] + + let currentResponse = response + let content = '' + + // Tool execution loop + const functionCalls = response.functionCalls + if (functionCalls?.length) { + logger.info(`Received function call from Gemini: ${functionCalls[0].name}`) + + while (state.iterationCount < MAX_TOOL_ITERATIONS) { + const functionCallPart = extractFunctionCallPart(currentResponse.candidates?.[0]) + if (!functionCallPart?.functionCall) { + content = extractTextContent(currentResponse.candidates?.[0]) + break + } + + const functionCall: ParsedFunctionCall = { + name: functionCallPart.functionCall.name ?? '', + args: (functionCallPart.functionCall.args ?? {}) as Record, + } + + logger.info( + `Processing function call: ${functionCall.name} (iteration ${state.iterationCount + 1})` + ) + + const { success, state: updatedState } = await executeToolCall( + functionCallPart, + functionCall, + request, + state, + forcedTools, + logger + ) + if (!success) { + content = extractTextContent(currentResponse.candidates?.[0]) + break + } + + state = { ...updatedState, iterationCount: updatedState.iterationCount + 1 } + const nextConfig = buildNextConfig(geminiConfig, state, forcedTools, request, logger) + + // Stream final response if requested + if (request.stream) { + const checkResponse = await ai.models.generateContent({ + model, + contents: state.contents, + config: nextConfig, + }) + state = updateStateWithResponse(state, checkResponse, model, Date.now() - 100, Date.now()) + + if (checkResponse.functionCalls?.length) { + currentResponse = checkResponse + continue + } + + logger.info('No more function calls, streaming final response') + + if (request.responseFormat) { + nextConfig.tools = undefined + nextConfig.toolConfig = undefined + nextConfig.responseMimeType = 'application/json' + nextConfig.responseSchema = cleanSchemaForGemini( + request.responseFormat.schema + ) as Schema + } + + // Capture accumulated cost before streaming + const accumulatedCost = { + input: state.cost.input, + output: state.cost.output, + total: state.cost.total, + } + const accumulatedTokens = { ...state.tokens } + + const streamGenerator = await ai.models.generateContentStream({ + model, + contents: state.contents, + config: nextConfig, + }) + + const streamingResult = createStreamingResult( + providerStartTime, + providerStartTimeISO, + firstResponseTime, + initialCallTime, + state + ) + streamingResult.execution.output.model = model + + streamingResult.stream = createReadableStreamFromGeminiStream( + streamGenerator, + (streamContent: string, usage: GeminiUsage) => { + streamingResult.execution.output.content = streamContent + streamingResult.execution.output.tokens = { + prompt: accumulatedTokens.prompt + usage.promptTokenCount, + completion: accumulatedTokens.completion + usage.candidatesTokenCount, + total: accumulatedTokens.total + usage.totalTokenCount, + } + + const streamCost = calculateCost( + model, + usage.promptTokenCount, + usage.candidatesTokenCount + ) + streamingResult.execution.output.cost = { + input: accumulatedCost.input + streamCost.input, + output: accumulatedCost.output + streamCost.output, + total: accumulatedCost.total + streamCost.total, + pricing: streamCost.pricing, + } + + if (streamingResult.execution.output.providerTiming) { + streamingResult.execution.output.providerTiming.endTime = new Date().toISOString() + streamingResult.execution.output.providerTiming.duration = + Date.now() - providerStartTime + } + } + ) + + return streamingResult + } + + // Non-streaming: get next response + const nextModelStartTime = Date.now() + const nextResponse = await ai.models.generateContent({ + model, + contents: state.contents, + config: nextConfig, + }) + state = updateStateWithResponse(state, nextResponse, model, nextModelStartTime, Date.now()) + currentResponse = nextResponse + } + + if (!content) { + content = extractTextContent(currentResponse.candidates?.[0]) + } + } else { + content = extractTextContent(candidate) + } + + const providerEndTime = Date.now() + + return { + content, + model, + tokens: state.tokens, + toolCalls: state.toolCalls.length ? state.toolCalls : undefined, + toolResults: state.toolResults.length ? state.toolResults : undefined, + timing: { + startTime: providerStartTimeISO, + endTime: new Date(providerEndTime).toISOString(), + duration: providerEndTime - providerStartTime, + modelTime: state.modelTime, + toolsTime: state.toolsTime, + firstResponseTime, + iterations: state.iterationCount + 1, + timeSegments: state.timeSegments, + }, + cost: state.cost, + } + } catch (error) { + const providerEndTime = Date.now() + const duration = providerEndTime - providerStartTime + + logger.error('Error in Gemini request:', { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + }) + + const enhancedError = error instanceof Error ? error : new Error(String(error)) + Object.assign(enhancedError, { + timing: { + startTime: providerStartTimeISO, + endTime: new Date(providerEndTime).toISOString(), + duration, + }, + }) + throw enhancedError + } +} diff --git a/apps/sim/providers/gemini/index.ts b/apps/sim/providers/gemini/index.ts new file mode 100644 index 000000000..378aa32c3 --- /dev/null +++ b/apps/sim/providers/gemini/index.ts @@ -0,0 +1,18 @@ +/** + * Shared Gemini execution core + * + * This module provides the shared execution logic for both Google Gemini API + * and Vertex AI providers. The only difference between providers is how the + * GoogleGenAI client is configured (API key vs OAuth). + */ + +export { createGeminiClient } from './client' +export { executeGeminiRequest, type GeminiExecutionConfig } from './core' +export type { + ExecutionState, + ForcedToolResult, + GeminiClientConfig, + GeminiProviderType, + GeminiUsage, + ParsedFunctionCall, +} from './types' diff --git a/apps/sim/providers/gemini/types.ts b/apps/sim/providers/gemini/types.ts new file mode 100644 index 000000000..02592d09b --- /dev/null +++ b/apps/sim/providers/gemini/types.ts @@ -0,0 +1,64 @@ +import type { Content, ToolConfig } from '@google/genai' +import type { FunctionCallResponse, ModelPricing, TimeSegment } from '@/providers/types' + +/** + * Usage metadata from Gemini responses + */ +export interface GeminiUsage { + promptTokenCount: number + candidatesTokenCount: number + totalTokenCount: number +} + +/** + * Parsed function call from Gemini response + */ +export interface ParsedFunctionCall { + name: string + args: Record +} + +/** + * Accumulated state during tool execution loop + */ +export interface ExecutionState { + contents: Content[] + tokens: { prompt: number; completion: number; total: number } + cost: { input: number; output: number; total: number; pricing: ModelPricing } + toolCalls: FunctionCallResponse[] + toolResults: Record[] + iterationCount: number + modelTime: number + toolsTime: number + timeSegments: TimeSegment[] + usedForcedTools: string[] + currentToolConfig: ToolConfig | undefined +} + +/** + * Result from forced tool usage check + */ +export interface ForcedToolResult { + hasUsedForcedTool: boolean + usedForcedTools: string[] + nextToolConfig: ToolConfig | undefined +} + +/** + * Configuration for creating a Gemini client + */ +export interface GeminiClientConfig { + /** For Google Gemini API */ + apiKey?: string + /** For Vertex AI */ + vertexai?: boolean + project?: string + location?: string + /** OAuth access token for Vertex AI */ + accessToken?: string +} + +/** + * Provider type for logging and model lookup + */ +export type GeminiProviderType = 'google' | 'vertex' diff --git a/apps/sim/providers/google/index.ts b/apps/sim/providers/google/index.ts index b17983033..c827480a6 100644 --- a/apps/sim/providers/google/index.ts +++ b/apps/sim/providers/google/index.ts @@ -1,235 +1,18 @@ +import { GoogleGenAI } from '@google/genai' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' -import { MAX_TOOL_ITERATIONS } from '@/providers' -import { - cleanSchemaForGemini, - convertToGeminiFormat, - extractFunctionCall, - extractTextContent, -} from '@/providers/google/utils' +import { executeGeminiRequest } from '@/providers/gemini/core' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' -import type { - ProviderConfig, - ProviderRequest, - ProviderResponse, - TimeSegment, -} from '@/providers/types' -import { - calculateCost, - prepareToolExecution, - prepareToolsWithUsageControl, - trackForcedToolUsage, -} from '@/providers/utils' -import { executeTool } from '@/tools' +import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types' const logger = createLogger('GoogleProvider') -interface GeminiStreamUsage { - promptTokenCount: number - candidatesTokenCount: number - totalTokenCount: number -} - -function createReadableStreamFromGeminiStream( - response: Response, - onComplete?: (content: string, usage: GeminiStreamUsage) => void -): ReadableStream { - const reader = response.body?.getReader() - if (!reader) { - throw new Error('Failed to get reader from response body') - } - - return new ReadableStream({ - async start(controller) { - try { - let buffer = '' - let fullContent = '' - let promptTokenCount = 0 - let candidatesTokenCount = 0 - let totalTokenCount = 0 - - const updateUsage = (metadata: any) => { - if (metadata) { - promptTokenCount = metadata.promptTokenCount ?? promptTokenCount - candidatesTokenCount = metadata.candidatesTokenCount ?? candidatesTokenCount - totalTokenCount = metadata.totalTokenCount ?? totalTokenCount - } - } - - const buildUsage = (): GeminiStreamUsage => ({ - promptTokenCount, - candidatesTokenCount, - totalTokenCount, - }) - - const complete = () => { - if (onComplete) { - if (promptTokenCount === 0 && candidatesTokenCount === 0) { - logger.warn('Gemini stream completed without usage metadata') - } - onComplete(fullContent, buildUsage()) - } - } - - while (true) { - const { done, value } = await reader.read() - if (done) { - if (buffer.trim()) { - try { - const data = JSON.parse(buffer.trim()) - updateUsage(data.usageMetadata) - const candidate = data.candidates?.[0] - if (candidate?.content?.parts) { - const functionCall = extractFunctionCall(candidate) - if (functionCall) { - logger.debug('Function call detected in final buffer', { - functionName: functionCall.name, - }) - complete() - controller.close() - return - } - const content = extractTextContent(candidate) - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - } catch (e) { - if (buffer.trim().startsWith('[')) { - try { - const dataArray = JSON.parse(buffer.trim()) - if (Array.isArray(dataArray)) { - for (const item of dataArray) { - updateUsage(item.usageMetadata) - const candidate = item.candidates?.[0] - if (candidate?.content?.parts) { - const functionCall = extractFunctionCall(candidate) - if (functionCall) { - logger.debug('Function call detected in array item', { - functionName: functionCall.name, - }) - complete() - controller.close() - return - } - const content = extractTextContent(candidate) - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - } - } - } catch (_) {} - } - } - } - complete() - controller.close() - break - } - - const text = new TextDecoder().decode(value) - buffer += text - - let searchIndex = 0 - while (searchIndex < buffer.length) { - const openBrace = buffer.indexOf('{', searchIndex) - if (openBrace === -1) break - - let braceCount = 0 - let inString = false - let escaped = false - let closeBrace = -1 - - for (let i = openBrace; i < buffer.length; i++) { - const char = buffer[i] - - if (!inString) { - if (char === '"' && !escaped) { - inString = true - } else if (char === '{') { - braceCount++ - } else if (char === '}') { - braceCount-- - if (braceCount === 0) { - closeBrace = i - break - } - } - } else { - if (char === '"' && !escaped) { - inString = false - } - } - - escaped = char === '\\' && !escaped - } - - if (closeBrace !== -1) { - const jsonStr = buffer.substring(openBrace, closeBrace + 1) - - try { - const data = JSON.parse(jsonStr) - updateUsage(data.usageMetadata) - const candidate = data.candidates?.[0] - - if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') { - logger.warn('Gemini returned UNEXPECTED_TOOL_CALL in streaming mode') - const textContent = extractTextContent(candidate) - if (textContent) { - fullContent += textContent - controller.enqueue(new TextEncoder().encode(textContent)) - } - complete() - controller.close() - return - } - - if (candidate?.content?.parts) { - const functionCall = extractFunctionCall(candidate) - if (functionCall) { - logger.debug('Function call detected in stream', { - functionName: functionCall.name, - }) - complete() - controller.close() - return - } - const content = extractTextContent(candidate) - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - } catch (e) { - logger.error('Error parsing JSON from stream', { - error: e instanceof Error ? e.message : String(e), - jsonPreview: jsonStr.substring(0, 200), - }) - } - - buffer = buffer.substring(closeBrace + 1) - searchIndex = 0 - } else { - break - } - } - } - } catch (e) { - logger.error('Error reading Google Gemini stream', { - error: e instanceof Error ? e.message : String(e), - }) - controller.error(e) - } - }, - async cancel() { - await reader.cancel() - }, - }) -} - +/** + * Google Gemini provider + * + * Uses the @google/genai SDK with API key authentication. + * Shares core execution logic with Vertex AI provider. + */ export const googleProvider: ProviderConfig = { id: 'google', name: 'Google', @@ -245,854 +28,15 @@ export const googleProvider: ProviderConfig = { throw new Error('API key is required for Google Gemini') } - logger.info('Preparing Google Gemini request', { - model: request.model || 'gemini-2.5-pro', - hasSystemPrompt: !!request.systemPrompt, - hasMessages: !!request.messages?.length, - hasTools: !!request.tools?.length, - toolCount: request.tools?.length || 0, - hasResponseFormat: !!request.responseFormat, - streaming: !!request.stream, + logger.info('Creating Google Gemini client', { model: request.model }) + + const ai = new GoogleGenAI({ apiKey: request.apiKey }) + + return executeGeminiRequest({ + ai, + model: request.model, + request, + providerType: 'google', }) - - const providerStartTime = Date.now() - const providerStartTimeISO = new Date(providerStartTime).toISOString() - - try { - const { contents, tools, systemInstruction } = convertToGeminiFormat(request) - - const requestedModel = request.model || 'gemini-2.5-pro' - - const payload: any = { - contents, - generationConfig: {}, - } - - if (request.temperature !== undefined && request.temperature !== null) { - payload.generationConfig.temperature = request.temperature - } - - if (request.maxTokens !== undefined) { - payload.generationConfig.maxOutputTokens = request.maxTokens - } - - if (systemInstruction) { - payload.systemInstruction = systemInstruction - } - - if (request.responseFormat && !tools?.length) { - const responseFormatSchema = request.responseFormat.schema || request.responseFormat - - const cleanSchema = cleanSchemaForGemini(responseFormatSchema) - - payload.generationConfig.responseMimeType = 'application/json' - payload.generationConfig.responseSchema = cleanSchema - - logger.info('Using Gemini native structured output format', { - hasSchema: !!cleanSchema, - mimeType: 'application/json', - }) - } else if (request.responseFormat && tools?.length) { - logger.warn( - 'Gemini does not support structured output (responseFormat) with function calling (tools). Structured output will be ignored.' - ) - } - - let preparedTools: ReturnType | null = null - - if (tools?.length) { - preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'google') - const { tools: filteredTools, toolConfig } = preparedTools - - if (filteredTools?.length) { - payload.tools = [ - { - functionDeclarations: filteredTools, - }, - ] - - if (toolConfig) { - payload.toolConfig = toolConfig - } - - logger.info('Google Gemini request with tools:', { - toolCount: filteredTools.length, - model: requestedModel, - tools: filteredTools.map((t) => t.name), - hasToolConfig: !!toolConfig, - toolConfig: toolConfig, - }) - } - } - - const initialCallTime = Date.now() - - const shouldStream = request.stream && !tools?.length - - const endpoint = shouldStream - ? `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:streamGenerateContent?key=${request.apiKey}` - : `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:generateContent?key=${request.apiKey}` - - if (request.stream && tools?.length) { - logger.info('Streaming disabled for initial request due to tools presence', { - toolCount: tools.length, - willStreamAfterTools: true, - }) - } - - const response = await fetch(endpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(payload), - }) - - if (!response.ok) { - const responseText = await response.text() - logger.error('Gemini API error details:', { - status: response.status, - statusText: response.statusText, - responseBody: responseText, - }) - throw new Error(`Gemini API error: ${response.status} ${response.statusText}`) - } - - const firstResponseTime = Date.now() - initialCallTime - - if (shouldStream) { - logger.info('Handling Google Gemini streaming response') - - const streamingResult: StreamingExecution = { - stream: null as any, - execution: { - success: true, - output: { - content: '', - model: request.model, - tokens: { - prompt: 0, - completion: 0, - total: 0, - }, - providerTiming: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: firstResponseTime, - modelTime: firstResponseTime, - toolsTime: 0, - firstResponseTime, - iterations: 1, - timeSegments: [ - { - type: 'model', - name: 'Initial streaming response', - startTime: initialCallTime, - endTime: initialCallTime + firstResponseTime, - duration: firstResponseTime, - }, - ], - }, - cost: { input: 0, output: 0, total: 0 }, - }, - logs: [], - metadata: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: firstResponseTime, - }, - isStreaming: true, - }, - } - - streamingResult.stream = createReadableStreamFromGeminiStream( - response, - (content, usage) => { - streamingResult.execution.output.content = content - streamingResult.execution.output.tokens = { - prompt: usage.promptTokenCount, - completion: usage.candidatesTokenCount, - total: usage.totalTokenCount || usage.promptTokenCount + usage.candidatesTokenCount, - } - - const costResult = calculateCost( - request.model, - usage.promptTokenCount, - usage.candidatesTokenCount - ) - streamingResult.execution.output.cost = { - input: costResult.input, - output: costResult.output, - total: costResult.total, - } - - const streamEndTime = Date.now() - const streamEndTimeISO = new Date(streamEndTime).toISOString() - - if (streamingResult.execution.output.providerTiming) { - streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO - streamingResult.execution.output.providerTiming.duration = - streamEndTime - providerStartTime - - if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) { - streamingResult.execution.output.providerTiming.timeSegments[0].endTime = - streamEndTime - streamingResult.execution.output.providerTiming.timeSegments[0].duration = - streamEndTime - providerStartTime - } - } - } - ) - - return streamingResult - } - - let geminiResponse = await response.json() - - if (payload.generationConfig?.responseSchema) { - const candidate = geminiResponse.candidates?.[0] - if (candidate?.content?.parts?.[0]?.text) { - const text = candidate.content.parts[0].text - try { - JSON.parse(text) - logger.info('Successfully received structured JSON output') - } catch (_e) { - logger.warn('Failed to parse structured output as JSON') - } - } - } - - let content = '' - let tokens = { - prompt: 0, - completion: 0, - total: 0, - } - const cost = { - input: 0, - output: 0, - total: 0, - } - const toolCalls = [] - const toolResults = [] - let iterationCount = 0 - - const originalToolConfig = preparedTools?.toolConfig - const forcedTools = preparedTools?.forcedTools || [] - let usedForcedTools: string[] = [] - let hasUsedForcedTool = false - let currentToolConfig = originalToolConfig - - const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => { - if (currentToolConfig && forcedTools.length > 0) { - const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }] - const result = trackForcedToolUsage( - toolCallsForTracking, - currentToolConfig, - logger, - 'google', - forcedTools, - usedForcedTools - ) - hasUsedForcedTool = result.hasUsedForcedTool - usedForcedTools = result.usedForcedTools - - if (result.nextToolConfig) { - currentToolConfig = result.nextToolConfig - logger.info('Updated tool config for next iteration', { - hasNextToolConfig: !!currentToolConfig, - usedForcedTools: usedForcedTools, - }) - } - } - } - - let modelTime = firstResponseTime - let toolsTime = 0 - - const timeSegments: TimeSegment[] = [ - { - type: 'model', - name: 'Initial response', - startTime: initialCallTime, - endTime: initialCallTime + firstResponseTime, - duration: firstResponseTime, - }, - ] - - try { - const candidate = geminiResponse.candidates?.[0] - - if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') { - logger.warn( - 'Gemini returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided', - { - finishReason: candidate.finishReason, - hasContent: !!candidate?.content, - hasParts: !!candidate?.content?.parts, - } - ) - content = extractTextContent(candidate) - } - - const functionCall = extractFunctionCall(candidate) - - if (functionCall) { - logger.info(`Received function call from Gemini: ${functionCall.name}`) - - while (iterationCount < MAX_TOOL_ITERATIONS) { - const latestResponse = geminiResponse.candidates?.[0] - const latestFunctionCall = extractFunctionCall(latestResponse) - - if (!latestFunctionCall) { - content = extractTextContent(latestResponse) - break - } - - logger.info( - `Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` - ) - - const toolsStartTime = Date.now() - - try { - const toolName = latestFunctionCall.name - const toolArgs = latestFunctionCall.args || {} - - const tool = request.tools?.find((t) => t.id === toolName) - if (!tool) { - logger.warn(`Tool ${toolName} not found in registry, skipping`) - break - } - - const toolCallStartTime = Date.now() - - const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams, true) - const toolCallEndTime = Date.now() - const toolCallDuration = toolCallEndTime - toolCallStartTime - - timeSegments.push({ - type: 'tool', - name: toolName, - startTime: toolCallStartTime, - endTime: toolCallEndTime, - duration: toolCallDuration, - }) - - let resultContent: any - if (result.success) { - toolResults.push(result.output) - resultContent = result.output - } else { - resultContent = { - error: true, - message: result.error || 'Tool execution failed', - tool: toolName, - } - } - - toolCalls.push({ - name: toolName, - arguments: toolParams, - startTime: new Date(toolCallStartTime).toISOString(), - endTime: new Date(toolCallEndTime).toISOString(), - duration: toolCallDuration, - result: resultContent, - success: result.success, - }) - - const simplifiedMessages = [ - ...(contents.filter((m) => m.role === 'user').length > 0 - ? [contents.filter((m) => m.role === 'user')[0]] - : [contents[0]]), - { - role: 'model', - parts: [ - { - functionCall: { - name: latestFunctionCall.name, - args: latestFunctionCall.args, - }, - }, - ], - }, - { - role: 'user', - parts: [ - { - text: `Function ${latestFunctionCall.name} result: ${JSON.stringify(resultContent)}`, - }, - ], - }, - ] - - const thisToolsTime = Date.now() - toolsStartTime - toolsTime += thisToolsTime - - checkForForcedToolUsage(latestFunctionCall) - - const nextModelStartTime = Date.now() - - try { - if (request.stream) { - const streamingPayload = { - ...payload, - contents: simplifiedMessages, - } - - const allForcedToolsUsed = - forcedTools.length > 0 && usedForcedTools.length === forcedTools.length - - if (allForcedToolsUsed && request.responseFormat) { - streamingPayload.tools = undefined - streamingPayload.toolConfig = undefined - - const responseFormatSchema = - request.responseFormat.schema || request.responseFormat - const cleanSchema = cleanSchemaForGemini(responseFormatSchema) - - if (!streamingPayload.generationConfig) { - streamingPayload.generationConfig = {} - } - streamingPayload.generationConfig.responseMimeType = 'application/json' - streamingPayload.generationConfig.responseSchema = cleanSchema - - logger.info('Using structured output for final response after tool execution') - } else { - if (currentToolConfig) { - streamingPayload.toolConfig = currentToolConfig - } else { - streamingPayload.toolConfig = { functionCallingConfig: { mode: 'AUTO' } } - } - } - - const checkPayload = { - ...streamingPayload, - } - checkPayload.stream = undefined - - const checkResponse = await fetch( - `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:generateContent?key=${request.apiKey}`, - { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(checkPayload), - } - ) - - if (!checkResponse.ok) { - const errorBody = await checkResponse.text() - logger.error('Error in Gemini check request:', { - status: checkResponse.status, - statusText: checkResponse.statusText, - responseBody: errorBody, - }) - throw new Error( - `Gemini API check error: ${checkResponse.status} ${checkResponse.statusText}` - ) - } - - const checkResult = await checkResponse.json() - const checkCandidate = checkResult.candidates?.[0] - const checkFunctionCall = extractFunctionCall(checkCandidate) - - if (checkFunctionCall) { - logger.info( - 'Function call detected in follow-up, handling in non-streaming mode', - { - functionName: checkFunctionCall.name, - } - ) - - geminiResponse = checkResult - - if (checkResult.usageMetadata) { - tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0 - tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0 - tokens.total += - (checkResult.usageMetadata.promptTokenCount || 0) + - (checkResult.usageMetadata.candidatesTokenCount || 0) - - const iterationCost = calculateCost( - request.model, - checkResult.usageMetadata.promptTokenCount || 0, - checkResult.usageMetadata.candidatesTokenCount || 0 - ) - cost.input += iterationCost.input - cost.output += iterationCost.output - cost.total += iterationCost.total - } - - const nextModelEndTime = Date.now() - const thisModelTime = nextModelEndTime - nextModelStartTime - modelTime += thisModelTime - - timeSegments.push({ - type: 'model', - name: `Model response (iteration ${iterationCount + 1})`, - startTime: nextModelStartTime, - endTime: nextModelEndTime, - duration: thisModelTime, - }) - - iterationCount++ - continue - } - logger.info('No function call detected, proceeding with streaming response') - - if (request.responseFormat) { - streamingPayload.tools = undefined - streamingPayload.toolConfig = undefined - - const responseFormatSchema = - request.responseFormat.schema || request.responseFormat - const cleanSchema = cleanSchemaForGemini(responseFormatSchema) - - if (!streamingPayload.generationConfig) { - streamingPayload.generationConfig = {} - } - streamingPayload.generationConfig.responseMimeType = 'application/json' - streamingPayload.generationConfig.responseSchema = cleanSchema - - logger.info( - 'Using structured output for final streaming response after tool execution' - ) - } - - const streamingResponse = await fetch( - `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:streamGenerateContent?key=${request.apiKey}`, - { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(streamingPayload), - } - ) - - if (!streamingResponse.ok) { - const errorBody = await streamingResponse.text() - logger.error('Error in Gemini streaming follow-up request:', { - status: streamingResponse.status, - statusText: streamingResponse.statusText, - responseBody: errorBody, - }) - throw new Error( - `Gemini API streaming error: ${streamingResponse.status} ${streamingResponse.statusText}` - ) - } - - const nextModelEndTime = Date.now() - const thisModelTime = nextModelEndTime - nextModelStartTime - modelTime += thisModelTime - - timeSegments.push({ - type: 'model', - name: 'Final streaming response after tool calls', - startTime: nextModelStartTime, - endTime: nextModelEndTime, - duration: thisModelTime, - }) - - const streamingExecution: StreamingExecution = { - stream: null as any, - execution: { - success: true, - output: { - content: '', - model: request.model, - tokens, - toolCalls: - toolCalls.length > 0 - ? { - list: toolCalls, - count: toolCalls.length, - } - : undefined, - toolResults, - providerTiming: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: Date.now() - providerStartTime, - modelTime, - toolsTime, - firstResponseTime, - iterations: iterationCount + 1, - timeSegments, - }, - cost, - }, - logs: [], - metadata: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: Date.now() - providerStartTime, - }, - isStreaming: true, - }, - } - - streamingExecution.stream = createReadableStreamFromGeminiStream( - streamingResponse, - (content, usage) => { - streamingExecution.execution.output.content = content - - const existingTokens = streamingExecution.execution.output.tokens - streamingExecution.execution.output.tokens = { - prompt: (existingTokens?.prompt ?? 0) + usage.promptTokenCount, - completion: (existingTokens?.completion ?? 0) + usage.candidatesTokenCount, - total: - (existingTokens?.total ?? 0) + - (usage.totalTokenCount || - usage.promptTokenCount + usage.candidatesTokenCount), - } - - const streamCost = calculateCost( - request.model, - usage.promptTokenCount, - usage.candidatesTokenCount - ) - const existingCost = streamingExecution.execution.output.cost as any - streamingExecution.execution.output.cost = { - input: (existingCost?.input ?? 0) + streamCost.input, - output: (existingCost?.output ?? 0) + streamCost.output, - total: (existingCost?.total ?? 0) + streamCost.total, - } - - const streamEndTime = Date.now() - const streamEndTimeISO = new Date(streamEndTime).toISOString() - - if (streamingExecution.execution.output.providerTiming) { - streamingExecution.execution.output.providerTiming.endTime = - streamEndTimeISO - streamingExecution.execution.output.providerTiming.duration = - streamEndTime - providerStartTime - } - } - ) - - return streamingExecution - } - - const nextPayload = { - ...payload, - contents: simplifiedMessages, - } - - const allForcedToolsUsed = - forcedTools.length > 0 && usedForcedTools.length === forcedTools.length - - if (allForcedToolsUsed && request.responseFormat) { - nextPayload.tools = undefined - nextPayload.toolConfig = undefined - - const responseFormatSchema = - request.responseFormat.schema || request.responseFormat - const cleanSchema = cleanSchemaForGemini(responseFormatSchema) - - if (!nextPayload.generationConfig) { - nextPayload.generationConfig = {} - } - nextPayload.generationConfig.responseMimeType = 'application/json' - nextPayload.generationConfig.responseSchema = cleanSchema - - logger.info( - 'Using structured output for final non-streaming response after tool execution' - ) - } else { - if (currentToolConfig) { - nextPayload.toolConfig = currentToolConfig - } - } - - const nextResponse = await fetch( - `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:generateContent?key=${request.apiKey}`, - { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(nextPayload), - } - ) - - if (!nextResponse.ok) { - const errorBody = await nextResponse.text() - logger.error('Error in Gemini follow-up request:', { - status: nextResponse.status, - statusText: nextResponse.statusText, - responseBody: errorBody, - iterationCount, - }) - break - } - - geminiResponse = await nextResponse.json() - - const nextModelEndTime = Date.now() - const thisModelTime = nextModelEndTime - nextModelStartTime - - timeSegments.push({ - type: 'model', - name: `Model response (iteration ${iterationCount + 1})`, - startTime: nextModelStartTime, - endTime: nextModelEndTime, - duration: thisModelTime, - }) - - modelTime += thisModelTime - - const nextCandidate = geminiResponse.candidates?.[0] - const nextFunctionCall = extractFunctionCall(nextCandidate) - - if (!nextFunctionCall) { - if (request.responseFormat) { - const finalPayload = { - ...payload, - contents: nextPayload.contents, - tools: undefined, - toolConfig: undefined, - } - - const responseFormatSchema = - request.responseFormat.schema || request.responseFormat - const cleanSchema = cleanSchemaForGemini(responseFormatSchema) - - if (!finalPayload.generationConfig) { - finalPayload.generationConfig = {} - } - finalPayload.generationConfig.responseMimeType = 'application/json' - finalPayload.generationConfig.responseSchema = cleanSchema - - logger.info('Making final request with structured output after tool execution') - - const finalResponse = await fetch( - `https://generativelanguage.googleapis.com/v1beta/models/${requestedModel}:generateContent?key=${request.apiKey}`, - { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(finalPayload), - } - ) - - if (finalResponse.ok) { - const finalResult = await finalResponse.json() - const finalCandidate = finalResult.candidates?.[0] - content = extractTextContent(finalCandidate) - - if (finalResult.usageMetadata) { - tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0 - tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0 - tokens.total += - (finalResult.usageMetadata.promptTokenCount || 0) + - (finalResult.usageMetadata.candidatesTokenCount || 0) - - const iterationCost = calculateCost( - request.model, - finalResult.usageMetadata.promptTokenCount || 0, - finalResult.usageMetadata.candidatesTokenCount || 0 - ) - cost.input += iterationCost.input - cost.output += iterationCost.output - cost.total += iterationCost.total - } - } else { - logger.warn( - 'Failed to get structured output, falling back to regular response' - ) - content = extractTextContent(nextCandidate) - } - } else { - content = extractTextContent(nextCandidate) - } - break - } - - iterationCount++ - } catch (error) { - logger.error('Error in Gemini follow-up request:', { - error: error instanceof Error ? error.message : String(error), - iterationCount, - }) - break - } - } catch (error) { - logger.error('Error processing function call:', { - error: error instanceof Error ? error.message : String(error), - functionName: latestFunctionCall?.name || 'unknown', - }) - break - } - } - } else { - content = extractTextContent(candidate) - } - } catch (error) { - logger.error('Error processing Gemini response:', { - error: error instanceof Error ? error.message : String(error), - iterationCount, - }) - - if (!content && toolCalls.length > 0) { - content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.` - } - } - - const providerEndTime = Date.now() - const providerEndTimeISO = new Date(providerEndTime).toISOString() - const totalDuration = providerEndTime - providerStartTime - - if (geminiResponse.usageMetadata) { - tokens = { - prompt: geminiResponse.usageMetadata.promptTokenCount || 0, - completion: geminiResponse.usageMetadata.candidatesTokenCount || 0, - total: - (geminiResponse.usageMetadata.promptTokenCount || 0) + - (geminiResponse.usageMetadata.candidatesTokenCount || 0), - } - } - - return { - content, - model: request.model, - tokens, - toolCalls: toolCalls.length > 0 ? toolCalls : undefined, - toolResults: toolResults.length > 0 ? toolResults : undefined, - timing: { - startTime: providerStartTimeISO, - endTime: providerEndTimeISO, - duration: totalDuration, - modelTime: modelTime, - toolsTime: toolsTime, - firstResponseTime: firstResponseTime, - iterations: iterationCount + 1, - timeSegments: timeSegments, - }, - } - } catch (error) { - const providerEndTime = Date.now() - const providerEndTimeISO = new Date(providerEndTime).toISOString() - const totalDuration = providerEndTime - providerStartTime - - logger.error('Error in Google Gemini request:', { - error: error instanceof Error ? error.message : String(error), - duration: totalDuration, - }) - - const enhancedError = new Error(error instanceof Error ? error.message : String(error)) - // @ts-ignore - enhancedError.timing = { - startTime: providerStartTimeISO, - endTime: providerEndTimeISO, - duration: totalDuration, - } - - throw enhancedError - } }, } diff --git a/apps/sim/providers/google/utils.ts b/apps/sim/providers/google/utils.ts index 1d5675f37..c663137b5 100644 --- a/apps/sim/providers/google/utils.ts +++ b/apps/sim/providers/google/utils.ts @@ -1,61 +1,89 @@ -import type { Candidate } from '@google/genai' +import { + type Candidate, + type Content, + type FunctionCall, + FunctionCallingConfigMode, + type GenerateContentResponse, + type GenerateContentResponseUsageMetadata, + type Part, + type Schema, + type SchemaUnion, + ThinkingLevel, + type ToolConfig, + Type, +} from '@google/genai' +import { createLogger } from '@/lib/logs/console/logger' import type { ProviderRequest } from '@/providers/types' +import { trackForcedToolUsage } from '@/providers/utils' + +const logger = createLogger('GoogleUtils') + +/** + * Usage metadata for Google Gemini responses + */ +export interface GeminiUsage { + promptTokenCount: number + candidatesTokenCount: number + totalTokenCount: number +} + +/** + * Parsed function call from Gemini response + */ +export interface ParsedFunctionCall { + name: string + args: Record +} /** * Removes additionalProperties from a schema object (not supported by Gemini) */ -export function cleanSchemaForGemini(schema: any): any { +export function cleanSchemaForGemini(schema: SchemaUnion): SchemaUnion { if (schema === null || schema === undefined) return schema if (typeof schema !== 'object') return schema if (Array.isArray(schema)) { return schema.map((item) => cleanSchemaForGemini(item)) } - const cleanedSchema: any = {} + const cleanedSchema: Record = {} + const schemaObj = schema as Record - for (const key in schema) { + for (const key in schemaObj) { if (key === 'additionalProperties') continue - cleanedSchema[key] = cleanSchemaForGemini(schema[key]) + cleanedSchema[key] = cleanSchemaForGemini(schemaObj[key] as SchemaUnion) } return cleanedSchema } /** - * Extracts text content from a Gemini response candidate, handling structured output + * Extracts text content from a Gemini response candidate. + * Filters out thought parts (model reasoning) from the output. */ export function extractTextContent(candidate: Candidate | undefined): string { if (!candidate?.content?.parts) return '' - if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) { - const text = candidate.content.parts[0].text - if (text && (text.trim().startsWith('{') || text.trim().startsWith('['))) { - try { - JSON.parse(text) - return text - } catch (_e) {} - } - } + const textParts = candidate.content.parts.filter( + (part): part is Part & { text: string } => Boolean(part.text) && part.thought !== true + ) - return candidate.content.parts - .filter((part: any) => part.text) - .map((part: any) => part.text) - .join('\n') + if (textParts.length === 0) return '' + if (textParts.length === 1) return textParts[0].text + + return textParts.map((part) => part.text).join('\n') } /** - * Extracts a function call from a Gemini response candidate + * Extracts the first function call from a Gemini response candidate */ -export function extractFunctionCall( - candidate: Candidate | undefined -): { name: string; args: any } | null { +export function extractFunctionCall(candidate: Candidate | undefined): ParsedFunctionCall | null { if (!candidate?.content?.parts) return null for (const part of candidate.content.parts) { if (part.functionCall) { return { name: part.functionCall.name ?? '', - args: part.functionCall.args ?? {}, + args: (part.functionCall.args ?? {}) as Record, } } } @@ -63,16 +91,55 @@ export function extractFunctionCall( return null } +/** + * Extracts the full Part containing the function call (preserves thoughtSignature) + */ +export function extractFunctionCallPart(candidate: Candidate | undefined): Part | null { + if (!candidate?.content?.parts) return null + + for (const part of candidate.content.parts) { + if (part.functionCall) { + return part + } + } + + return null +} + +/** + * Converts usage metadata from SDK response to our format + */ +export function convertUsageMetadata( + usageMetadata: GenerateContentResponseUsageMetadata | undefined +): GeminiUsage { + const promptTokenCount = usageMetadata?.promptTokenCount ?? 0 + const candidatesTokenCount = usageMetadata?.candidatesTokenCount ?? 0 + return { + promptTokenCount, + candidatesTokenCount, + totalTokenCount: usageMetadata?.totalTokenCount ?? promptTokenCount + candidatesTokenCount, + } +} + +/** + * Tool definition for Gemini format + */ +export interface GeminiToolDef { + name: string + description: string + parameters: Schema +} + /** * Converts OpenAI-style request format to Gemini format */ export function convertToGeminiFormat(request: ProviderRequest): { - contents: any[] - tools: any[] | undefined - systemInstruction: any | undefined + contents: Content[] + tools: GeminiToolDef[] | undefined + systemInstruction: Content | undefined } { - const contents: any[] = [] - let systemInstruction + const contents: Content[] = [] + let systemInstruction: Content | undefined if (request.systemPrompt) { systemInstruction = { parts: [{ text: request.systemPrompt }] } @@ -82,13 +149,13 @@ export function convertToGeminiFormat(request: ProviderRequest): { contents.push({ role: 'user', parts: [{ text: request.context }] }) } - if (request.messages && request.messages.length > 0) { + if (request.messages?.length) { for (const message of request.messages) { if (message.role === 'system') { if (!systemInstruction) { - systemInstruction = { parts: [{ text: message.content }] } - } else { - systemInstruction.parts[0].text = `${systemInstruction.parts[0].text || ''}\n${message.content}` + systemInstruction = { parts: [{ text: message.content ?? '' }] } + } else if (systemInstruction.parts?.[0] && 'text' in systemInstruction.parts[0]) { + systemInstruction.parts[0].text = `${systemInstruction.parts[0].text}\n${message.content}` } } else if (message.role === 'user' || message.role === 'assistant') { const geminiRole = message.role === 'user' ? 'user' : 'model' @@ -97,60 +164,200 @@ export function convertToGeminiFormat(request: ProviderRequest): { contents.push({ role: geminiRole, parts: [{ text: message.content }] }) } - if (message.role === 'assistant' && message.tool_calls && message.tool_calls.length > 0) { + if (message.role === 'assistant' && message.tool_calls?.length) { const functionCalls = message.tool_calls.map((toolCall) => ({ functionCall: { name: toolCall.function?.name, - args: JSON.parse(toolCall.function?.arguments || '{}'), + args: JSON.parse(toolCall.function?.arguments || '{}') as Record, }, })) - contents.push({ role: 'model', parts: functionCalls }) } } else if (message.role === 'tool') { + if (!message.name) { + logger.warn('Tool message missing function name, skipping') + continue + } + let responseData: Record + try { + responseData = JSON.parse(message.content ?? '{}') + } catch { + responseData = { output: message.content } + } contents.push({ role: 'user', - parts: [{ text: `Function result: ${message.content}` }], + parts: [ + { + functionResponse: { + id: message.tool_call_id, + name: message.name, + response: responseData, + }, + }, + ], }) } } } - const tools = request.tools?.map((tool) => { + const tools = request.tools?.map((tool): GeminiToolDef => { const toolParameters = { ...(tool.parameters || {}) } if (toolParameters.properties) { const properties = { ...toolParameters.properties } const required = toolParameters.required ? [...toolParameters.required] : [] + // Remove default values from properties (not supported by Gemini) for (const key in properties) { - const prop = properties[key] as any - + const prop = properties[key] as Record if (prop.default !== undefined) { const { default: _, ...cleanProp } = prop properties[key] = cleanProp } } - const parameters = { - type: toolParameters.type || 'object', - properties, + const parameters: Schema = { + type: (toolParameters.type as Schema['type']) || Type.OBJECT, + properties: properties as Record, ...(required.length > 0 ? { required } : {}), } return { name: tool.id, description: tool.description || `Execute the ${tool.id} function`, - parameters: cleanSchemaForGemini(parameters), + parameters: cleanSchemaForGemini(parameters) as Schema, } } return { name: tool.id, description: tool.description || `Execute the ${tool.id} function`, - parameters: cleanSchemaForGemini(toolParameters), + parameters: cleanSchemaForGemini(toolParameters) as Schema, } }) return { contents, tools, systemInstruction } } + +/** + * Creates a ReadableStream from a Google Gemini streaming response + */ +export function createReadableStreamFromGeminiStream( + stream: AsyncGenerator, + onComplete?: (content: string, usage: GeminiUsage) => void +): ReadableStream { + let fullContent = '' + let usage: GeminiUsage = { promptTokenCount: 0, candidatesTokenCount: 0, totalTokenCount: 0 } + + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of stream) { + if (chunk.usageMetadata) { + usage = convertUsageMetadata(chunk.usageMetadata) + } + + const text = chunk.text + if (text) { + fullContent += text + controller.enqueue(new TextEncoder().encode(text)) + } + } + + onComplete?.(fullContent, usage) + controller.close() + } catch (error) { + logger.error('Error reading Google Gemini stream', { + error: error instanceof Error ? error.message : String(error), + }) + controller.error(error) + } + }, + }) +} + +/** + * Maps string mode to FunctionCallingConfigMode enum + */ +function mapToFunctionCallingMode(mode: string): FunctionCallingConfigMode { + switch (mode) { + case 'AUTO': + return FunctionCallingConfigMode.AUTO + case 'ANY': + return FunctionCallingConfigMode.ANY + case 'NONE': + return FunctionCallingConfigMode.NONE + default: + return FunctionCallingConfigMode.AUTO + } +} + +/** + * Maps string level to ThinkingLevel enum + */ +export function mapToThinkingLevel(level: string): ThinkingLevel { + switch (level.toLowerCase()) { + case 'minimal': + return ThinkingLevel.MINIMAL + case 'low': + return ThinkingLevel.LOW + case 'medium': + return ThinkingLevel.MEDIUM + case 'high': + return ThinkingLevel.HIGH + default: + return ThinkingLevel.HIGH + } +} + +/** + * Result of checking forced tool usage + */ +export interface ForcedToolResult { + hasUsedForcedTool: boolean + usedForcedTools: string[] + nextToolConfig: ToolConfig | undefined +} + +/** + * Checks for forced tool usage in Google Gemini responses + */ +export function checkForForcedToolUsage( + functionCalls: FunctionCall[] | undefined, + toolConfig: ToolConfig | undefined, + forcedTools: string[], + usedForcedTools: string[] +): ForcedToolResult | null { + if (!functionCalls?.length) return null + + const adaptedToolCalls = functionCalls.map((fc) => ({ + name: fc.name ?? '', + arguments: (fc.args ?? {}) as Record, + })) + + const result = trackForcedToolUsage( + adaptedToolCalls, + toolConfig, + logger, + 'google', + forcedTools, + usedForcedTools + ) + + if (!result) return null + + const nextToolConfig: ToolConfig | undefined = result.nextToolConfig?.functionCallingConfig?.mode + ? { + functionCallingConfig: { + mode: mapToFunctionCallingMode(result.nextToolConfig.functionCallingConfig.mode), + allowedFunctionNames: result.nextToolConfig.functionCallingConfig.allowedFunctionNames, + }, + } + : undefined + + return { + hasUsedForcedTool: result.hasUsedForcedTool, + usedForcedTools: result.usedForcedTools, + nextToolConfig, + } +} diff --git a/apps/sim/providers/groq/index.ts b/apps/sim/providers/groq/index.ts index cfdcd4868..58ff64197 100644 --- a/apps/sim/providers/groq/index.ts +++ b/apps/sim/providers/groq/index.ts @@ -69,10 +69,7 @@ export const groqProvider: ProviderConfig = { : undefined const payload: any = { - model: (request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct').replace( - 'groq/', - '' - ), + model: request.model.replace('groq/', ''), messages: allMessages, } @@ -109,7 +106,7 @@ export const groqProvider: ProviderConfig = { toolChoice: payload.tool_choice, forcedToolsCount: forcedTools.length, hasFilteredTools, - model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct', + model: request.model, }) } } @@ -149,7 +146,7 @@ export const groqProvider: ProviderConfig = { success: true, output: { content: '', - model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct', + model: request.model, tokens: { prompt: 0, completion: 0, total: 0 }, toolCalls: undefined, providerTiming: { @@ -393,7 +390,7 @@ export const groqProvider: ProviderConfig = { const streamingPayload = { ...payload, messages: currentMessages, - tool_choice: 'auto', + tool_choice: originalToolChoice || 'auto', stream: true, } @@ -425,7 +422,7 @@ export const groqProvider: ProviderConfig = { success: true, output: { content: '', - model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct', + model: request.model, tokens: { prompt: tokens.prompt, completion: tokens.completion, diff --git a/apps/sim/providers/index.ts b/apps/sim/providers/index.ts index 3dbed8f42..6c4fa15c9 100644 --- a/apps/sim/providers/index.ts +++ b/apps/sim/providers/index.ts @@ -1,11 +1,11 @@ import { getCostMultiplier } from '@/lib/core/config/feature-flags' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' -import type { ProviderRequest, ProviderResponse } from '@/providers/types' +import { getProviderExecutor } from '@/providers/registry' +import type { ProviderId, ProviderRequest, ProviderResponse } from '@/providers/types' import { calculateCost, generateStructuredOutputInstructions, - getProvider, shouldBillModelUsage, supportsTemperature, } from '@/providers/utils' @@ -40,7 +40,7 @@ export async function executeProviderRequest( providerId: string, request: ProviderRequest ): Promise { - const provider = getProvider(providerId) + const provider = await getProviderExecutor(providerId as ProviderId) if (!provider) { throw new Error(`Provider not found: ${providerId}`) } diff --git a/apps/sim/providers/mistral/index.ts b/apps/sim/providers/mistral/index.ts index ebeb6ac48..9cfda86f1 100644 --- a/apps/sim/providers/mistral/index.ts +++ b/apps/sim/providers/mistral/index.ts @@ -36,7 +36,7 @@ export const mistralProvider: ProviderConfig = { request: ProviderRequest ): Promise => { logger.info('Preparing Mistral request', { - model: request.model || 'mistral-large-latest', + model: request.model, hasSystemPrompt: !!request.systemPrompt, hasMessages: !!request.messages?.length, hasTools: !!request.tools?.length, @@ -86,7 +86,7 @@ export const mistralProvider: ProviderConfig = { : undefined const payload: any = { - model: request.model || 'mistral-large-latest', + model: request.model, messages: allMessages, } @@ -126,7 +126,7 @@ export const mistralProvider: ProviderConfig = { : toolChoice.type === 'any' ? `force:${toolChoice.any?.name || 'unknown'}` : 'unknown', - model: request.model || 'mistral-large-latest', + model: request.model, }) } } diff --git a/apps/sim/providers/models.ts b/apps/sim/providers/models.ts index 096c29dbc..668e98e23 100644 --- a/apps/sim/providers/models.ts +++ b/apps/sim/providers/models.ts @@ -39,6 +39,10 @@ export interface ModelCapabilities { verbosity?: { values: string[] } + thinking?: { + levels: string[] + default?: string + } } export interface ModelDefinition { @@ -730,6 +734,10 @@ export const PROVIDER_DEFINITIONS: Record = { }, capabilities: { temperature: { min: 0, max: 2 }, + thinking: { + levels: ['low', 'high'], + default: 'high', + }, }, contextWindow: 1000000, }, @@ -743,6 +751,10 @@ export const PROVIDER_DEFINITIONS: Record = { }, capabilities: { temperature: { min: 0, max: 2 }, + thinking: { + levels: ['minimal', 'low', 'medium', 'high'], + default: 'high', + }, }, contextWindow: 1000000, }, @@ -832,6 +844,10 @@ export const PROVIDER_DEFINITIONS: Record = { }, capabilities: { temperature: { min: 0, max: 2 }, + thinking: { + levels: ['low', 'high'], + default: 'high', + }, }, contextWindow: 1000000, }, @@ -845,6 +861,10 @@ export const PROVIDER_DEFINITIONS: Record = { }, capabilities: { temperature: { min: 0, max: 2 }, + thinking: { + levels: ['minimal', 'low', 'medium', 'high'], + default: 'high', + }, }, contextWindow: 1000000, }, @@ -1864,3 +1884,49 @@ export function supportsNativeStructuredOutputs(modelId: string): boolean { } return false } + +/** + * Check if a model supports thinking/reasoning features. + * Returns the thinking capability config if supported, null otherwise. + */ +export function getThinkingCapability( + modelId: string +): { levels: string[]; default?: string } | null { + const normalizedModelId = modelId.toLowerCase() + + for (const provider of Object.values(PROVIDER_DEFINITIONS)) { + for (const model of provider.models) { + if (model.capabilities.thinking) { + const baseModelId = model.id.toLowerCase() + if (normalizedModelId === baseModelId || normalizedModelId.startsWith(`${baseModelId}-`)) { + return model.capabilities.thinking + } + } + } + } + return null +} + +/** + * Get all models that support thinking capability + */ +export function getModelsWithThinking(): string[] { + const models: string[] = [] + for (const provider of Object.values(PROVIDER_DEFINITIONS)) { + for (const model of provider.models) { + if (model.capabilities.thinking) { + models.push(model.id) + } + } + } + return models +} + +/** + * Get the thinking levels for a specific model + * Returns the valid levels for that model, or null if the model doesn't support thinking + */ +export function getThinkingLevelsForModel(modelId: string): string[] | null { + const capability = getThinkingCapability(modelId) + return capability?.levels ?? null +} diff --git a/apps/sim/providers/openai/index.ts b/apps/sim/providers/openai/index.ts index ddb207d54..700dc6aa6 100644 --- a/apps/sim/providers/openai/index.ts +++ b/apps/sim/providers/openai/index.ts @@ -33,7 +33,7 @@ export const openaiProvider: ProviderConfig = { request: ProviderRequest ): Promise => { logger.info('Preparing OpenAI request', { - model: request.model || 'gpt-4o', + model: request.model, hasSystemPrompt: !!request.systemPrompt, hasMessages: !!request.messages?.length, hasTools: !!request.tools?.length, @@ -76,7 +76,7 @@ export const openaiProvider: ProviderConfig = { : undefined const payload: any = { - model: request.model || 'gpt-4o', + model: request.model, messages: allMessages, } @@ -121,7 +121,7 @@ export const openaiProvider: ProviderConfig = { : toolChoice.type === 'any' ? `force:${toolChoice.any?.name || 'unknown'}` : 'unknown', - model: request.model || 'gpt-4o', + model: request.model, }) } } diff --git a/apps/sim/providers/openrouter/index.ts b/apps/sim/providers/openrouter/index.ts index 928b82d24..abaf25a96 100644 --- a/apps/sim/providers/openrouter/index.ts +++ b/apps/sim/providers/openrouter/index.ts @@ -78,7 +78,7 @@ export const openRouterProvider: ProviderConfig = { baseURL: 'https://openrouter.ai/api/v1', }) - const requestedModel = (request.model || '').replace(/^openrouter\//, '') + const requestedModel = request.model.replace(/^openrouter\//, '') logger.info('Preparing OpenRouter request', { model: requestedModel, diff --git a/apps/sim/providers/registry.ts b/apps/sim/providers/registry.ts new file mode 100644 index 000000000..4ea790667 --- /dev/null +++ b/apps/sim/providers/registry.ts @@ -0,0 +1,59 @@ +import { createLogger } from '@/lib/logs/console/logger' +import { anthropicProvider } from '@/providers/anthropic' +import { azureOpenAIProvider } from '@/providers/azure-openai' +import { cerebrasProvider } from '@/providers/cerebras' +import { deepseekProvider } from '@/providers/deepseek' +import { googleProvider } from '@/providers/google' +import { groqProvider } from '@/providers/groq' +import { mistralProvider } from '@/providers/mistral' +import { ollamaProvider } from '@/providers/ollama' +import { openaiProvider } from '@/providers/openai' +import { openRouterProvider } from '@/providers/openrouter' +import type { ProviderConfig, ProviderId } from '@/providers/types' +import { vertexProvider } from '@/providers/vertex' +import { vllmProvider } from '@/providers/vllm' +import { xAIProvider } from '@/providers/xai' + +const logger = createLogger('ProviderRegistry') + +const providerRegistry: Record = { + openai: openaiProvider, + anthropic: anthropicProvider, + google: googleProvider, + vertex: vertexProvider, + deepseek: deepseekProvider, + xai: xAIProvider, + cerebras: cerebrasProvider, + groq: groqProvider, + vllm: vllmProvider, + mistral: mistralProvider, + 'azure-openai': azureOpenAIProvider, + openrouter: openRouterProvider, + ollama: ollamaProvider, +} + +export async function getProviderExecutor( + providerId: ProviderId +): Promise { + const provider = providerRegistry[providerId] + if (!provider) { + logger.error(`Provider not found: ${providerId}`) + return undefined + } + return provider +} + +export async function initializeProviders(): Promise { + for (const [id, provider] of Object.entries(providerRegistry)) { + if (provider.initialize) { + try { + await provider.initialize() + logger.info(`Initialized provider: ${id}`) + } catch (error) { + logger.error(`Failed to initialize ${id} provider`, { + error: error instanceof Error ? error.message : 'Unknown error', + }) + } + } + } +} diff --git a/apps/sim/providers/types.ts b/apps/sim/providers/types.ts index b49a04d61..9d83ec458 100644 --- a/apps/sim/providers/types.ts +++ b/apps/sim/providers/types.ts @@ -164,6 +164,7 @@ export interface ProviderRequest { vertexLocation?: string reasoningEffort?: string verbosity?: string + thinkingLevel?: string } export const providers: Record = {} diff --git a/apps/sim/providers/utils.ts b/apps/sim/providers/utils.ts index 31acd3cae..9344d3971 100644 --- a/apps/sim/providers/utils.ts +++ b/apps/sim/providers/utils.ts @@ -3,13 +3,6 @@ import type { CompletionUsage } from 'openai/resources/completions' import { getEnv, isTruthy } from '@/lib/core/config/env' import { isHosted } from '@/lib/core/config/feature-flags' import { createLogger, type Logger } from '@/lib/logs/console/logger' -import { anthropicProvider } from '@/providers/anthropic' -import { azureOpenAIProvider } from '@/providers/azure-openai' -import { cerebrasProvider } from '@/providers/cerebras' -import { deepseekProvider } from '@/providers/deepseek' -import { googleProvider } from '@/providers/google' -import { groqProvider } from '@/providers/groq' -import { mistralProvider } from '@/providers/mistral' import { getComputerUseModels, getEmbeddingModelPricing, @@ -20,117 +13,82 @@ import { getModelsWithTemperatureSupport, getModelsWithTempRange01, getModelsWithTempRange02, + getModelsWithThinking, getModelsWithVerbosity, + getProviderDefaultModel as getProviderDefaultModelFromDefinitions, getProviderModels as getProviderModelsFromDefinitions, getProvidersWithToolUsageControl, getReasoningEffortValuesForModel as getReasoningEffortValuesForModelFromDefinitions, + getThinkingLevelsForModel as getThinkingLevelsForModelFromDefinitions, getVerbosityValuesForModel as getVerbosityValuesForModelFromDefinitions, PROVIDER_DEFINITIONS, supportsTemperature as supportsTemperatureFromDefinitions, supportsToolUsageControl as supportsToolUsageControlFromDefinitions, updateOllamaModels as updateOllamaModelsInDefinitions, } from '@/providers/models' -import { ollamaProvider } from '@/providers/ollama' -import { openaiProvider } from '@/providers/openai' -import { openRouterProvider } from '@/providers/openrouter' -import type { ProviderConfig, ProviderId, ProviderToolConfig } from '@/providers/types' -import { vertexProvider } from '@/providers/vertex' -import { vllmProvider } from '@/providers/vllm' -import { xAIProvider } from '@/providers/xai' +import type { ProviderId, ProviderToolConfig } from '@/providers/types' import { useCustomToolsStore } from '@/stores/custom-tools/store' import { useProvidersStore } from '@/stores/providers/store' const logger = createLogger('ProviderUtils') -export const providers: Record< - ProviderId, - ProviderConfig & { - models: string[] - computerUseModels?: string[] - modelPatterns?: RegExp[] +/** + * Client-safe provider metadata. + * This object contains only model lists and patterns - no executeRequest implementations. + * For server-side execution, use @/providers/registry. + */ +export interface ProviderMetadata { + id: string + name: string + description: string + version: string + models: string[] + defaultModel: string + computerUseModels?: string[] + modelPatterns?: RegExp[] +} + +/** + * Build provider metadata from PROVIDER_DEFINITIONS. + * This is client-safe as it doesn't import any provider implementations. + */ +function buildProviderMetadata(providerId: ProviderId): ProviderMetadata { + const def = PROVIDER_DEFINITIONS[providerId] + return { + id: providerId, + name: def?.name || providerId, + description: def?.description || '', + version: '1.0.0', + models: getProviderModelsFromDefinitions(providerId), + defaultModel: getProviderDefaultModelFromDefinitions(providerId), + modelPatterns: def?.modelPatterns, } -> = { +} + +export const providers: Record = { openai: { - ...openaiProvider, - models: getProviderModelsFromDefinitions('openai'), + ...buildProviderMetadata('openai'), computerUseModels: ['computer-use-preview'], - modelPatterns: PROVIDER_DEFINITIONS.openai.modelPatterns, }, anthropic: { - ...anthropicProvider, - models: getProviderModelsFromDefinitions('anthropic'), + ...buildProviderMetadata('anthropic'), computerUseModels: getComputerUseModels().filter((model) => getProviderModelsFromDefinitions('anthropic').includes(model) ), - modelPatterns: PROVIDER_DEFINITIONS.anthropic.modelPatterns, - }, - google: { - ...googleProvider, - models: getProviderModelsFromDefinitions('google'), - modelPatterns: PROVIDER_DEFINITIONS.google.modelPatterns, - }, - vertex: { - ...vertexProvider, - models: getProviderModelsFromDefinitions('vertex'), - modelPatterns: PROVIDER_DEFINITIONS.vertex.modelPatterns, - }, - deepseek: { - ...deepseekProvider, - models: getProviderModelsFromDefinitions('deepseek'), - modelPatterns: PROVIDER_DEFINITIONS.deepseek.modelPatterns, - }, - xai: { - ...xAIProvider, - models: getProviderModelsFromDefinitions('xai'), - modelPatterns: PROVIDER_DEFINITIONS.xai.modelPatterns, - }, - cerebras: { - ...cerebrasProvider, - models: getProviderModelsFromDefinitions('cerebras'), - modelPatterns: PROVIDER_DEFINITIONS.cerebras.modelPatterns, - }, - groq: { - ...groqProvider, - models: getProviderModelsFromDefinitions('groq'), - modelPatterns: PROVIDER_DEFINITIONS.groq.modelPatterns, - }, - vllm: { - ...vllmProvider, - models: getProviderModelsFromDefinitions('vllm'), - modelPatterns: PROVIDER_DEFINITIONS.vllm.modelPatterns, - }, - mistral: { - ...mistralProvider, - models: getProviderModelsFromDefinitions('mistral'), - modelPatterns: PROVIDER_DEFINITIONS.mistral.modelPatterns, - }, - 'azure-openai': { - ...azureOpenAIProvider, - models: getProviderModelsFromDefinitions('azure-openai'), - modelPatterns: PROVIDER_DEFINITIONS['azure-openai'].modelPatterns, - }, - openrouter: { - ...openRouterProvider, - models: getProviderModelsFromDefinitions('openrouter'), - modelPatterns: PROVIDER_DEFINITIONS.openrouter.modelPatterns, - }, - ollama: { - ...ollamaProvider, - models: getProviderModelsFromDefinitions('ollama'), - modelPatterns: PROVIDER_DEFINITIONS.ollama.modelPatterns, }, + google: buildProviderMetadata('google'), + vertex: buildProviderMetadata('vertex'), + deepseek: buildProviderMetadata('deepseek'), + xai: buildProviderMetadata('xai'), + cerebras: buildProviderMetadata('cerebras'), + groq: buildProviderMetadata('groq'), + vllm: buildProviderMetadata('vllm'), + mistral: buildProviderMetadata('mistral'), + 'azure-openai': buildProviderMetadata('azure-openai'), + openrouter: buildProviderMetadata('openrouter'), + ollama: buildProviderMetadata('ollama'), } -Object.entries(providers).forEach(([id, provider]) => { - if (provider.initialize) { - provider.initialize().catch((error) => { - logger.error(`Failed to initialize ${id} provider`, { - error: error instanceof Error ? error.message : 'Unknown error', - }) - }) - } -}) - export function updateOllamaProviderModels(models: string[]): void { updateOllamaModelsInDefinitions(models) providers.ollama.models = getProviderModelsFromDefinitions('ollama') @@ -211,12 +169,12 @@ export function getProviderFromModel(model: string): ProviderId { return 'ollama' } -export function getProvider(id: string): ProviderConfig | undefined { +export function getProvider(id: string): ProviderMetadata | undefined { const providerId = id.split('/')[0] as ProviderId return providers[providerId] } -export function getProviderConfigFromModel(model: string): ProviderConfig | undefined { +export function getProviderConfigFromModel(model: string): ProviderMetadata | undefined { const providerId = getProviderFromModel(model) return providers[providerId] } @@ -929,6 +887,7 @@ export const MODELS_TEMP_RANGE_0_1 = getModelsWithTempRange01() export const MODELS_WITH_TEMPERATURE_SUPPORT = getModelsWithTemperatureSupport() export const MODELS_WITH_REASONING_EFFORT = getModelsWithReasoningEffort() export const MODELS_WITH_VERBOSITY = getModelsWithVerbosity() +export const MODELS_WITH_THINKING = getModelsWithThinking() export const PROVIDERS_WITH_TOOL_USAGE_CONTROL = getProvidersWithToolUsageControl() export function supportsTemperature(model: string): boolean { @@ -963,6 +922,14 @@ export function getVerbosityValuesForModel(model: string): string[] | null { return getVerbosityValuesForModelFromDefinitions(model) } +/** + * Get thinking levels for a specific model + * Returns the valid levels for that model, or null if the model doesn't support thinking + */ +export function getThinkingLevelsForModel(model: string): string[] | null { + return getThinkingLevelsForModelFromDefinitions(model) +} + /** * Prepare tool execution parameters, separating tool parameters from system parameters */ diff --git a/apps/sim/providers/vertex/index.ts b/apps/sim/providers/vertex/index.ts index f33c9d1fc..e926c43d6 100644 --- a/apps/sim/providers/vertex/index.ts +++ b/apps/sim/providers/vertex/index.ts @@ -1,33 +1,23 @@ +import { GoogleGenAI } from '@google/genai' +import { OAuth2Client } from 'google-auth-library' import { env } from '@/lib/core/config/env' import { createLogger } from '@/lib/logs/console/logger' import type { StreamingExecution } from '@/executor/types' -import { MAX_TOOL_ITERATIONS } from '@/providers' -import { - cleanSchemaForGemini, - convertToGeminiFormat, - extractFunctionCall, - extractTextContent, -} from '@/providers/google/utils' +import { executeGeminiRequest } from '@/providers/gemini/core' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' -import type { - ProviderConfig, - ProviderRequest, - ProviderResponse, - TimeSegment, -} from '@/providers/types' -import { - calculateCost, - prepareToolExecution, - prepareToolsWithUsageControl, - trackForcedToolUsage, -} from '@/providers/utils' -import { buildVertexEndpoint, createReadableStreamFromVertexStream } from '@/providers/vertex/utils' -import { executeTool } from '@/tools' +import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types' const logger = createLogger('VertexProvider') /** - * Vertex AI provider configuration + * Vertex AI provider + * + * Uses the @google/genai SDK with Vertex AI backend and OAuth authentication. + * Shares core execution logic with Google Gemini provider. + * + * Authentication: + * - Uses OAuth access token passed via googleAuthOptions.authClient + * - Token refresh is handled at the OAuth layer before calling this provider */ export const vertexProvider: ProviderConfig = { id: 'vertex', @@ -55,869 +45,35 @@ export const vertexProvider: ProviderConfig = { ) } - logger.info('Preparing Vertex AI request', { - model: request.model || 'vertex/gemini-2.5-pro', - hasSystemPrompt: !!request.systemPrompt, - hasMessages: !!request.messages?.length, - hasTools: !!request.tools?.length, - toolCount: request.tools?.length || 0, - hasResponseFormat: !!request.responseFormat, - streaming: !!request.stream, + // Strip 'vertex/' prefix from model name if present + const model = request.model.replace('vertex/', '') + + logger.info('Creating Vertex AI client', { project: vertexProject, location: vertexLocation, + model, }) - const providerStartTime = Date.now() - const providerStartTimeISO = new Date(providerStartTime).toISOString() - - try { - const { contents, tools, systemInstruction } = convertToGeminiFormat(request) - - const requestedModel = (request.model || 'vertex/gemini-2.5-pro').replace('vertex/', '') - - const payload: any = { - contents, - generationConfig: {}, - } - - if (request.temperature !== undefined && request.temperature !== null) { - payload.generationConfig.temperature = request.temperature - } - - if (request.maxTokens !== undefined) { - payload.generationConfig.maxOutputTokens = request.maxTokens - } - - if (systemInstruction) { - payload.systemInstruction = systemInstruction - } - - if (request.responseFormat && !tools?.length) { - const responseFormatSchema = request.responseFormat.schema || request.responseFormat - const cleanSchema = cleanSchemaForGemini(responseFormatSchema) - - payload.generationConfig.responseMimeType = 'application/json' - payload.generationConfig.responseSchema = cleanSchema - - logger.info('Using Vertex AI native structured output format', { - hasSchema: !!cleanSchema, - mimeType: 'application/json', - }) - } else if (request.responseFormat && tools?.length) { - logger.warn( - 'Vertex AI does not support structured output (responseFormat) with function calling (tools). Structured output will be ignored.' - ) - } - - let preparedTools: ReturnType | null = null - - if (tools?.length) { - preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'google') - const { tools: filteredTools, toolConfig } = preparedTools - - if (filteredTools?.length) { - payload.tools = [ - { - functionDeclarations: filteredTools, - }, - ] - - if (toolConfig) { - payload.toolConfig = toolConfig - } - - logger.info('Vertex AI request with tools:', { - toolCount: filteredTools.length, - model: requestedModel, - tools: filteredTools.map((t) => t.name), - hasToolConfig: !!toolConfig, - toolConfig: toolConfig, - }) - } - } - - const initialCallTime = Date.now() - const shouldStream = !!(request.stream && !tools?.length) - - const endpoint = buildVertexEndpoint( - vertexProject, - vertexLocation, - requestedModel, - shouldStream - ) - - if (request.stream && tools?.length) { - logger.info('Streaming disabled for initial request due to tools presence', { - toolCount: tools.length, - willStreamAfterTools: true, - }) - } - - const response = await fetch(endpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${request.apiKey}`, - }, - body: JSON.stringify(payload), - }) - - if (!response.ok) { - const responseText = await response.text() - logger.error('Vertex AI API error details:', { - status: response.status, - statusText: response.statusText, - responseBody: responseText, - }) - throw new Error(`Vertex AI API error: ${response.status} ${response.statusText}`) - } - - const firstResponseTime = Date.now() - initialCallTime - - if (shouldStream) { - logger.info('Handling Vertex AI streaming response') - - const streamingResult: StreamingExecution = { - stream: null as any, - execution: { - success: true, - output: { - content: '', - model: request.model, - tokens: { - prompt: 0, - completion: 0, - total: 0, - }, - providerTiming: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: firstResponseTime, - modelTime: firstResponseTime, - toolsTime: 0, - firstResponseTime, - iterations: 1, - timeSegments: [ - { - type: 'model', - name: 'Initial streaming response', - startTime: initialCallTime, - endTime: initialCallTime + firstResponseTime, - duration: firstResponseTime, - }, - ], - }, - }, - logs: [], - metadata: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: firstResponseTime, - }, - isStreaming: true, - }, - } - - streamingResult.stream = createReadableStreamFromVertexStream( - response, - (content, usage) => { - streamingResult.execution.output.content = content - - const streamEndTime = Date.now() - const streamEndTimeISO = new Date(streamEndTime).toISOString() - - if (streamingResult.execution.output.providerTiming) { - streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO - streamingResult.execution.output.providerTiming.duration = - streamEndTime - providerStartTime - - if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) { - streamingResult.execution.output.providerTiming.timeSegments[0].endTime = - streamEndTime - streamingResult.execution.output.providerTiming.timeSegments[0].duration = - streamEndTime - providerStartTime - } - } - - const promptTokens = usage?.promptTokenCount || 0 - const completionTokens = usage?.candidatesTokenCount || 0 - const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens - - streamingResult.execution.output.tokens = { - prompt: promptTokens, - completion: completionTokens, - total: totalTokens, - } - - const costResult = calculateCost(request.model, promptTokens, completionTokens) - streamingResult.execution.output.cost = { - input: costResult.input, - output: costResult.output, - total: costResult.total, - } - } - ) - - return streamingResult - } - - let geminiResponse = await response.json() - - if (payload.generationConfig?.responseSchema) { - const candidate = geminiResponse.candidates?.[0] - if (candidate?.content?.parts?.[0]?.text) { - const text = candidate.content.parts[0].text - try { - JSON.parse(text) - logger.info('Successfully received structured JSON output') - } catch (_e) { - logger.warn('Failed to parse structured output as JSON') - } - } - } - - let content = '' - let tokens = { - prompt: 0, - completion: 0, - total: 0, - } - const toolCalls = [] - const toolResults = [] - let iterationCount = 0 - - const originalToolConfig = preparedTools?.toolConfig - const forcedTools = preparedTools?.forcedTools || [] - let usedForcedTools: string[] = [] - let hasUsedForcedTool = false - let currentToolConfig = originalToolConfig - - const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => { - if (currentToolConfig && forcedTools.length > 0) { - const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }] - const result = trackForcedToolUsage( - toolCallsForTracking, - currentToolConfig, - logger, - 'google', - forcedTools, - usedForcedTools - ) - hasUsedForcedTool = result.hasUsedForcedTool - usedForcedTools = result.usedForcedTools - - if (result.nextToolConfig) { - currentToolConfig = result.nextToolConfig - logger.info('Updated tool config for next iteration', { - hasNextToolConfig: !!currentToolConfig, - usedForcedTools: usedForcedTools, - }) - } - } - } - - let modelTime = firstResponseTime - let toolsTime = 0 - - const timeSegments: TimeSegment[] = [ - { - type: 'model', - name: 'Initial response', - startTime: initialCallTime, - endTime: initialCallTime + firstResponseTime, - duration: firstResponseTime, - }, - ] - - try { - const candidate = geminiResponse.candidates?.[0] - - if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') { - logger.warn( - 'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided', - { - finishReason: candidate.finishReason, - hasContent: !!candidate?.content, - hasParts: !!candidate?.content?.parts, - } - ) - content = extractTextContent(candidate) - } - - const functionCall = extractFunctionCall(candidate) - - if (functionCall) { - logger.info(`Received function call from Vertex AI: ${functionCall.name}`) - - while (iterationCount < MAX_TOOL_ITERATIONS) { - const latestResponse = geminiResponse.candidates?.[0] - const latestFunctionCall = extractFunctionCall(latestResponse) - - if (!latestFunctionCall) { - content = extractTextContent(latestResponse) - break - } - - logger.info( - `Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` - ) - - const toolsStartTime = Date.now() - - try { - const toolName = latestFunctionCall.name - const toolArgs = latestFunctionCall.args || {} - - const tool = request.tools?.find((t) => t.id === toolName) - if (!tool) { - logger.warn(`Tool ${toolName} not found in registry, skipping`) - break - } - - const toolCallStartTime = Date.now() - - const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams, true) - const toolCallEndTime = Date.now() - const toolCallDuration = toolCallEndTime - toolCallStartTime - - timeSegments.push({ - type: 'tool', - name: toolName, - startTime: toolCallStartTime, - endTime: toolCallEndTime, - duration: toolCallDuration, - }) - - let resultContent: any - if (result.success) { - toolResults.push(result.output) - resultContent = result.output - } else { - resultContent = { - error: true, - message: result.error || 'Tool execution failed', - tool: toolName, - } - } - - toolCalls.push({ - name: toolName, - arguments: toolParams, - startTime: new Date(toolCallStartTime).toISOString(), - endTime: new Date(toolCallEndTime).toISOString(), - duration: toolCallDuration, - result: resultContent, - success: result.success, - }) - - const simplifiedMessages = [ - ...(contents.filter((m) => m.role === 'user').length > 0 - ? [contents.filter((m) => m.role === 'user')[0]] - : [contents[0]]), - { - role: 'model', - parts: [ - { - functionCall: { - name: latestFunctionCall.name, - args: latestFunctionCall.args, - }, - }, - ], - }, - { - role: 'user', - parts: [ - { - text: `Function ${latestFunctionCall.name} result: ${JSON.stringify(resultContent)}`, - }, - ], - }, - ] - - const thisToolsTime = Date.now() - toolsStartTime - toolsTime += thisToolsTime - - checkForForcedToolUsage(latestFunctionCall) - - const nextModelStartTime = Date.now() - - try { - if (request.stream) { - const streamingPayload = { - ...payload, - contents: simplifiedMessages, - } - - const allForcedToolsUsed = - forcedTools.length > 0 && usedForcedTools.length === forcedTools.length - - if (allForcedToolsUsed && request.responseFormat) { - streamingPayload.tools = undefined - streamingPayload.toolConfig = undefined - - const responseFormatSchema = - request.responseFormat.schema || request.responseFormat - const cleanSchema = cleanSchemaForGemini(responseFormatSchema) - - if (!streamingPayload.generationConfig) { - streamingPayload.generationConfig = {} - } - streamingPayload.generationConfig.responseMimeType = 'application/json' - streamingPayload.generationConfig.responseSchema = cleanSchema - - logger.info('Using structured output for final response after tool execution') - } else { - if (currentToolConfig) { - streamingPayload.toolConfig = currentToolConfig - } else { - streamingPayload.toolConfig = { functionCallingConfig: { mode: 'AUTO' } } - } - } - - const checkPayload = { - ...streamingPayload, - } - checkPayload.stream = undefined - - const checkEndpoint = buildVertexEndpoint( - vertexProject, - vertexLocation, - requestedModel, - false - ) - - const checkResponse = await fetch(checkEndpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${request.apiKey}`, - }, - body: JSON.stringify(checkPayload), - }) - - if (!checkResponse.ok) { - const errorBody = await checkResponse.text() - logger.error('Error in Vertex AI check request:', { - status: checkResponse.status, - statusText: checkResponse.statusText, - responseBody: errorBody, - }) - throw new Error( - `Vertex AI API check error: ${checkResponse.status} ${checkResponse.statusText}` - ) - } - - const checkResult = await checkResponse.json() - const checkCandidate = checkResult.candidates?.[0] - const checkFunctionCall = extractFunctionCall(checkCandidate) - - if (checkFunctionCall) { - logger.info( - 'Function call detected in follow-up, handling in non-streaming mode', - { - functionName: checkFunctionCall.name, - } - ) - - geminiResponse = checkResult - - if (checkResult.usageMetadata) { - tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0 - tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0 - tokens.total += - (checkResult.usageMetadata.promptTokenCount || 0) + - (checkResult.usageMetadata.candidatesTokenCount || 0) - } - - const nextModelEndTime = Date.now() - const thisModelTime = nextModelEndTime - nextModelStartTime - modelTime += thisModelTime - - timeSegments.push({ - type: 'model', - name: `Model response (iteration ${iterationCount + 1})`, - startTime: nextModelStartTime, - endTime: nextModelEndTime, - duration: thisModelTime, - }) - - iterationCount++ - continue - } - - logger.info('No function call detected, proceeding with streaming response') - - if (request.responseFormat) { - streamingPayload.tools = undefined - streamingPayload.toolConfig = undefined - - const responseFormatSchema = - request.responseFormat.schema || request.responseFormat - const cleanSchema = cleanSchemaForGemini(responseFormatSchema) - - if (!streamingPayload.generationConfig) { - streamingPayload.generationConfig = {} - } - streamingPayload.generationConfig.responseMimeType = 'application/json' - streamingPayload.generationConfig.responseSchema = cleanSchema - - logger.info( - 'Using structured output for final streaming response after tool execution' - ) - } - - const streamEndpoint = buildVertexEndpoint( - vertexProject, - vertexLocation, - requestedModel, - true - ) - - const streamingResponse = await fetch(streamEndpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${request.apiKey}`, - }, - body: JSON.stringify(streamingPayload), - }) - - if (!streamingResponse.ok) { - const errorBody = await streamingResponse.text() - logger.error('Error in Vertex AI streaming follow-up request:', { - status: streamingResponse.status, - statusText: streamingResponse.statusText, - responseBody: errorBody, - }) - throw new Error( - `Vertex AI API streaming error: ${streamingResponse.status} ${streamingResponse.statusText}` - ) - } - - const nextModelEndTime = Date.now() - const thisModelTime = nextModelEndTime - nextModelStartTime - modelTime += thisModelTime - - timeSegments.push({ - type: 'model', - name: 'Final streaming response after tool calls', - startTime: nextModelStartTime, - endTime: nextModelEndTime, - duration: thisModelTime, - }) - - const streamingExecution: StreamingExecution = { - stream: null as any, - execution: { - success: true, - output: { - content: '', - model: request.model, - tokens, - toolCalls: - toolCalls.length > 0 - ? { - list: toolCalls, - count: toolCalls.length, - } - : undefined, - toolResults, - providerTiming: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: Date.now() - providerStartTime, - modelTime, - toolsTime, - firstResponseTime, - iterations: iterationCount + 1, - timeSegments, - }, - }, - logs: [], - metadata: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: Date.now() - providerStartTime, - }, - isStreaming: true, - }, - } - - streamingExecution.stream = createReadableStreamFromVertexStream( - streamingResponse, - (content, usage) => { - streamingExecution.execution.output.content = content - - const streamEndTime = Date.now() - const streamEndTimeISO = new Date(streamEndTime).toISOString() - - if (streamingExecution.execution.output.providerTiming) { - streamingExecution.execution.output.providerTiming.endTime = - streamEndTimeISO - streamingExecution.execution.output.providerTiming.duration = - streamEndTime - providerStartTime - } - - const promptTokens = usage?.promptTokenCount || 0 - const completionTokens = usage?.candidatesTokenCount || 0 - const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens - - const existingTokens = streamingExecution.execution.output.tokens || { - prompt: 0, - completion: 0, - total: 0, - } - - const existingPrompt = existingTokens.prompt || 0 - const existingCompletion = existingTokens.completion || 0 - const existingTotal = existingTokens.total || 0 - - streamingExecution.execution.output.tokens = { - prompt: existingPrompt + promptTokens, - completion: existingCompletion + completionTokens, - total: existingTotal + totalTokens, - } - - const accumulatedCost = calculateCost( - request.model, - existingPrompt, - existingCompletion - ) - const streamCost = calculateCost( - request.model, - promptTokens, - completionTokens - ) - streamingExecution.execution.output.cost = { - input: accumulatedCost.input + streamCost.input, - output: accumulatedCost.output + streamCost.output, - total: accumulatedCost.total + streamCost.total, - } - } - ) - - return streamingExecution - } - - const nextPayload = { - ...payload, - contents: simplifiedMessages, - } - - const allForcedToolsUsed = - forcedTools.length > 0 && usedForcedTools.length === forcedTools.length - - if (allForcedToolsUsed && request.responseFormat) { - nextPayload.tools = undefined - nextPayload.toolConfig = undefined - - const responseFormatSchema = - request.responseFormat.schema || request.responseFormat - const cleanSchema = cleanSchemaForGemini(responseFormatSchema) - - if (!nextPayload.generationConfig) { - nextPayload.generationConfig = {} - } - nextPayload.generationConfig.responseMimeType = 'application/json' - nextPayload.generationConfig.responseSchema = cleanSchema - - logger.info( - 'Using structured output for final non-streaming response after tool execution' - ) - } else { - if (currentToolConfig) { - nextPayload.toolConfig = currentToolConfig - } - } - - const nextEndpoint = buildVertexEndpoint( - vertexProject, - vertexLocation, - requestedModel, - false - ) - - const nextResponse = await fetch(nextEndpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${request.apiKey}`, - }, - body: JSON.stringify(nextPayload), - }) - - if (!nextResponse.ok) { - const errorBody = await nextResponse.text() - logger.error('Error in Vertex AI follow-up request:', { - status: nextResponse.status, - statusText: nextResponse.statusText, - responseBody: errorBody, - iterationCount, - }) - break - } - - geminiResponse = await nextResponse.json() - - const nextModelEndTime = Date.now() - const thisModelTime = nextModelEndTime - nextModelStartTime - - timeSegments.push({ - type: 'model', - name: `Model response (iteration ${iterationCount + 1})`, - startTime: nextModelStartTime, - endTime: nextModelEndTime, - duration: thisModelTime, - }) - - modelTime += thisModelTime - - const nextCandidate = geminiResponse.candidates?.[0] - const nextFunctionCall = extractFunctionCall(nextCandidate) - - if (!nextFunctionCall) { - if (request.responseFormat) { - const finalPayload = { - ...payload, - contents: nextPayload.contents, - tools: undefined, - toolConfig: undefined, - } - - const responseFormatSchema = - request.responseFormat.schema || request.responseFormat - const cleanSchema = cleanSchemaForGemini(responseFormatSchema) - - if (!finalPayload.generationConfig) { - finalPayload.generationConfig = {} - } - finalPayload.generationConfig.responseMimeType = 'application/json' - finalPayload.generationConfig.responseSchema = cleanSchema - - logger.info('Making final request with structured output after tool execution') - - const finalEndpoint = buildVertexEndpoint( - vertexProject, - vertexLocation, - requestedModel, - false - ) - - const finalResponse = await fetch(finalEndpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${request.apiKey}`, - }, - body: JSON.stringify(finalPayload), - }) - - if (finalResponse.ok) { - const finalResult = await finalResponse.json() - const finalCandidate = finalResult.candidates?.[0] - content = extractTextContent(finalCandidate) - - if (finalResult.usageMetadata) { - tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0 - tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0 - tokens.total += - (finalResult.usageMetadata.promptTokenCount || 0) + - (finalResult.usageMetadata.candidatesTokenCount || 0) - } - } else { - logger.warn( - 'Failed to get structured output, falling back to regular response' - ) - content = extractTextContent(nextCandidate) - } - } else { - content = extractTextContent(nextCandidate) - } - break - } - - iterationCount++ - } catch (error) { - logger.error('Error in Vertex AI follow-up request:', { - error: error instanceof Error ? error.message : String(error), - iterationCount, - }) - break - } - } catch (error) { - logger.error('Error processing function call:', { - error: error instanceof Error ? error.message : String(error), - functionName: latestFunctionCall?.name || 'unknown', - }) - break - } - } - } else { - content = extractTextContent(candidate) - } - } catch (error) { - logger.error('Error processing Vertex AI response:', { - error: error instanceof Error ? error.message : String(error), - iterationCount, - }) - - if (!content && toolCalls.length > 0) { - content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.` - } - } - - const providerEndTime = Date.now() - const providerEndTimeISO = new Date(providerEndTime).toISOString() - const totalDuration = providerEndTime - providerStartTime - - if (geminiResponse.usageMetadata) { - tokens = { - prompt: geminiResponse.usageMetadata.promptTokenCount || 0, - completion: geminiResponse.usageMetadata.candidatesTokenCount || 0, - total: - (geminiResponse.usageMetadata.promptTokenCount || 0) + - (geminiResponse.usageMetadata.candidatesTokenCount || 0), - } - } - - return { - content, - model: request.model, - tokens, - toolCalls: toolCalls.length > 0 ? toolCalls : undefined, - toolResults: toolResults.length > 0 ? toolResults : undefined, - timing: { - startTime: providerStartTimeISO, - endTime: providerEndTimeISO, - duration: totalDuration, - modelTime: modelTime, - toolsTime: toolsTime, - firstResponseTime: firstResponseTime, - iterations: iterationCount + 1, - timeSegments: timeSegments, - }, - } - } catch (error) { - const providerEndTime = Date.now() - const providerEndTimeISO = new Date(providerEndTime).toISOString() - const totalDuration = providerEndTime - providerStartTime - - logger.error('Error in Vertex AI request:', { - error: error instanceof Error ? error.message : String(error), - duration: totalDuration, - }) - - const enhancedError = new Error(error instanceof Error ? error.message : String(error)) - // @ts-ignore - enhancedError.timing = { - startTime: providerStartTimeISO, - endTime: providerEndTimeISO, - duration: totalDuration, - } - - throw enhancedError - } + // Create an OAuth2Client and set the access token + // This allows us to use an OAuth access token with the SDK + const authClient = new OAuth2Client() + authClient.setCredentials({ access_token: request.apiKey }) + + // Create client with Vertex AI configuration + const ai = new GoogleGenAI({ + vertexai: true, + project: vertexProject, + location: vertexLocation, + googleAuthOptions: { + authClient, + }, + }) + + return executeGeminiRequest({ + ai, + model, + request, + providerType: 'vertex', + }) }, } diff --git a/apps/sim/providers/vertex/utils.ts b/apps/sim/providers/vertex/utils.ts deleted file mode 100644 index 3c7289108..000000000 --- a/apps/sim/providers/vertex/utils.ts +++ /dev/null @@ -1,231 +0,0 @@ -import { createLogger } from '@/lib/logs/console/logger' -import { extractFunctionCall, extractTextContent } from '@/providers/google/utils' - -const logger = createLogger('VertexUtils') - -/** - * Creates a ReadableStream from Vertex AI's Gemini stream response - */ -export function createReadableStreamFromVertexStream( - response: Response, - onComplete?: ( - content: string, - usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number } - ) => void -): ReadableStream { - const reader = response.body?.getReader() - if (!reader) { - throw new Error('Failed to get reader from response body') - } - - return new ReadableStream({ - async start(controller) { - try { - let buffer = '' - let fullContent = '' - let usageData: { - promptTokenCount?: number - candidatesTokenCount?: number - totalTokenCount?: number - } | null = null - - while (true) { - const { done, value } = await reader.read() - if (done) { - if (buffer.trim()) { - try { - const data = JSON.parse(buffer.trim()) - if (data.usageMetadata) { - usageData = data.usageMetadata - } - const candidate = data.candidates?.[0] - if (candidate?.content?.parts) { - const functionCall = extractFunctionCall(candidate) - if (functionCall) { - logger.debug( - 'Function call detected in final buffer, ending stream to execute tool', - { - functionName: functionCall.name, - } - ) - if (onComplete) onComplete(fullContent, usageData || undefined) - controller.close() - return - } - const content = extractTextContent(candidate) - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - } catch (e) { - if (buffer.trim().startsWith('[')) { - try { - const dataArray = JSON.parse(buffer.trim()) - if (Array.isArray(dataArray)) { - for (const item of dataArray) { - if (item.usageMetadata) { - usageData = item.usageMetadata - } - const candidate = item.candidates?.[0] - if (candidate?.content?.parts) { - const functionCall = extractFunctionCall(candidate) - if (functionCall) { - logger.debug( - 'Function call detected in array item, ending stream to execute tool', - { - functionName: functionCall.name, - } - ) - if (onComplete) onComplete(fullContent, usageData || undefined) - controller.close() - return - } - const content = extractTextContent(candidate) - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - } - } - } catch (arrayError) {} - } - } - } - if (onComplete) onComplete(fullContent, usageData || undefined) - controller.close() - break - } - - const text = new TextDecoder().decode(value) - buffer += text - - let searchIndex = 0 - while (searchIndex < buffer.length) { - const openBrace = buffer.indexOf('{', searchIndex) - if (openBrace === -1) break - - let braceCount = 0 - let inString = false - let escaped = false - let closeBrace = -1 - - for (let i = openBrace; i < buffer.length; i++) { - const char = buffer[i] - - if (!inString) { - if (char === '"' && !escaped) { - inString = true - } else if (char === '{') { - braceCount++ - } else if (char === '}') { - braceCount-- - if (braceCount === 0) { - closeBrace = i - break - } - } - } else { - if (char === '"' && !escaped) { - inString = false - } - } - - escaped = char === '\\' && !escaped - } - - if (closeBrace !== -1) { - const jsonStr = buffer.substring(openBrace, closeBrace + 1) - - try { - const data = JSON.parse(jsonStr) - - if (data.usageMetadata) { - usageData = data.usageMetadata - } - - const candidate = data.candidates?.[0] - - if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') { - logger.warn( - 'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided', - { - finishReason: candidate.finishReason, - hasContent: !!candidate?.content, - hasParts: !!candidate?.content?.parts, - } - ) - const textContent = extractTextContent(candidate) - if (textContent) { - fullContent += textContent - controller.enqueue(new TextEncoder().encode(textContent)) - } - if (onComplete) onComplete(fullContent, usageData || undefined) - controller.close() - return - } - - if (candidate?.content?.parts) { - const functionCall = extractFunctionCall(candidate) - if (functionCall) { - logger.debug( - 'Function call detected in stream, ending stream to execute tool', - { - functionName: functionCall.name, - } - ) - if (onComplete) onComplete(fullContent, usageData || undefined) - controller.close() - return - } - const content = extractTextContent(candidate) - if (content) { - fullContent += content - controller.enqueue(new TextEncoder().encode(content)) - } - } - } catch (e) { - logger.error('Error parsing JSON from stream', { - error: e instanceof Error ? e.message : String(e), - jsonPreview: jsonStr.substring(0, 200), - }) - } - - buffer = buffer.substring(closeBrace + 1) - searchIndex = 0 - } else { - break - } - } - } - } catch (e) { - logger.error('Error reading Vertex AI stream', { - error: e instanceof Error ? e.message : String(e), - }) - controller.error(e) - } - }, - async cancel() { - await reader.cancel() - }, - }) -} - -/** - * Build Vertex AI endpoint URL - */ -export function buildVertexEndpoint( - project: string, - location: string, - model: string, - isStreaming: boolean -): string { - const action = isStreaming ? 'streamGenerateContent' : 'generateContent' - - if (location === 'global') { - return `https://aiplatform.googleapis.com/v1/projects/${project}/locations/global/publishers/google/models/${model}:${action}` - } - - return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/google/models/${model}:${action}` -} diff --git a/apps/sim/providers/vllm/index.ts b/apps/sim/providers/vllm/index.ts index 118040be7..6fe2734de 100644 --- a/apps/sim/providers/vllm/index.ts +++ b/apps/sim/providers/vllm/index.ts @@ -130,7 +130,7 @@ export const vllmProvider: ProviderConfig = { : undefined const payload: any = { - model: (request.model || getProviderDefaultModel('vllm')).replace(/^vllm\//, ''), + model: request.model.replace(/^vllm\//, ''), messages: allMessages, } diff --git a/apps/sim/providers/xai/index.ts b/apps/sim/providers/xai/index.ts index 5e29fccd4..d57f28281 100644 --- a/apps/sim/providers/xai/index.ts +++ b/apps/sim/providers/xai/index.ts @@ -48,7 +48,7 @@ export const xAIProvider: ProviderConfig = { hasTools: !!request.tools?.length, toolCount: request.tools?.length || 0, hasResponseFormat: !!request.responseFormat, - model: request.model || 'grok-3-latest', + model: request.model, streaming: !!request.stream, }) @@ -87,7 +87,7 @@ export const xAIProvider: ProviderConfig = { ) } const basePayload: any = { - model: request.model || 'grok-3-latest', + model: request.model, messages: allMessages, } @@ -139,7 +139,7 @@ export const xAIProvider: ProviderConfig = { success: true, output: { content: '', - model: request.model || 'grok-3-latest', + model: request.model, tokens: { prompt: 0, completion: 0, total: 0 }, toolCalls: undefined, providerTiming: { @@ -505,7 +505,7 @@ export const xAIProvider: ProviderConfig = { success: true, output: { content: '', - model: request.model || 'grok-3-latest', + model: request.model, tokens: { prompt: tokens.prompt, completion: tokens.completion,