From 89c1085950200fb89eb0d52853ce2d043a2f0e8e Mon Sep 17 00:00:00 2001 From: Waleed Date: Tue, 23 Dec 2025 13:11:56 -0800 Subject: [PATCH] improvement(vertex): added vertex to all LLM-based blocks, fixed refresh (#2555) * improvement(vertex): added vertex to all LLM-based blocks, fixed refresh * fix build --- apps/sim/app/api/providers/route.ts | 37 +++++++++++++- apps/sim/blocks/blocks/agent.ts | 2 + apps/sim/blocks/blocks/evaluator.ts | 34 +++++++++++-- apps/sim/blocks/blocks/router.ts | 35 ++++++++++++-- apps/sim/blocks/blocks/translate.ts | 36 ++++++++++++-- .../executor/handlers/agent/agent-handler.ts | 1 - .../handlers/evaluator/evaluator-handler.ts | 47 +++++++++++++++++- .../handlers/router/router-handler.ts | 48 ++++++++++++++++++- apps/sim/lib/oauth/oauth.ts | 18 ++++++- apps/sim/tools/llm/chat.ts | 8 ++++ 10 files changed, 248 insertions(+), 18 deletions(-) diff --git a/apps/sim/app/api/providers/route.ts b/apps/sim/app/api/providers/route.ts index 9a52f0bd7..04910ed1c 100644 --- a/apps/sim/app/api/providers/route.ts +++ b/apps/sim/app/api/providers/route.ts @@ -1,6 +1,10 @@ +import { db } from '@sim/db' +import { account } from '@sim/db/schema' +import { eq } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { generateRequestId } from '@/lib/core/utils/request' import { createLogger } from '@/lib/logs/console/logger' +import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils' import type { StreamingExecution } from '@/executor/types' import { executeProviderRequest } from '@/providers' import { getApiKey } from '@/providers/utils' @@ -37,6 +41,7 @@ export async function POST(request: NextRequest) { azureApiVersion, vertexProject, vertexLocation, + vertexCredential, responseFormat, workflowId, workspaceId, @@ -62,6 +67,7 @@ export async function POST(request: NextRequest) { hasAzureApiVersion: !!azureApiVersion, hasVertexProject: !!vertexProject, hasVertexLocation: !!vertexLocation, + hasVertexCredential: !!vertexCredential, hasResponseFormat: !!responseFormat, workflowId, stream: !!stream, @@ -76,13 +82,18 @@ export async function POST(request: NextRequest) { let finalApiKey: string try { - finalApiKey = getApiKey(provider, model, apiKey) + if (provider === 'vertex' && vertexCredential) { + finalApiKey = await resolveVertexCredential(requestId, vertexCredential) + } else { + finalApiKey = getApiKey(provider, model, apiKey) + } } catch (error) { logger.error(`[${requestId}] Failed to get API key:`, { provider, model, error: error instanceof Error ? error.message : String(error), hasProvidedApiKey: !!apiKey, + hasVertexCredential: !!vertexCredential, }) return NextResponse.json( { error: error instanceof Error ? error.message : 'API key error' }, @@ -324,3 +335,27 @@ function sanitizeObject(obj: any): any { return result } + +/** + * Resolves a Vertex AI OAuth credential to an access token + */ +async function resolveVertexCredential(requestId: string, credentialId: string): Promise { + logger.info(`[${requestId}] Resolving Vertex AI credential: ${credentialId}`) + + const credential = await db.query.account.findFirst({ + where: eq(account.id, credentialId), + }) + + if (!credential) { + throw new Error(`Vertex AI credential not found: ${credentialId}`) + } + + const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId) + + if (!accessToken) { + throw new Error('Failed to get Vertex AI access token') + } + + logger.info(`[${requestId}] Successfully resolved Vertex AI credential`) + return accessToken +} diff --git a/apps/sim/blocks/blocks/agent.ts b/apps/sim/blocks/blocks/agent.ts index 16227a290..75cee0200 100644 --- a/apps/sim/blocks/blocks/agent.ts +++ b/apps/sim/blocks/blocks/agent.ts @@ -310,6 +310,7 @@ export const AgentBlock: BlockConfig = { type: 'short-input', placeholder: 'your-gcp-project-id', connectionDroppable: false, + required: true, condition: { field: 'model', value: providers.vertex.models, @@ -321,6 +322,7 @@ export const AgentBlock: BlockConfig = { type: 'short-input', placeholder: 'us-central1', connectionDroppable: false, + required: true, condition: { field: 'model', value: providers.vertex.models, diff --git a/apps/sim/blocks/blocks/evaluator.ts b/apps/sim/blocks/blocks/evaluator.ts index 63ea9c74c..47e3b895f 100644 --- a/apps/sim/blocks/blocks/evaluator.ts +++ b/apps/sim/blocks/blocks/evaluator.ts @@ -18,6 +18,10 @@ const getCurrentOllamaModels = () => { return useProvidersStore.getState().providers.ollama.models } +const getCurrentVLLMModels = () => { + return useProvidersStore.getState().providers.vllm.models +} + interface Metric { name: string description: string @@ -196,6 +200,19 @@ export const EvaluatorBlock: BlockConfig = { }) }, }, + { + id: 'vertexCredential', + title: 'Google Cloud Account', + type: 'oauth-input', + serviceId: 'vertex-ai', + requiredScopes: ['https://www.googleapis.com/auth/cloud-platform'], + placeholder: 'Select Google Cloud account', + required: true, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, { id: 'apiKey', title: 'API Key', @@ -204,16 +221,21 @@ export const EvaluatorBlock: BlockConfig = { password: true, connectionDroppable: false, required: true, + // Hide API key for hosted models, Ollama models, vLLM models, and Vertex models (uses OAuth) condition: isHosted ? { field: 'model', - value: getHostedModels(), + value: [...getHostedModels(), ...providers.vertex.models], not: true, // Show for all models EXCEPT those listed } : () => ({ field: 'model', - value: getCurrentOllamaModels(), - not: true, // Show for all models EXCEPT Ollama models + value: [ + ...getCurrentOllamaModels(), + ...getCurrentVLLMModels(), + ...providers.vertex.models, + ], + not: true, // Show for all models EXCEPT Ollama, vLLM, and Vertex models }), }, { @@ -245,6 +267,7 @@ export const EvaluatorBlock: BlockConfig = { type: 'short-input', placeholder: 'your-gcp-project-id', connectionDroppable: false, + required: true, condition: { field: 'model', value: providers.vertex.models, @@ -256,6 +279,7 @@ export const EvaluatorBlock: BlockConfig = { type: 'short-input', placeholder: 'us-central1', connectionDroppable: false, + required: true, condition: { field: 'model', value: providers.vertex.models, @@ -386,6 +410,10 @@ export const EvaluatorBlock: BlockConfig = { type: 'string' as ParamType, description: 'Google Cloud location for Vertex AI', }, + vertexCredential: { + type: 'string' as ParamType, + description: 'Google Cloud OAuth credential ID for Vertex AI', + }, temperature: { type: 'number' as ParamType, description: 'Response randomness level (low for consistent evaluation)', diff --git a/apps/sim/blocks/blocks/router.ts b/apps/sim/blocks/blocks/router.ts index 0c6006a43..1549baa54 100644 --- a/apps/sim/blocks/blocks/router.ts +++ b/apps/sim/blocks/blocks/router.ts @@ -15,6 +15,10 @@ const getCurrentOllamaModels = () => { return useProvidersStore.getState().providers.ollama.models } +const getCurrentVLLMModels = () => { + return useProvidersStore.getState().providers.vllm.models +} + interface RouterResponse extends ToolResponse { output: { prompt: string @@ -144,6 +148,19 @@ export const RouterBlock: BlockConfig = { }) }, }, + { + id: 'vertexCredential', + title: 'Google Cloud Account', + type: 'oauth-input', + serviceId: 'vertex-ai', + requiredScopes: ['https://www.googleapis.com/auth/cloud-platform'], + placeholder: 'Select Google Cloud account', + required: true, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, { id: 'apiKey', title: 'API Key', @@ -152,17 +169,21 @@ export const RouterBlock: BlockConfig = { password: true, connectionDroppable: false, required: true, - // Hide API key for hosted models and Ollama models + // Hide API key for hosted models, Ollama models, vLLM models, and Vertex models (uses OAuth) condition: isHosted ? { field: 'model', - value: getHostedModels(), + value: [...getHostedModels(), ...providers.vertex.models], not: true, // Show for all models EXCEPT those listed } : () => ({ field: 'model', - value: getCurrentOllamaModels(), - not: true, // Show for all models EXCEPT Ollama models + value: [ + ...getCurrentOllamaModels(), + ...getCurrentVLLMModels(), + ...providers.vertex.models, + ], + not: true, // Show for all models EXCEPT Ollama, vLLM, and Vertex models }), }, { @@ -194,6 +215,7 @@ export const RouterBlock: BlockConfig = { type: 'short-input', placeholder: 'your-gcp-project-id', connectionDroppable: false, + required: true, condition: { field: 'model', value: providers.vertex.models, @@ -205,6 +227,7 @@ export const RouterBlock: BlockConfig = { type: 'short-input', placeholder: 'us-central1', connectionDroppable: false, + required: true, condition: { field: 'model', value: providers.vertex.models, @@ -259,6 +282,10 @@ export const RouterBlock: BlockConfig = { azureApiVersion: { type: 'string', description: 'Azure API version' }, vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' }, vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' }, + vertexCredential: { + type: 'string', + description: 'Google Cloud OAuth credential ID for Vertex AI', + }, temperature: { type: 'number', description: 'Response randomness level (low for consistent routing)', diff --git a/apps/sim/blocks/blocks/translate.ts b/apps/sim/blocks/blocks/translate.ts index 1ecfc7a20..44c646608 100644 --- a/apps/sim/blocks/blocks/translate.ts +++ b/apps/sim/blocks/blocks/translate.ts @@ -8,6 +8,10 @@ const getCurrentOllamaModels = () => { return useProvidersStore.getState().providers.ollama.models } +const getCurrentVLLMModels = () => { + return useProvidersStore.getState().providers.vllm.models +} + const getTranslationPrompt = (targetLanguage: string) => `Translate the following text into ${targetLanguage || 'English'}. Output ONLY the translated text with no additional commentary, explanations, or notes.` @@ -55,6 +59,19 @@ export const TranslateBlock: BlockConfig = { }) }, }, + { + id: 'vertexCredential', + title: 'Google Cloud Account', + type: 'oauth-input', + serviceId: 'vertex-ai', + requiredScopes: ['https://www.googleapis.com/auth/cloud-platform'], + placeholder: 'Select Google Cloud account', + required: true, + condition: { + field: 'model', + value: providers.vertex.models, + }, + }, { id: 'apiKey', title: 'API Key', @@ -63,17 +80,21 @@ export const TranslateBlock: BlockConfig = { password: true, connectionDroppable: false, required: true, - // Hide API key for hosted models and Ollama models + // Hide API key for hosted models, Ollama models, vLLM models, and Vertex models (uses OAuth) condition: isHosted ? { field: 'model', - value: getHostedModels(), + value: [...getHostedModels(), ...providers.vertex.models], not: true, // Show for all models EXCEPT those listed } : () => ({ field: 'model', - value: getCurrentOllamaModels(), - not: true, // Show for all models EXCEPT Ollama models + value: [ + ...getCurrentOllamaModels(), + ...getCurrentVLLMModels(), + ...providers.vertex.models, + ], + not: true, // Show for all models EXCEPT Ollama, vLLM, and Vertex models }), }, { @@ -105,6 +126,7 @@ export const TranslateBlock: BlockConfig = { type: 'short-input', placeholder: 'your-gcp-project-id', connectionDroppable: false, + required: true, condition: { field: 'model', value: providers.vertex.models, @@ -116,6 +138,7 @@ export const TranslateBlock: BlockConfig = { type: 'short-input', placeholder: 'us-central1', connectionDroppable: false, + required: true, condition: { field: 'model', value: providers.vertex.models, @@ -144,6 +167,7 @@ export const TranslateBlock: BlockConfig = { azureApiVersion: params.azureApiVersion, vertexProject: params.vertexProject, vertexLocation: params.vertexLocation, + vertexCredential: params.vertexCredential, }), }, }, @@ -155,6 +179,10 @@ export const TranslateBlock: BlockConfig = { azureApiVersion: { type: 'string', description: 'Azure API version' }, vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' }, vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' }, + vertexCredential: { + type: 'string', + description: 'Google Cloud OAuth credential ID for Vertex AI', + }, systemPrompt: { type: 'string', description: 'Translation instructions' }, }, outputs: { diff --git a/apps/sim/executor/handlers/agent/agent-handler.ts b/apps/sim/executor/handlers/agent/agent-handler.ts index 2f4a36332..292a154f0 100644 --- a/apps/sim/executor/handlers/agent/agent-handler.ts +++ b/apps/sim/executor/handlers/agent/agent-handler.ts @@ -1001,7 +1001,6 @@ export class AgentBlockHandler implements BlockHandler { ) { let finalApiKey: string - // For Vertex AI, resolve OAuth credential to access token if (providerId === 'vertex' && providerRequest.vertexCredential) { finalApiKey = await this.resolveVertexCredential( providerRequest.vertexCredential, diff --git a/apps/sim/executor/handlers/evaluator/evaluator-handler.ts b/apps/sim/executor/handlers/evaluator/evaluator-handler.ts index 694cf885c..ed370c468 100644 --- a/apps/sim/executor/handlers/evaluator/evaluator-handler.ts +++ b/apps/sim/executor/handlers/evaluator/evaluator-handler.ts @@ -1,4 +1,8 @@ +import { db } from '@sim/db' +import { account } from '@sim/db/schema' +import { eq } from 'drizzle-orm' import { createLogger } from '@/lib/logs/console/logger' +import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils' import type { BlockOutput } from '@/blocks/types' import { BlockType, DEFAULTS, EVALUATOR, HTTP } from '@/executor/constants' import type { BlockHandler, ExecutionContext } from '@/executor/types' @@ -25,9 +29,17 @@ export class EvaluatorBlockHandler implements BlockHandler { const evaluatorConfig = { model: inputs.model || EVALUATOR.DEFAULT_MODEL, apiKey: inputs.apiKey, + vertexProject: inputs.vertexProject, + vertexLocation: inputs.vertexLocation, + vertexCredential: inputs.vertexCredential, } const providerId = getProviderFromModel(evaluatorConfig.model) + let finalApiKey = evaluatorConfig.apiKey + if (providerId === 'vertex' && evaluatorConfig.vertexCredential) { + finalApiKey = await this.resolveVertexCredential(evaluatorConfig.vertexCredential) + } + const processedContent = this.processContent(inputs.content) let systemPromptObj: { systemPrompt: string; responseFormat: any } = { @@ -87,7 +99,7 @@ export class EvaluatorBlockHandler implements BlockHandler { try { const url = buildAPIUrl('/api/providers') - const providerRequest = { + const providerRequest: Record = { provider: providerId, model: evaluatorConfig.model, systemPrompt: systemPromptObj.systemPrompt, @@ -101,10 +113,15 @@ export class EvaluatorBlockHandler implements BlockHandler { ]), temperature: EVALUATOR.DEFAULT_TEMPERATURE, - apiKey: evaluatorConfig.apiKey, + apiKey: finalApiKey, workflowId: ctx.workflowId, } + if (providerId === 'vertex') { + providerRequest.vertexProject = evaluatorConfig.vertexProject + providerRequest.vertexLocation = evaluatorConfig.vertexLocation + } + const response = await fetch(url.toString(), { method: 'POST', headers: { @@ -250,4 +267,30 @@ export class EvaluatorBlockHandler implements BlockHandler { logger.warn(`Metric "${metricName}" not found in LLM response`) return DEFAULTS.EXECUTION_TIME } + + /** + * Resolves a Vertex AI OAuth credential to an access token + */ + private async resolveVertexCredential(credentialId: string): Promise { + const requestId = `vertex-evaluator-${Date.now()}` + + logger.info(`[${requestId}] Resolving Vertex AI credential: ${credentialId}`) + + const credential = await db.query.account.findFirst({ + where: eq(account.id, credentialId), + }) + + if (!credential) { + throw new Error(`Vertex AI credential not found: ${credentialId}`) + } + + const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId) + + if (!accessToken) { + throw new Error('Failed to get Vertex AI access token') + } + + logger.info(`[${requestId}] Successfully resolved Vertex AI credential`) + return accessToken + } } diff --git a/apps/sim/executor/handlers/router/router-handler.ts b/apps/sim/executor/handlers/router/router-handler.ts index 59c5e8291..8b52d6217 100644 --- a/apps/sim/executor/handlers/router/router-handler.ts +++ b/apps/sim/executor/handlers/router/router-handler.ts @@ -1,5 +1,9 @@ +import { db } from '@sim/db' +import { account } from '@sim/db/schema' +import { eq } from 'drizzle-orm' import { getBaseUrl } from '@/lib/core/utils/urls' import { createLogger } from '@/lib/logs/console/logger' +import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils' import { generateRouterPrompt } from '@/blocks/blocks/router' import type { BlockOutput } from '@/blocks/types' import { BlockType, DEFAULTS, HTTP, isAgentBlockType, ROUTER } from '@/executor/constants' @@ -30,6 +34,9 @@ export class RouterBlockHandler implements BlockHandler { prompt: inputs.prompt, model: inputs.model || ROUTER.DEFAULT_MODEL, apiKey: inputs.apiKey, + vertexProject: inputs.vertexProject, + vertexLocation: inputs.vertexLocation, + vertexCredential: inputs.vertexCredential, } const providerId = getProviderFromModel(routerConfig.model) @@ -39,16 +46,27 @@ export class RouterBlockHandler implements BlockHandler { const messages = [{ role: 'user', content: routerConfig.prompt }] const systemPrompt = generateRouterPrompt(routerConfig.prompt, targetBlocks) - const providerRequest = { + + let finalApiKey = routerConfig.apiKey + if (providerId === 'vertex' && routerConfig.vertexCredential) { + finalApiKey = await this.resolveVertexCredential(routerConfig.vertexCredential) + } + + const providerRequest: Record = { provider: providerId, model: routerConfig.model, systemPrompt: systemPrompt, context: JSON.stringify(messages), temperature: ROUTER.INFERENCE_TEMPERATURE, - apiKey: routerConfig.apiKey, + apiKey: finalApiKey, workflowId: ctx.workflowId, } + if (providerId === 'vertex') { + providerRequest.vertexProject = routerConfig.vertexProject + providerRequest.vertexLocation = routerConfig.vertexLocation + } + const response = await fetch(url.toString(), { method: 'POST', headers: { @@ -152,4 +170,30 @@ export class RouterBlockHandler implements BlockHandler { } }) } + + /** + * Resolves a Vertex AI OAuth credential to an access token + */ + private async resolveVertexCredential(credentialId: string): Promise { + const requestId = `vertex-router-${Date.now()}` + + logger.info(`[${requestId}] Resolving Vertex AI credential: ${credentialId}`) + + const credential = await db.query.account.findFirst({ + where: eq(account.id, credentialId), + }) + + if (!credential) { + throw new Error(`Vertex AI credential not found: ${credentialId}`) + } + + const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId) + + if (!accessToken) { + throw new Error('Failed to get Vertex AI access token') + } + + logger.info(`[${requestId}] Successfully resolved Vertex AI credential`) + return accessToken + } } diff --git a/apps/sim/lib/oauth/oauth.ts b/apps/sim/lib/oauth/oauth.ts index 5c1e99b93..e7a10ac06 100644 --- a/apps/sim/lib/oauth/oauth.ts +++ b/apps/sim/lib/oauth/oauth.ts @@ -1107,12 +1107,28 @@ function buildAuthRequest( * @param refreshToken The refresh token to use * @returns Object containing the new access token and expiration time in seconds, or null if refresh failed */ +function getBaseProviderForService(providerId: string): string { + if (providerId in OAUTH_PROVIDERS) { + return providerId + } + + for (const [baseProvider, config] of Object.entries(OAUTH_PROVIDERS)) { + for (const service of Object.values(config.services)) { + if (service.providerId === providerId) { + return baseProvider + } + } + } + + throw new Error(`Unknown OAuth provider: ${providerId}`) +} + export async function refreshOAuthToken( providerId: string, refreshToken: string ): Promise<{ accessToken: string; expiresIn: number; refreshToken: string } | null> { try { - const provider = providerId.split('-')[0] + const provider = getBaseProviderForService(providerId) const config = getProviderAuthConfig(provider) diff --git a/apps/sim/tools/llm/chat.ts b/apps/sim/tools/llm/chat.ts index 7af74232d..5f1bb3b2f 100644 --- a/apps/sim/tools/llm/chat.ts +++ b/apps/sim/tools/llm/chat.ts @@ -15,6 +15,7 @@ interface LLMChatParams { azureApiVersion?: string vertexProject?: string vertexLocation?: string + vertexCredential?: string } interface LLMChatResponse extends ToolResponse { @@ -91,6 +92,12 @@ export const llmChatTool: ToolConfig = { visibility: 'hidden', description: 'Google Cloud location for Vertex AI (defaults to us-central1)', }, + vertexCredential: { + type: 'string', + required: false, + visibility: 'hidden', + description: 'Google Cloud OAuth credential ID for Vertex AI', + }, }, request: { @@ -114,6 +121,7 @@ export const llmChatTool: ToolConfig = { azureApiVersion: params.azureApiVersion, vertexProject: params.vertexProject, vertexLocation: params.vertexLocation, + vertexCredential: params.vertexCredential, } }, },