improvement(vertex): added vertex to all LLM-based blocks, fixed refresh (#2555)

* improvement(vertex): added vertex to all LLM-based blocks, fixed refresh

* fix build
This commit is contained in:
Waleed
2025-12-23 13:11:56 -08:00
committed by GitHub
parent 4e09c389e8
commit 89c1085950
10 changed files with 248 additions and 18 deletions

View File

@@ -1,6 +1,10 @@
import { db } from '@sim/db'
import { account } from '@sim/db/schema'
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { generateRequestId } from '@/lib/core/utils/request'
import { createLogger } from '@/lib/logs/console/logger'
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
import type { StreamingExecution } from '@/executor/types'
import { executeProviderRequest } from '@/providers'
import { getApiKey } from '@/providers/utils'
@@ -37,6 +41,7 @@ export async function POST(request: NextRequest) {
azureApiVersion,
vertexProject,
vertexLocation,
vertexCredential,
responseFormat,
workflowId,
workspaceId,
@@ -62,6 +67,7 @@ export async function POST(request: NextRequest) {
hasAzureApiVersion: !!azureApiVersion,
hasVertexProject: !!vertexProject,
hasVertexLocation: !!vertexLocation,
hasVertexCredential: !!vertexCredential,
hasResponseFormat: !!responseFormat,
workflowId,
stream: !!stream,
@@ -76,13 +82,18 @@ export async function POST(request: NextRequest) {
let finalApiKey: string
try {
finalApiKey = getApiKey(provider, model, apiKey)
if (provider === 'vertex' && vertexCredential) {
finalApiKey = await resolveVertexCredential(requestId, vertexCredential)
} else {
finalApiKey = getApiKey(provider, model, apiKey)
}
} catch (error) {
logger.error(`[${requestId}] Failed to get API key:`, {
provider,
model,
error: error instanceof Error ? error.message : String(error),
hasProvidedApiKey: !!apiKey,
hasVertexCredential: !!vertexCredential,
})
return NextResponse.json(
{ error: error instanceof Error ? error.message : 'API key error' },
@@ -324,3 +335,27 @@ function sanitizeObject(obj: any): any {
return result
}
/**
* Resolves a Vertex AI OAuth credential to an access token
*/
async function resolveVertexCredential(requestId: string, credentialId: string): Promise<string> {
logger.info(`[${requestId}] Resolving Vertex AI credential: ${credentialId}`)
const credential = await db.query.account.findFirst({
where: eq(account.id, credentialId),
})
if (!credential) {
throw new Error(`Vertex AI credential not found: ${credentialId}`)
}
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
if (!accessToken) {
throw new Error('Failed to get Vertex AI access token')
}
logger.info(`[${requestId}] Successfully resolved Vertex AI credential`)
return accessToken
}

View File

@@ -310,6 +310,7 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
type: 'short-input',
placeholder: 'your-gcp-project-id',
connectionDroppable: false,
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
@@ -321,6 +322,7 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
type: 'short-input',
placeholder: 'us-central1',
connectionDroppable: false,
required: true,
condition: {
field: 'model',
value: providers.vertex.models,

View File

@@ -18,6 +18,10 @@ const getCurrentOllamaModels = () => {
return useProvidersStore.getState().providers.ollama.models
}
const getCurrentVLLMModels = () => {
return useProvidersStore.getState().providers.vllm.models
}
interface Metric {
name: string
description: string
@@ -196,6 +200,19 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
})
},
},
{
id: 'vertexCredential',
title: 'Google Cloud Account',
type: 'oauth-input',
serviceId: 'vertex-ai',
requiredScopes: ['https://www.googleapis.com/auth/cloud-platform'],
placeholder: 'Select Google Cloud account',
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'apiKey',
title: 'API Key',
@@ -204,16 +221,21 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
password: true,
connectionDroppable: false,
required: true,
// Hide API key for hosted models, Ollama models, vLLM models, and Vertex models (uses OAuth)
condition: isHosted
? {
field: 'model',
value: getHostedModels(),
value: [...getHostedModels(), ...providers.vertex.models],
not: true, // Show for all models EXCEPT those listed
}
: () => ({
field: 'model',
value: getCurrentOllamaModels(),
not: true, // Show for all models EXCEPT Ollama models
value: [
...getCurrentOllamaModels(),
...getCurrentVLLMModels(),
...providers.vertex.models,
],
not: true, // Show for all models EXCEPT Ollama, vLLM, and Vertex models
}),
},
{
@@ -245,6 +267,7 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
type: 'short-input',
placeholder: 'your-gcp-project-id',
connectionDroppable: false,
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
@@ -256,6 +279,7 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
type: 'short-input',
placeholder: 'us-central1',
connectionDroppable: false,
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
@@ -386,6 +410,10 @@ export const EvaluatorBlock: BlockConfig<EvaluatorResponse> = {
type: 'string' as ParamType,
description: 'Google Cloud location for Vertex AI',
},
vertexCredential: {
type: 'string' as ParamType,
description: 'Google Cloud OAuth credential ID for Vertex AI',
},
temperature: {
type: 'number' as ParamType,
description: 'Response randomness level (low for consistent evaluation)',

View File

@@ -15,6 +15,10 @@ const getCurrentOllamaModels = () => {
return useProvidersStore.getState().providers.ollama.models
}
const getCurrentVLLMModels = () => {
return useProvidersStore.getState().providers.vllm.models
}
interface RouterResponse extends ToolResponse {
output: {
prompt: string
@@ -144,6 +148,19 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
})
},
},
{
id: 'vertexCredential',
title: 'Google Cloud Account',
type: 'oauth-input',
serviceId: 'vertex-ai',
requiredScopes: ['https://www.googleapis.com/auth/cloud-platform'],
placeholder: 'Select Google Cloud account',
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'apiKey',
title: 'API Key',
@@ -152,17 +169,21 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
password: true,
connectionDroppable: false,
required: true,
// Hide API key for hosted models and Ollama models
// Hide API key for hosted models, Ollama models, vLLM models, and Vertex models (uses OAuth)
condition: isHosted
? {
field: 'model',
value: getHostedModels(),
value: [...getHostedModels(), ...providers.vertex.models],
not: true, // Show for all models EXCEPT those listed
}
: () => ({
field: 'model',
value: getCurrentOllamaModels(),
not: true, // Show for all models EXCEPT Ollama models
value: [
...getCurrentOllamaModels(),
...getCurrentVLLMModels(),
...providers.vertex.models,
],
not: true, // Show for all models EXCEPT Ollama, vLLM, and Vertex models
}),
},
{
@@ -194,6 +215,7 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
type: 'short-input',
placeholder: 'your-gcp-project-id',
connectionDroppable: false,
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
@@ -205,6 +227,7 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
type: 'short-input',
placeholder: 'us-central1',
connectionDroppable: false,
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
@@ -259,6 +282,10 @@ export const RouterBlock: BlockConfig<RouterResponse> = {
azureApiVersion: { type: 'string', description: 'Azure API version' },
vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' },
vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' },
vertexCredential: {
type: 'string',
description: 'Google Cloud OAuth credential ID for Vertex AI',
},
temperature: {
type: 'number',
description: 'Response randomness level (low for consistent routing)',

View File

@@ -8,6 +8,10 @@ const getCurrentOllamaModels = () => {
return useProvidersStore.getState().providers.ollama.models
}
const getCurrentVLLMModels = () => {
return useProvidersStore.getState().providers.vllm.models
}
const getTranslationPrompt = (targetLanguage: string) =>
`Translate the following text into ${targetLanguage || 'English'}. Output ONLY the translated text with no additional commentary, explanations, or notes.`
@@ -55,6 +59,19 @@ export const TranslateBlock: BlockConfig = {
})
},
},
{
id: 'vertexCredential',
title: 'Google Cloud Account',
type: 'oauth-input',
serviceId: 'vertex-ai',
requiredScopes: ['https://www.googleapis.com/auth/cloud-platform'],
placeholder: 'Select Google Cloud account',
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'apiKey',
title: 'API Key',
@@ -63,17 +80,21 @@ export const TranslateBlock: BlockConfig = {
password: true,
connectionDroppable: false,
required: true,
// Hide API key for hosted models and Ollama models
// Hide API key for hosted models, Ollama models, vLLM models, and Vertex models (uses OAuth)
condition: isHosted
? {
field: 'model',
value: getHostedModels(),
value: [...getHostedModels(), ...providers.vertex.models],
not: true, // Show for all models EXCEPT those listed
}
: () => ({
field: 'model',
value: getCurrentOllamaModels(),
not: true, // Show for all models EXCEPT Ollama models
value: [
...getCurrentOllamaModels(),
...getCurrentVLLMModels(),
...providers.vertex.models,
],
not: true, // Show for all models EXCEPT Ollama, vLLM, and Vertex models
}),
},
{
@@ -105,6 +126,7 @@ export const TranslateBlock: BlockConfig = {
type: 'short-input',
placeholder: 'your-gcp-project-id',
connectionDroppable: false,
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
@@ -116,6 +138,7 @@ export const TranslateBlock: BlockConfig = {
type: 'short-input',
placeholder: 'us-central1',
connectionDroppable: false,
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
@@ -144,6 +167,7 @@ export const TranslateBlock: BlockConfig = {
azureApiVersion: params.azureApiVersion,
vertexProject: params.vertexProject,
vertexLocation: params.vertexLocation,
vertexCredential: params.vertexCredential,
}),
},
},
@@ -155,6 +179,10 @@ export const TranslateBlock: BlockConfig = {
azureApiVersion: { type: 'string', description: 'Azure API version' },
vertexProject: { type: 'string', description: 'Google Cloud project ID for Vertex AI' },
vertexLocation: { type: 'string', description: 'Google Cloud location for Vertex AI' },
vertexCredential: {
type: 'string',
description: 'Google Cloud OAuth credential ID for Vertex AI',
},
systemPrompt: { type: 'string', description: 'Translation instructions' },
},
outputs: {

View File

@@ -1001,7 +1001,6 @@ export class AgentBlockHandler implements BlockHandler {
) {
let finalApiKey: string
// For Vertex AI, resolve OAuth credential to access token
if (providerId === 'vertex' && providerRequest.vertexCredential) {
finalApiKey = await this.resolveVertexCredential(
providerRequest.vertexCredential,

View File

@@ -1,4 +1,8 @@
import { db } from '@sim/db'
import { account } from '@sim/db/schema'
import { eq } from 'drizzle-orm'
import { createLogger } from '@/lib/logs/console/logger'
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
import type { BlockOutput } from '@/blocks/types'
import { BlockType, DEFAULTS, EVALUATOR, HTTP } from '@/executor/constants'
import type { BlockHandler, ExecutionContext } from '@/executor/types'
@@ -25,9 +29,17 @@ export class EvaluatorBlockHandler implements BlockHandler {
const evaluatorConfig = {
model: inputs.model || EVALUATOR.DEFAULT_MODEL,
apiKey: inputs.apiKey,
vertexProject: inputs.vertexProject,
vertexLocation: inputs.vertexLocation,
vertexCredential: inputs.vertexCredential,
}
const providerId = getProviderFromModel(evaluatorConfig.model)
let finalApiKey = evaluatorConfig.apiKey
if (providerId === 'vertex' && evaluatorConfig.vertexCredential) {
finalApiKey = await this.resolveVertexCredential(evaluatorConfig.vertexCredential)
}
const processedContent = this.processContent(inputs.content)
let systemPromptObj: { systemPrompt: string; responseFormat: any } = {
@@ -87,7 +99,7 @@ export class EvaluatorBlockHandler implements BlockHandler {
try {
const url = buildAPIUrl('/api/providers')
const providerRequest = {
const providerRequest: Record<string, any> = {
provider: providerId,
model: evaluatorConfig.model,
systemPrompt: systemPromptObj.systemPrompt,
@@ -101,10 +113,15 @@ export class EvaluatorBlockHandler implements BlockHandler {
]),
temperature: EVALUATOR.DEFAULT_TEMPERATURE,
apiKey: evaluatorConfig.apiKey,
apiKey: finalApiKey,
workflowId: ctx.workflowId,
}
if (providerId === 'vertex') {
providerRequest.vertexProject = evaluatorConfig.vertexProject
providerRequest.vertexLocation = evaluatorConfig.vertexLocation
}
const response = await fetch(url.toString(), {
method: 'POST',
headers: {
@@ -250,4 +267,30 @@ export class EvaluatorBlockHandler implements BlockHandler {
logger.warn(`Metric "${metricName}" not found in LLM response`)
return DEFAULTS.EXECUTION_TIME
}
/**
* Resolves a Vertex AI OAuth credential to an access token
*/
private async resolveVertexCredential(credentialId: string): Promise<string> {
const requestId = `vertex-evaluator-${Date.now()}`
logger.info(`[${requestId}] Resolving Vertex AI credential: ${credentialId}`)
const credential = await db.query.account.findFirst({
where: eq(account.id, credentialId),
})
if (!credential) {
throw new Error(`Vertex AI credential not found: ${credentialId}`)
}
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
if (!accessToken) {
throw new Error('Failed to get Vertex AI access token')
}
logger.info(`[${requestId}] Successfully resolved Vertex AI credential`)
return accessToken
}
}

View File

@@ -1,5 +1,9 @@
import { db } from '@sim/db'
import { account } from '@sim/db/schema'
import { eq } from 'drizzle-orm'
import { getBaseUrl } from '@/lib/core/utils/urls'
import { createLogger } from '@/lib/logs/console/logger'
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
import { generateRouterPrompt } from '@/blocks/blocks/router'
import type { BlockOutput } from '@/blocks/types'
import { BlockType, DEFAULTS, HTTP, isAgentBlockType, ROUTER } from '@/executor/constants'
@@ -30,6 +34,9 @@ export class RouterBlockHandler implements BlockHandler {
prompt: inputs.prompt,
model: inputs.model || ROUTER.DEFAULT_MODEL,
apiKey: inputs.apiKey,
vertexProject: inputs.vertexProject,
vertexLocation: inputs.vertexLocation,
vertexCredential: inputs.vertexCredential,
}
const providerId = getProviderFromModel(routerConfig.model)
@@ -39,16 +46,27 @@ export class RouterBlockHandler implements BlockHandler {
const messages = [{ role: 'user', content: routerConfig.prompt }]
const systemPrompt = generateRouterPrompt(routerConfig.prompt, targetBlocks)
const providerRequest = {
let finalApiKey = routerConfig.apiKey
if (providerId === 'vertex' && routerConfig.vertexCredential) {
finalApiKey = await this.resolveVertexCredential(routerConfig.vertexCredential)
}
const providerRequest: Record<string, any> = {
provider: providerId,
model: routerConfig.model,
systemPrompt: systemPrompt,
context: JSON.stringify(messages),
temperature: ROUTER.INFERENCE_TEMPERATURE,
apiKey: routerConfig.apiKey,
apiKey: finalApiKey,
workflowId: ctx.workflowId,
}
if (providerId === 'vertex') {
providerRequest.vertexProject = routerConfig.vertexProject
providerRequest.vertexLocation = routerConfig.vertexLocation
}
const response = await fetch(url.toString(), {
method: 'POST',
headers: {
@@ -152,4 +170,30 @@ export class RouterBlockHandler implements BlockHandler {
}
})
}
/**
* Resolves a Vertex AI OAuth credential to an access token
*/
private async resolveVertexCredential(credentialId: string): Promise<string> {
const requestId = `vertex-router-${Date.now()}`
logger.info(`[${requestId}] Resolving Vertex AI credential: ${credentialId}`)
const credential = await db.query.account.findFirst({
where: eq(account.id, credentialId),
})
if (!credential) {
throw new Error(`Vertex AI credential not found: ${credentialId}`)
}
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
if (!accessToken) {
throw new Error('Failed to get Vertex AI access token')
}
logger.info(`[${requestId}] Successfully resolved Vertex AI credential`)
return accessToken
}
}

View File

@@ -1107,12 +1107,28 @@ function buildAuthRequest(
* @param refreshToken The refresh token to use
* @returns Object containing the new access token and expiration time in seconds, or null if refresh failed
*/
function getBaseProviderForService(providerId: string): string {
if (providerId in OAUTH_PROVIDERS) {
return providerId
}
for (const [baseProvider, config] of Object.entries(OAUTH_PROVIDERS)) {
for (const service of Object.values(config.services)) {
if (service.providerId === providerId) {
return baseProvider
}
}
}
throw new Error(`Unknown OAuth provider: ${providerId}`)
}
export async function refreshOAuthToken(
providerId: string,
refreshToken: string
): Promise<{ accessToken: string; expiresIn: number; refreshToken: string } | null> {
try {
const provider = providerId.split('-')[0]
const provider = getBaseProviderForService(providerId)
const config = getProviderAuthConfig(provider)

View File

@@ -15,6 +15,7 @@ interface LLMChatParams {
azureApiVersion?: string
vertexProject?: string
vertexLocation?: string
vertexCredential?: string
}
interface LLMChatResponse extends ToolResponse {
@@ -91,6 +92,12 @@ export const llmChatTool: ToolConfig<LLMChatParams, LLMChatResponse> = {
visibility: 'hidden',
description: 'Google Cloud location for Vertex AI (defaults to us-central1)',
},
vertexCredential: {
type: 'string',
required: false,
visibility: 'hidden',
description: 'Google Cloud OAuth credential ID for Vertex AI',
},
},
request: {
@@ -114,6 +121,7 @@ export const llmChatTool: ToolConfig<LLMChatParams, LLMChatResponse> = {
azureApiVersion: params.azureApiVersion,
vertexProject: params.vertexProject,
vertexLocation: params.vertexLocation,
vertexCredential: params.vertexCredential,
}
},
},