mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-07 22:24:06 -05:00
fix(search): removed full text param from built-in search, anthropic provider streaming fix (#2542)
* fix(search): removed full text param from built-in search, anthropic provider streaming fix * rewrite gemini provider with official sdk + add thinking capability * vertex gemini consolidation * never silently use different model * pass oauth client through the googleAuthOptions param directly * make server side provider registry * remove comments * take oauth selector below model selector --------- Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
This commit is contained in:
@@ -56,7 +56,7 @@ export async function POST(request: NextRequest) {
|
||||
query: validated.query,
|
||||
type: 'auto',
|
||||
useAutoprompt: true,
|
||||
text: true,
|
||||
highlights: true,
|
||||
apiKey: env.EXA_API_KEY,
|
||||
})
|
||||
|
||||
@@ -77,7 +77,7 @@ export async function POST(request: NextRequest) {
|
||||
const results = (result.output.results || []).map((r: any, index: number) => ({
|
||||
title: r.title || '',
|
||||
link: r.url || '',
|
||||
snippet: r.text || '',
|
||||
snippet: Array.isArray(r.highlights) ? r.highlights.join(' ... ') : '',
|
||||
date: r.publishedDate || undefined,
|
||||
position: index + 1,
|
||||
}))
|
||||
|
||||
@@ -43,6 +43,8 @@ const SCOPE_DESCRIPTIONS: Record<string, string> = {
|
||||
'https://www.googleapis.com/auth/admin.directory.group.readonly': 'View Google Workspace groups',
|
||||
'https://www.googleapis.com/auth/admin.directory.group.member.readonly':
|
||||
'View Google Workspace group memberships',
|
||||
'https://www.googleapis.com/auth/cloud-platform':
|
||||
'Full access to Google Cloud resources for Vertex AI',
|
||||
'read:confluence-content.all': 'Read all Confluence content',
|
||||
'read:confluence-space.summary': 'Read Confluence space information',
|
||||
'read:space:confluence': 'View Confluence spaces',
|
||||
|
||||
@@ -9,8 +9,10 @@ import {
|
||||
getMaxTemperature,
|
||||
getProviderIcon,
|
||||
getReasoningEffortValuesForModel,
|
||||
getThinkingLevelsForModel,
|
||||
getVerbosityValuesForModel,
|
||||
MODELS_WITH_REASONING_EFFORT,
|
||||
MODELS_WITH_THINKING,
|
||||
MODELS_WITH_VERBOSITY,
|
||||
providers,
|
||||
supportsTemperature,
|
||||
@@ -108,7 +110,19 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
||||
})
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
id: 'vertexCredential',
|
||||
title: 'Google Cloud Account',
|
||||
type: 'oauth-input',
|
||||
serviceId: 'vertex-ai',
|
||||
requiredScopes: ['https://www.googleapis.com/auth/cloud-platform'],
|
||||
placeholder: 'Select Google Cloud account',
|
||||
required: true,
|
||||
condition: {
|
||||
field: 'model',
|
||||
value: providers.vertex.models,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'reasoningEffort',
|
||||
title: 'Reasoning Effort',
|
||||
@@ -215,6 +229,57 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
||||
value: MODELS_WITH_VERBOSITY,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'thinkingLevel',
|
||||
title: 'Thinking Level',
|
||||
type: 'dropdown',
|
||||
placeholder: 'Select thinking level...',
|
||||
options: [
|
||||
{ label: 'minimal', id: 'minimal' },
|
||||
{ label: 'low', id: 'low' },
|
||||
{ label: 'medium', id: 'medium' },
|
||||
{ label: 'high', id: 'high' },
|
||||
],
|
||||
dependsOn: ['model'],
|
||||
fetchOptions: async (blockId: string) => {
|
||||
const { useSubBlockStore } = await import('@/stores/workflows/subblock/store')
|
||||
const { useWorkflowRegistry } = await import('@/stores/workflows/registry/store')
|
||||
|
||||
const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId
|
||||
if (!activeWorkflowId) {
|
||||
return [
|
||||
{ label: 'low', id: 'low' },
|
||||
{ label: 'high', id: 'high' },
|
||||
]
|
||||
}
|
||||
|
||||
const workflowValues = useSubBlockStore.getState().workflowValues[activeWorkflowId]
|
||||
const blockValues = workflowValues?.[blockId]
|
||||
const modelValue = blockValues?.model as string
|
||||
|
||||
if (!modelValue) {
|
||||
return [
|
||||
{ label: 'low', id: 'low' },
|
||||
{ label: 'high', id: 'high' },
|
||||
]
|
||||
}
|
||||
|
||||
const validOptions = getThinkingLevelsForModel(modelValue)
|
||||
if (!validOptions) {
|
||||
return [
|
||||
{ label: 'low', id: 'low' },
|
||||
{ label: 'high', id: 'high' },
|
||||
]
|
||||
}
|
||||
|
||||
return validOptions.map((opt) => ({ label: opt, id: opt }))
|
||||
},
|
||||
value: () => 'high',
|
||||
condition: {
|
||||
field: 'model',
|
||||
value: MODELS_WITH_THINKING,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
id: 'azureEndpoint',
|
||||
@@ -275,17 +340,21 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
||||
password: true,
|
||||
connectionDroppable: false,
|
||||
required: true,
|
||||
// Hide API key for hosted models, Ollama models, and vLLM models
|
||||
// Hide API key for hosted models, Ollama models, vLLM models, and Vertex models (uses OAuth)
|
||||
condition: isHosted
|
||||
? {
|
||||
field: 'model',
|
||||
value: getHostedModels(),
|
||||
value: [...getHostedModels(), ...providers.vertex.models],
|
||||
not: true, // Show for all models EXCEPT those listed
|
||||
}
|
||||
: () => ({
|
||||
field: 'model',
|
||||
value: [...getCurrentOllamaModels(), ...getCurrentVLLMModels()],
|
||||
not: true, // Show for all models EXCEPT Ollama and vLLM models
|
||||
value: [
|
||||
...getCurrentOllamaModels(),
|
||||
...getCurrentVLLMModels(),
|
||||
...providers.vertex.models,
|
||||
],
|
||||
not: true, // Show for all models EXCEPT Ollama, vLLM, and Vertex models
|
||||
}),
|
||||
},
|
||||
{
|
||||
@@ -609,6 +678,7 @@ Example 3 (Array Input):
|
||||
temperature: { type: 'number', description: 'Response randomness level' },
|
||||
reasoningEffort: { type: 'string', description: 'Reasoning effort level for GPT-5 models' },
|
||||
verbosity: { type: 'string', description: 'Verbosity level for GPT-5 models' },
|
||||
thinkingLevel: { type: 'string', description: 'Thinking level for Gemini 3 models' },
|
||||
tools: { type: 'json', description: 'Available tools configuration' },
|
||||
},
|
||||
outputs: {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { db } from '@sim/db'
|
||||
import { mcpServers } from '@sim/db/schema'
|
||||
import { account, mcpServers } from '@sim/db/schema'
|
||||
import { and, eq, inArray, isNull } from 'drizzle-orm'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { createMcpToolId } from '@/lib/mcp/utils'
|
||||
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
|
||||
import { getAllBlocks } from '@/blocks'
|
||||
import type { BlockOutput } from '@/blocks/types'
|
||||
import { AGENT, BlockType, DEFAULTS, HTTP } from '@/executor/constants'
|
||||
@@ -919,6 +920,7 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
azureApiVersion: inputs.azureApiVersion,
|
||||
vertexProject: inputs.vertexProject,
|
||||
vertexLocation: inputs.vertexLocation,
|
||||
vertexCredential: inputs.vertexCredential,
|
||||
responseFormat,
|
||||
workflowId: ctx.workflowId,
|
||||
workspaceId: ctx.workspaceId,
|
||||
@@ -997,7 +999,17 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
responseFormat: any,
|
||||
providerStartTime: number
|
||||
) {
|
||||
const finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey)
|
||||
let finalApiKey: string
|
||||
|
||||
// For Vertex AI, resolve OAuth credential to access token
|
||||
if (providerId === 'vertex' && providerRequest.vertexCredential) {
|
||||
finalApiKey = await this.resolveVertexCredential(
|
||||
providerRequest.vertexCredential,
|
||||
ctx.workflowId
|
||||
)
|
||||
} else {
|
||||
finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey)
|
||||
}
|
||||
|
||||
const { blockData, blockNameMapping } = collectBlockData(ctx)
|
||||
|
||||
@@ -1024,7 +1036,6 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
blockNameMapping,
|
||||
})
|
||||
|
||||
this.logExecutionSuccess(providerId, model, ctx, block, providerStartTime, response)
|
||||
return this.processProviderResponse(response, block, responseFormat)
|
||||
}
|
||||
|
||||
@@ -1049,15 +1060,6 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
throw new Error(errorMessage)
|
||||
}
|
||||
|
||||
this.logExecutionSuccess(
|
||||
providerRequest.provider,
|
||||
providerRequest.model,
|
||||
ctx,
|
||||
block,
|
||||
providerStartTime,
|
||||
'HTTP response'
|
||||
)
|
||||
|
||||
const contentType = response.headers.get('Content-Type')
|
||||
if (contentType?.includes(HTTP.CONTENT_TYPE.EVENT_STREAM)) {
|
||||
return this.handleStreamingResponse(response, block, ctx, inputs)
|
||||
@@ -1117,21 +1119,33 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
}
|
||||
}
|
||||
|
||||
private logExecutionSuccess(
|
||||
provider: string,
|
||||
model: string,
|
||||
ctx: ExecutionContext,
|
||||
block: SerializedBlock,
|
||||
startTime: number,
|
||||
response: any
|
||||
) {
|
||||
const executionTime = Date.now() - startTime
|
||||
const responseType =
|
||||
response instanceof ReadableStream
|
||||
? 'stream'
|
||||
: response && typeof response === 'object' && 'stream' in response
|
||||
? 'streaming-execution'
|
||||
: 'json'
|
||||
/**
|
||||
* Resolves a Vertex AI OAuth credential to an access token
|
||||
*/
|
||||
private async resolveVertexCredential(credentialId: string, workflowId: string): Promise<string> {
|
||||
const requestId = `vertex-${Date.now()}`
|
||||
|
||||
logger.info(`[${requestId}] Resolving Vertex AI credential: ${credentialId}`)
|
||||
|
||||
// Get the credential - we need to find the owner
|
||||
// Since we're in a workflow context, we can query the credential directly
|
||||
const credential = await db.query.account.findFirst({
|
||||
where: eq(account.id, credentialId),
|
||||
})
|
||||
|
||||
if (!credential) {
|
||||
throw new Error(`Vertex AI credential not found: ${credentialId}`)
|
||||
}
|
||||
|
||||
// Refresh the token if needed
|
||||
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
|
||||
|
||||
if (!accessToken) {
|
||||
throw new Error('Failed to get Vertex AI access token')
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Successfully resolved Vertex AI credential`)
|
||||
return accessToken
|
||||
}
|
||||
|
||||
private handleExecutionError(
|
||||
|
||||
@@ -21,6 +21,7 @@ export interface AgentInputs {
|
||||
azureApiVersion?: string
|
||||
vertexProject?: string
|
||||
vertexLocation?: string
|
||||
vertexCredential?: string
|
||||
reasoningEffort?: string
|
||||
verbosity?: string
|
||||
}
|
||||
|
||||
@@ -579,6 +579,21 @@ export const auth = betterAuth({
|
||||
redirectURI: `${getBaseUrl()}/api/auth/oauth2/callback/google-groups`,
|
||||
},
|
||||
|
||||
{
|
||||
providerId: 'vertex-ai',
|
||||
clientId: env.GOOGLE_CLIENT_ID as string,
|
||||
clientSecret: env.GOOGLE_CLIENT_SECRET as string,
|
||||
discoveryUrl: 'https://accounts.google.com/.well-known/openid-configuration',
|
||||
accessType: 'offline',
|
||||
scopes: [
|
||||
'https://www.googleapis.com/auth/userinfo.email',
|
||||
'https://www.googleapis.com/auth/userinfo.profile',
|
||||
'https://www.googleapis.com/auth/cloud-platform',
|
||||
],
|
||||
prompt: 'consent',
|
||||
redirectURI: `${getBaseUrl()}/api/auth/oauth2/callback/vertex-ai`,
|
||||
},
|
||||
|
||||
{
|
||||
providerId: 'microsoft-teams',
|
||||
clientId: env.MICROSOFT_CLIENT_ID as string,
|
||||
|
||||
@@ -41,7 +41,7 @@ function filterUserFile(data: any): any {
|
||||
const DISPLAY_FILTERS = [filterUserFile]
|
||||
|
||||
export function filterForDisplay(data: any): any {
|
||||
const seen = new WeakSet()
|
||||
const seen = new Set<object>()
|
||||
return filterForDisplayInternal(data, seen, 0)
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ function getObjectType(data: unknown): string {
|
||||
return Object.prototype.toString.call(data).slice(8, -1)
|
||||
}
|
||||
|
||||
function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: number): any {
|
||||
function filterForDisplayInternal(data: any, seen: Set<object>, depth: number): any {
|
||||
try {
|
||||
if (data === null || data === undefined) {
|
||||
return data
|
||||
@@ -93,6 +93,7 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
|
||||
return '[Unknown Type]'
|
||||
}
|
||||
|
||||
// True circular reference: object is an ancestor in the current path
|
||||
if (seen.has(data)) {
|
||||
return '[Circular Reference]'
|
||||
}
|
||||
@@ -131,18 +132,24 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
|
||||
return `[ArrayBuffer: ${(data as ArrayBuffer).byteLength} bytes]`
|
||||
|
||||
case 'Map': {
|
||||
seen.add(data)
|
||||
const obj: Record<string, any> = {}
|
||||
for (const [key, value] of (data as Map<any, any>).entries()) {
|
||||
const keyStr = typeof key === 'string' ? key : String(key)
|
||||
obj[keyStr] = filterForDisplayInternal(value, seen, depth + 1)
|
||||
}
|
||||
seen.delete(data)
|
||||
return obj
|
||||
}
|
||||
|
||||
case 'Set':
|
||||
return Array.from(data as Set<any>).map((item) =>
|
||||
case 'Set': {
|
||||
seen.add(data)
|
||||
const result = Array.from(data as Set<any>).map((item) =>
|
||||
filterForDisplayInternal(item, seen, depth + 1)
|
||||
)
|
||||
seen.delete(data)
|
||||
return result
|
||||
}
|
||||
|
||||
case 'WeakMap':
|
||||
return '[WeakMap]'
|
||||
@@ -161,17 +168,22 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
|
||||
return `[${objectType}: ${(data as ArrayBufferView).byteLength} bytes]`
|
||||
}
|
||||
|
||||
// Add to current path before processing children
|
||||
seen.add(data)
|
||||
|
||||
for (const filterFn of DISPLAY_FILTERS) {
|
||||
const result = filterFn(data)
|
||||
if (result !== data) {
|
||||
return filterForDisplayInternal(result, seen, depth + 1)
|
||||
const filtered = filterFn(data)
|
||||
if (filtered !== data) {
|
||||
const result = filterForDisplayInternal(filtered, seen, depth + 1)
|
||||
seen.delete(data)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
if (Array.isArray(data)) {
|
||||
return data.map((item) => filterForDisplayInternal(item, seen, depth + 1))
|
||||
const result = data.map((item) => filterForDisplayInternal(item, seen, depth + 1))
|
||||
seen.delete(data)
|
||||
return result
|
||||
}
|
||||
|
||||
const result: Record<string, any> = {}
|
||||
@@ -182,6 +194,8 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
|
||||
result[key] = '[Error accessing property]'
|
||||
}
|
||||
}
|
||||
// Remove from current path after processing children
|
||||
seen.delete(data)
|
||||
return result
|
||||
} catch {
|
||||
return '[Unserializable]'
|
||||
|
||||
@@ -32,6 +32,7 @@ import {
|
||||
SlackIcon,
|
||||
SpotifyIcon,
|
||||
TrelloIcon,
|
||||
VertexIcon,
|
||||
WealthboxIcon,
|
||||
WebflowIcon,
|
||||
WordpressIcon,
|
||||
@@ -80,6 +81,7 @@ export type OAuthService =
|
||||
| 'google-vault'
|
||||
| 'google-forms'
|
||||
| 'google-groups'
|
||||
| 'vertex-ai'
|
||||
| 'github'
|
||||
| 'x'
|
||||
| 'confluence'
|
||||
@@ -237,6 +239,16 @@ export const OAUTH_PROVIDERS: Record<string, OAuthProviderConfig> = {
|
||||
],
|
||||
scopeHints: ['admin.directory.group'],
|
||||
},
|
||||
'vertex-ai': {
|
||||
id: 'vertex-ai',
|
||||
name: 'Vertex AI',
|
||||
description: 'Access Google Cloud Vertex AI for Gemini models with OAuth.',
|
||||
providerId: 'vertex-ai',
|
||||
icon: (props) => VertexIcon(props),
|
||||
baseProviderIcon: (props) => VertexIcon(props),
|
||||
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
|
||||
scopeHints: ['cloud-platform', 'vertex', 'aiplatform'],
|
||||
},
|
||||
},
|
||||
defaultService: 'gmail',
|
||||
},
|
||||
@@ -1099,6 +1111,12 @@ export function parseProvider(provider: OAuthProvider): ProviderConfig {
|
||||
featureType: 'microsoft-planner',
|
||||
}
|
||||
}
|
||||
if (provider === 'vertex-ai') {
|
||||
return {
|
||||
baseProvider: 'google',
|
||||
featureType: 'vertex-ai',
|
||||
}
|
||||
}
|
||||
|
||||
// Handle compound providers (e.g., 'google-email' -> { baseProvider: 'google', featureType: 'email' })
|
||||
const [base, feature] = provider.split('-')
|
||||
|
||||
@@ -58,7 +58,7 @@ export const anthropicProvider: ProviderConfig = {
|
||||
throw new Error('API key is required for Anthropic')
|
||||
}
|
||||
|
||||
const modelId = request.model || 'claude-3-7-sonnet-20250219'
|
||||
const modelId = request.model
|
||||
const useNativeStructuredOutputs = !!(
|
||||
request.responseFormat && supportsNativeStructuredOutputs(modelId)
|
||||
)
|
||||
@@ -174,7 +174,7 @@ export const anthropicProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
const payload: any = {
|
||||
model: request.model || 'claude-3-7-sonnet-20250219',
|
||||
model: request.model,
|
||||
messages,
|
||||
system: systemPrompt,
|
||||
max_tokens: Number.parseInt(String(request.maxTokens)) || 1024,
|
||||
@@ -561,37 +561,93 @@ export const anthropicProvider: ProviderConfig = {
|
||||
throw error
|
||||
}
|
||||
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
|
||||
|
||||
return {
|
||||
content,
|
||||
model: request.model || 'claude-3-7-sonnet-20250219',
|
||||
tokens,
|
||||
toolCalls:
|
||||
toolCalls.length > 0
|
||||
? toolCalls.map((tc) => ({
|
||||
name: tc.name,
|
||||
arguments: tc.arguments as Record<string, any>,
|
||||
startTime: tc.startTime,
|
||||
endTime: tc.endTime,
|
||||
duration: tc.duration,
|
||||
result: tc.result,
|
||||
}))
|
||||
: undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
timing: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
const streamingPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
stream: true,
|
||||
tool_choice: undefined,
|
||||
}
|
||||
|
||||
const streamResponse: any = await anthropic.messages.create(streamingPayload)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromAnthropicStream(
|
||||
streamResponse,
|
||||
(streamContent, usage) => {
|
||||
streamingResult.execution.output.content = streamContent
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: tokens.prompt + usage.input_tokens,
|
||||
completion: tokens.completion + usage.output_tokens,
|
||||
total: tokens.total + usage.input_tokens + usage.output_tokens,
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.input_tokens,
|
||||
usage.output_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
if (streamingResult.execution.output.providerTiming) {
|
||||
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
completion: tokens.completion,
|
||||
total: tokens.total,
|
||||
},
|
||||
toolCalls:
|
||||
toolCalls.length > 0
|
||||
? {
|
||||
list: toolCalls,
|
||||
count: toolCalls.length,
|
||||
}
|
||||
: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
isStreaming: true,
|
||||
},
|
||||
}
|
||||
|
||||
return streamingResult as StreamingExecution
|
||||
} catch (error) {
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
@@ -934,7 +990,7 @@ export const anthropicProvider: ProviderConfig = {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model || 'claude-3-7-sonnet-20250219',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
completion: tokens.completion,
|
||||
@@ -978,7 +1034,7 @@ export const anthropicProvider: ProviderConfig = {
|
||||
|
||||
return {
|
||||
content,
|
||||
model: request.model || 'claude-3-7-sonnet-20250219',
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls:
|
||||
toolCalls.length > 0
|
||||
|
||||
@@ -39,7 +39,7 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
request: ProviderRequest
|
||||
): Promise<ProviderResponse | StreamingExecution> => {
|
||||
logger.info('Preparing Azure OpenAI request', {
|
||||
model: request.model || 'azure/gpt-4o',
|
||||
model: request.model,
|
||||
hasSystemPrompt: !!request.systemPrompt,
|
||||
hasMessages: !!request.messages?.length,
|
||||
hasTools: !!request.tools?.length,
|
||||
@@ -95,7 +95,7 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
}))
|
||||
: undefined
|
||||
|
||||
const deploymentName = (request.model || 'azure/gpt-4o').replace('azure/', '')
|
||||
const deploymentName = request.model.replace('azure/', '')
|
||||
const payload: any = {
|
||||
model: deploymentName,
|
||||
messages: allMessages,
|
||||
|
||||
@@ -73,7 +73,7 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
: undefined
|
||||
|
||||
const payload: any = {
|
||||
model: (request.model || 'cerebras/llama-3.3-70b').replace('cerebras/', ''),
|
||||
model: request.model.replace('cerebras/', ''),
|
||||
messages: allMessages,
|
||||
}
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
@@ -145,7 +145,7 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model || 'cerebras/llama-3.3-70b',
|
||||
model: request.model,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
@@ -470,7 +470,7 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model || 'cerebras/llama-3.3-70b',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
completion: tokens.completion,
|
||||
|
||||
@@ -105,7 +105,7 @@ export const deepseekProvider: ProviderConfig = {
|
||||
: toolChoice.type === 'any'
|
||||
? `force:${toolChoice.any?.name || 'unknown'}`
|
||||
: 'unknown',
|
||||
model: request.model || 'deepseek-v3',
|
||||
model: request.model,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -145,7 +145,7 @@ export const deepseekProvider: ProviderConfig = {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model || 'deepseek-chat',
|
||||
model: request.model,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
@@ -469,7 +469,7 @@ export const deepseekProvider: ProviderConfig = {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model || 'deepseek-chat',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
completion: tokens.completion,
|
||||
|
||||
58
apps/sim/providers/gemini/client.ts
Normal file
58
apps/sim/providers/gemini/client.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import { GoogleGenAI } from '@google/genai'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { GeminiClientConfig } from './types'
|
||||
|
||||
const logger = createLogger('GeminiClient')
|
||||
|
||||
/**
|
||||
* Creates a GoogleGenAI client configured for either Google Gemini API or Vertex AI
|
||||
*
|
||||
* For Google Gemini API:
|
||||
* - Uses API key authentication
|
||||
*
|
||||
* For Vertex AI:
|
||||
* - Uses OAuth access token via HTTP Authorization header
|
||||
* - Requires project and location
|
||||
*/
|
||||
export function createGeminiClient(config: GeminiClientConfig): GoogleGenAI {
|
||||
if (config.vertexai) {
|
||||
if (!config.project) {
|
||||
throw new Error('Vertex AI requires a project ID')
|
||||
}
|
||||
if (!config.accessToken) {
|
||||
throw new Error('Vertex AI requires an access token')
|
||||
}
|
||||
|
||||
const location = config.location ?? 'us-central1'
|
||||
|
||||
logger.info('Creating Vertex AI client', {
|
||||
project: config.project,
|
||||
location,
|
||||
hasAccessToken: !!config.accessToken,
|
||||
})
|
||||
|
||||
// Create client with Vertex AI configuration
|
||||
// Use httpOptions.headers to pass the access token directly
|
||||
return new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project: config.project,
|
||||
location,
|
||||
httpOptions: {
|
||||
headers: {
|
||||
Authorization: `Bearer ${config.accessToken}`,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Google Gemini API with API key
|
||||
if (!config.apiKey) {
|
||||
throw new Error('Google Gemini API requires an API key')
|
||||
}
|
||||
|
||||
logger.info('Creating Google Gemini client')
|
||||
|
||||
return new GoogleGenAI({
|
||||
apiKey: config.apiKey,
|
||||
})
|
||||
}
|
||||
680
apps/sim/providers/gemini/core.ts
Normal file
680
apps/sim/providers/gemini/core.ts
Normal file
@@ -0,0 +1,680 @@
|
||||
import {
|
||||
type Content,
|
||||
FunctionCallingConfigMode,
|
||||
type FunctionDeclaration,
|
||||
type GenerateContentConfig,
|
||||
type GenerateContentResponse,
|
||||
type GoogleGenAI,
|
||||
type Part,
|
||||
type Schema,
|
||||
type ThinkingConfig,
|
||||
type ToolConfig,
|
||||
} from '@google/genai'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import {
|
||||
checkForForcedToolUsage,
|
||||
cleanSchemaForGemini,
|
||||
convertToGeminiFormat,
|
||||
convertUsageMetadata,
|
||||
createReadableStreamFromGeminiStream,
|
||||
extractFunctionCallPart,
|
||||
extractTextContent,
|
||||
mapToThinkingLevel,
|
||||
} from '@/providers/google/utils'
|
||||
import { getThinkingCapability } from '@/providers/models'
|
||||
import type { FunctionCallResponse, ProviderRequest, ProviderResponse } from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
import type { ExecutionState, GeminiProviderType, GeminiUsage, ParsedFunctionCall } from './types'
|
||||
|
||||
/**
|
||||
* Creates initial execution state
|
||||
*/
|
||||
function createInitialState(
|
||||
contents: Content[],
|
||||
initialUsage: GeminiUsage,
|
||||
firstResponseTime: number,
|
||||
initialCallTime: number,
|
||||
model: string,
|
||||
toolConfig: ToolConfig | undefined
|
||||
): ExecutionState {
|
||||
const initialCost = calculateCost(
|
||||
model,
|
||||
initialUsage.promptTokenCount,
|
||||
initialUsage.candidatesTokenCount
|
||||
)
|
||||
|
||||
return {
|
||||
contents,
|
||||
tokens: {
|
||||
prompt: initialUsage.promptTokenCount,
|
||||
completion: initialUsage.candidatesTokenCount,
|
||||
total: initialUsage.totalTokenCount,
|
||||
},
|
||||
cost: initialCost,
|
||||
toolCalls: [],
|
||||
toolResults: [],
|
||||
iterationCount: 0,
|
||||
modelTime: firstResponseTime,
|
||||
toolsTime: 0,
|
||||
timeSegments: [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Initial response',
|
||||
startTime: initialCallTime,
|
||||
endTime: initialCallTime + firstResponseTime,
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
],
|
||||
usedForcedTools: [],
|
||||
currentToolConfig: toolConfig,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes a tool call and updates state
|
||||
*/
|
||||
async function executeToolCall(
|
||||
functionCallPart: Part,
|
||||
functionCall: ParsedFunctionCall,
|
||||
request: ProviderRequest,
|
||||
state: ExecutionState,
|
||||
forcedTools: string[],
|
||||
logger: ReturnType<typeof createLogger>
|
||||
): Promise<{ success: boolean; state: ExecutionState }> {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = functionCall.name
|
||||
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) {
|
||||
logger.warn(`Tool ${toolName} not found in registry, skipping`)
|
||||
return { success: false, state }
|
||||
}
|
||||
|
||||
try {
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, functionCall.args, request)
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const duration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
const resultContent: Record<string, unknown> = result.success
|
||||
? (result.output as Record<string, unknown>)
|
||||
: { error: true, message: result.error || 'Tool execution failed', tool: toolName }
|
||||
|
||||
const toolCall: FunctionCallResponse = {
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration,
|
||||
result: resultContent,
|
||||
}
|
||||
|
||||
const updatedContents: Content[] = [
|
||||
...state.contents,
|
||||
{
|
||||
role: 'model',
|
||||
parts: [functionCallPart],
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: functionCall.name,
|
||||
response: resultContent,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const forcedToolCheck = checkForForcedToolUsage(
|
||||
[{ name: functionCall.name, args: functionCall.args }],
|
||||
state.currentToolConfig,
|
||||
forcedTools,
|
||||
state.usedForcedTools
|
||||
)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
state: {
|
||||
...state,
|
||||
contents: updatedContents,
|
||||
toolCalls: [...state.toolCalls, toolCall],
|
||||
toolResults: result.success
|
||||
? [...state.toolResults, result.output as Record<string, unknown>]
|
||||
: state.toolResults,
|
||||
toolsTime: state.toolsTime + duration,
|
||||
timeSegments: [
|
||||
...state.timeSegments,
|
||||
{
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration,
|
||||
},
|
||||
],
|
||||
usedForcedTools: forcedToolCheck?.usedForcedTools ?? state.usedForcedTools,
|
||||
currentToolConfig: forcedToolCheck?.nextToolConfig ?? state.currentToolConfig,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error processing function call:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
functionName: toolName,
|
||||
})
|
||||
return { success: false, state }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates state with model response metadata
|
||||
*/
|
||||
function updateStateWithResponse(
|
||||
state: ExecutionState,
|
||||
response: GenerateContentResponse,
|
||||
model: string,
|
||||
startTime: number,
|
||||
endTime: number
|
||||
): ExecutionState {
|
||||
const usage = convertUsageMetadata(response.usageMetadata)
|
||||
const cost = calculateCost(model, usage.promptTokenCount, usage.candidatesTokenCount)
|
||||
const duration = endTime - startTime
|
||||
|
||||
return {
|
||||
...state,
|
||||
tokens: {
|
||||
prompt: state.tokens.prompt + usage.promptTokenCount,
|
||||
completion: state.tokens.completion + usage.candidatesTokenCount,
|
||||
total: state.tokens.total + usage.totalTokenCount,
|
||||
},
|
||||
cost: {
|
||||
input: state.cost.input + cost.input,
|
||||
output: state.cost.output + cost.output,
|
||||
total: state.cost.total + cost.total,
|
||||
pricing: cost.pricing, // Use latest pricing
|
||||
},
|
||||
modelTime: state.modelTime + duration,
|
||||
timeSegments: [
|
||||
...state.timeSegments,
|
||||
{
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${state.iterationCount + 1})`,
|
||||
startTime,
|
||||
endTime,
|
||||
duration,
|
||||
},
|
||||
],
|
||||
iterationCount: state.iterationCount + 1,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds config for next iteration
|
||||
*/
|
||||
function buildNextConfig(
|
||||
baseConfig: GenerateContentConfig,
|
||||
state: ExecutionState,
|
||||
forcedTools: string[],
|
||||
request: ProviderRequest,
|
||||
logger: ReturnType<typeof createLogger>
|
||||
): GenerateContentConfig {
|
||||
const nextConfig = { ...baseConfig }
|
||||
const allForcedToolsUsed =
|
||||
forcedTools.length > 0 && state.usedForcedTools.length === forcedTools.length
|
||||
|
||||
if (allForcedToolsUsed && request.responseFormat) {
|
||||
nextConfig.tools = undefined
|
||||
nextConfig.toolConfig = undefined
|
||||
nextConfig.responseMimeType = 'application/json'
|
||||
nextConfig.responseSchema = cleanSchemaForGemini(request.responseFormat.schema) as Schema
|
||||
logger.info('Using structured output for final response after tool execution')
|
||||
} else if (state.currentToolConfig) {
|
||||
nextConfig.toolConfig = state.currentToolConfig
|
||||
} else {
|
||||
nextConfig.toolConfig = { functionCallingConfig: { mode: FunctionCallingConfigMode.AUTO } }
|
||||
}
|
||||
|
||||
return nextConfig
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates streaming execution result template
|
||||
*/
|
||||
function createStreamingResult(
|
||||
providerStartTime: number,
|
||||
providerStartTimeISO: string,
|
||||
firstResponseTime: number,
|
||||
initialCallTime: number,
|
||||
state?: ExecutionState
|
||||
): StreamingExecution {
|
||||
return {
|
||||
stream: undefined as unknown as ReadableStream<Uint8Array>,
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: '',
|
||||
tokens: state?.tokens ?? { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: state?.toolCalls.length
|
||||
? { list: state.toolCalls, count: state.toolCalls.length }
|
||||
: undefined,
|
||||
toolResults: state?.toolResults,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
modelTime: state?.modelTime ?? firstResponseTime,
|
||||
toolsTime: state?.toolsTime ?? 0,
|
||||
firstResponseTime,
|
||||
iterations: (state?.iterationCount ?? 0) + 1,
|
||||
timeSegments: state?.timeSegments ?? [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Initial streaming response',
|
||||
startTime: initialCallTime,
|
||||
endTime: initialCallTime + firstResponseTime,
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
],
|
||||
},
|
||||
cost: state?.cost ?? {
|
||||
input: 0,
|
||||
output: 0,
|
||||
total: 0,
|
||||
pricing: { input: 0, output: 0, updatedAt: new Date().toISOString() },
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
isStreaming: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for executing a Gemini request
|
||||
*/
|
||||
export interface GeminiExecutionConfig {
|
||||
ai: GoogleGenAI
|
||||
model: string
|
||||
request: ProviderRequest
|
||||
providerType: GeminiProviderType
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes a request using the Gemini API
|
||||
*
|
||||
* This is the shared core logic for both Google and Vertex AI providers.
|
||||
* The only difference is how the GoogleGenAI client is configured.
|
||||
*/
|
||||
export async function executeGeminiRequest(
|
||||
config: GeminiExecutionConfig
|
||||
): Promise<ProviderResponse | StreamingExecution> {
|
||||
const { ai, model, request, providerType } = config
|
||||
const logger = createLogger(providerType === 'google' ? 'GoogleProvider' : 'VertexProvider')
|
||||
|
||||
logger.info(`Preparing ${providerType} Gemini request`, {
|
||||
model,
|
||||
hasSystemPrompt: !!request.systemPrompt,
|
||||
hasMessages: !!request.messages?.length,
|
||||
hasTools: !!request.tools?.length,
|
||||
toolCount: request.tools?.length ?? 0,
|
||||
hasResponseFormat: !!request.responseFormat,
|
||||
streaming: !!request.stream,
|
||||
})
|
||||
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
|
||||
|
||||
// Build configuration
|
||||
const geminiConfig: GenerateContentConfig = {}
|
||||
|
||||
if (request.temperature !== undefined) {
|
||||
geminiConfig.temperature = request.temperature
|
||||
}
|
||||
if (request.maxTokens !== undefined) {
|
||||
geminiConfig.maxOutputTokens = request.maxTokens
|
||||
}
|
||||
if (systemInstruction) {
|
||||
geminiConfig.systemInstruction = systemInstruction
|
||||
}
|
||||
|
||||
// Handle response format (only when no tools)
|
||||
if (request.responseFormat && !tools?.length) {
|
||||
geminiConfig.responseMimeType = 'application/json'
|
||||
geminiConfig.responseSchema = cleanSchemaForGemini(request.responseFormat.schema) as Schema
|
||||
logger.info('Using Gemini native structured output format')
|
||||
} else if (request.responseFormat && tools?.length) {
|
||||
logger.warn('Gemini does not support responseFormat with tools. Structured output ignored.')
|
||||
}
|
||||
|
||||
// Configure thinking for models that support it
|
||||
const thinkingCapability = getThinkingCapability(model)
|
||||
if (thinkingCapability) {
|
||||
const level = request.thinkingLevel ?? thinkingCapability.default ?? 'high'
|
||||
const thinkingConfig: ThinkingConfig = {
|
||||
includeThoughts: false,
|
||||
thinkingLevel: mapToThinkingLevel(level),
|
||||
}
|
||||
geminiConfig.thinkingConfig = thinkingConfig
|
||||
}
|
||||
|
||||
// Prepare tools
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
let toolConfig: ToolConfig | undefined
|
||||
|
||||
if (tools?.length) {
|
||||
const functionDeclarations: FunctionDeclaration[] = tools.map((t) => ({
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.parameters,
|
||||
}))
|
||||
|
||||
preparedTools = prepareToolsWithUsageControl(
|
||||
functionDeclarations,
|
||||
request.tools,
|
||||
logger,
|
||||
'google'
|
||||
)
|
||||
const { tools: filteredTools, toolConfig: preparedToolConfig } = preparedTools
|
||||
|
||||
if (filteredTools?.length) {
|
||||
geminiConfig.tools = [{ functionDeclarations: filteredTools as FunctionDeclaration[] }]
|
||||
|
||||
if (preparedToolConfig) {
|
||||
toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode:
|
||||
{
|
||||
AUTO: FunctionCallingConfigMode.AUTO,
|
||||
ANY: FunctionCallingConfigMode.ANY,
|
||||
NONE: FunctionCallingConfigMode.NONE,
|
||||
}[preparedToolConfig.functionCallingConfig.mode] ?? FunctionCallingConfigMode.AUTO,
|
||||
allowedFunctionNames: preparedToolConfig.functionCallingConfig.allowedFunctionNames,
|
||||
},
|
||||
}
|
||||
geminiConfig.toolConfig = toolConfig
|
||||
}
|
||||
|
||||
logger.info('Gemini request with tools:', {
|
||||
toolCount: filteredTools.length,
|
||||
model,
|
||||
tools: filteredTools.map((t) => (t as FunctionDeclaration).name),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const initialCallTime = Date.now()
|
||||
const shouldStream = request.stream && !tools?.length
|
||||
|
||||
// Streaming without tools
|
||||
if (shouldStream) {
|
||||
logger.info('Handling Gemini streaming response')
|
||||
|
||||
const streamGenerator = await ai.models.generateContentStream({
|
||||
model,
|
||||
contents,
|
||||
config: geminiConfig,
|
||||
})
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
const streamingResult = createStreamingResult(
|
||||
providerStartTime,
|
||||
providerStartTimeISO,
|
||||
firstResponseTime,
|
||||
initialCallTime
|
||||
)
|
||||
streamingResult.execution.output.model = model
|
||||
|
||||
streamingResult.stream = createReadableStreamFromGeminiStream(
|
||||
streamGenerator,
|
||||
(content: string, usage: GeminiUsage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.promptTokenCount,
|
||||
completion: usage.candidatesTokenCount,
|
||||
total: usage.totalTokenCount,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
model,
|
||||
usage.promptTokenCount,
|
||||
usage.candidatesTokenCount
|
||||
)
|
||||
streamingResult.execution.output.cost = costResult
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
if (streamingResult.execution.output.providerTiming) {
|
||||
streamingResult.execution.output.providerTiming.endTime = new Date(
|
||||
streamEndTime
|
||||
).toISOString()
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
const segments = streamingResult.execution.output.providerTiming.timeSegments
|
||||
if (segments?.[0]) {
|
||||
segments[0].endTime = streamEndTime
|
||||
segments[0].duration = streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return streamingResult
|
||||
}
|
||||
|
||||
// Non-streaming request
|
||||
const response = await ai.models.generateContent({ model, contents, config: geminiConfig })
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
// Check for UNEXPECTED_TOOL_CALL
|
||||
const candidate = response.candidates?.[0]
|
||||
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
|
||||
logger.warn('Gemini returned UNEXPECTED_TOOL_CALL - model attempted to call unknown tool')
|
||||
}
|
||||
|
||||
const initialUsage = convertUsageMetadata(response.usageMetadata)
|
||||
let state = createInitialState(
|
||||
contents,
|
||||
initialUsage,
|
||||
firstResponseTime,
|
||||
initialCallTime,
|
||||
model,
|
||||
toolConfig
|
||||
)
|
||||
const forcedTools = preparedTools?.forcedTools ?? []
|
||||
|
||||
let currentResponse = response
|
||||
let content = ''
|
||||
|
||||
// Tool execution loop
|
||||
const functionCalls = response.functionCalls
|
||||
if (functionCalls?.length) {
|
||||
logger.info(`Received function call from Gemini: ${functionCalls[0].name}`)
|
||||
|
||||
while (state.iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
const functionCallPart = extractFunctionCallPart(currentResponse.candidates?.[0])
|
||||
if (!functionCallPart?.functionCall) {
|
||||
content = extractTextContent(currentResponse.candidates?.[0])
|
||||
break
|
||||
}
|
||||
|
||||
const functionCall: ParsedFunctionCall = {
|
||||
name: functionCallPart.functionCall.name ?? '',
|
||||
args: (functionCallPart.functionCall.args ?? {}) as Record<string, unknown>,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing function call: ${functionCall.name} (iteration ${state.iterationCount + 1})`
|
||||
)
|
||||
|
||||
const { success, state: updatedState } = await executeToolCall(
|
||||
functionCallPart,
|
||||
functionCall,
|
||||
request,
|
||||
state,
|
||||
forcedTools,
|
||||
logger
|
||||
)
|
||||
if (!success) {
|
||||
content = extractTextContent(currentResponse.candidates?.[0])
|
||||
break
|
||||
}
|
||||
|
||||
state = { ...updatedState, iterationCount: updatedState.iterationCount + 1 }
|
||||
const nextConfig = buildNextConfig(geminiConfig, state, forcedTools, request, logger)
|
||||
|
||||
// Stream final response if requested
|
||||
if (request.stream) {
|
||||
const checkResponse = await ai.models.generateContent({
|
||||
model,
|
||||
contents: state.contents,
|
||||
config: nextConfig,
|
||||
})
|
||||
state = updateStateWithResponse(state, checkResponse, model, Date.now() - 100, Date.now())
|
||||
|
||||
if (checkResponse.functionCalls?.length) {
|
||||
currentResponse = checkResponse
|
||||
continue
|
||||
}
|
||||
|
||||
logger.info('No more function calls, streaming final response')
|
||||
|
||||
if (request.responseFormat) {
|
||||
nextConfig.tools = undefined
|
||||
nextConfig.toolConfig = undefined
|
||||
nextConfig.responseMimeType = 'application/json'
|
||||
nextConfig.responseSchema = cleanSchemaForGemini(
|
||||
request.responseFormat.schema
|
||||
) as Schema
|
||||
}
|
||||
|
||||
// Capture accumulated cost before streaming
|
||||
const accumulatedCost = {
|
||||
input: state.cost.input,
|
||||
output: state.cost.output,
|
||||
total: state.cost.total,
|
||||
}
|
||||
const accumulatedTokens = { ...state.tokens }
|
||||
|
||||
const streamGenerator = await ai.models.generateContentStream({
|
||||
model,
|
||||
contents: state.contents,
|
||||
config: nextConfig,
|
||||
})
|
||||
|
||||
const streamingResult = createStreamingResult(
|
||||
providerStartTime,
|
||||
providerStartTimeISO,
|
||||
firstResponseTime,
|
||||
initialCallTime,
|
||||
state
|
||||
)
|
||||
streamingResult.execution.output.model = model
|
||||
|
||||
streamingResult.stream = createReadableStreamFromGeminiStream(
|
||||
streamGenerator,
|
||||
(streamContent: string, usage: GeminiUsage) => {
|
||||
streamingResult.execution.output.content = streamContent
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: accumulatedTokens.prompt + usage.promptTokenCount,
|
||||
completion: accumulatedTokens.completion + usage.candidatesTokenCount,
|
||||
total: accumulatedTokens.total + usage.totalTokenCount,
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(
|
||||
model,
|
||||
usage.promptTokenCount,
|
||||
usage.candidatesTokenCount
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
pricing: streamCost.pricing,
|
||||
}
|
||||
|
||||
if (streamingResult.execution.output.providerTiming) {
|
||||
streamingResult.execution.output.providerTiming.endTime = new Date().toISOString()
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
Date.now() - providerStartTime
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return streamingResult
|
||||
}
|
||||
|
||||
// Non-streaming: get next response
|
||||
const nextModelStartTime = Date.now()
|
||||
const nextResponse = await ai.models.generateContent({
|
||||
model,
|
||||
contents: state.contents,
|
||||
config: nextConfig,
|
||||
})
|
||||
state = updateStateWithResponse(state, nextResponse, model, nextModelStartTime, Date.now())
|
||||
currentResponse = nextResponse
|
||||
}
|
||||
|
||||
if (!content) {
|
||||
content = extractTextContent(currentResponse.candidates?.[0])
|
||||
}
|
||||
} else {
|
||||
content = extractTextContent(candidate)
|
||||
}
|
||||
|
||||
const providerEndTime = Date.now()
|
||||
|
||||
return {
|
||||
content,
|
||||
model,
|
||||
tokens: state.tokens,
|
||||
toolCalls: state.toolCalls.length ? state.toolCalls : undefined,
|
||||
toolResults: state.toolResults.length ? state.toolResults : undefined,
|
||||
timing: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date(providerEndTime).toISOString(),
|
||||
duration: providerEndTime - providerStartTime,
|
||||
modelTime: state.modelTime,
|
||||
toolsTime: state.toolsTime,
|
||||
firstResponseTime,
|
||||
iterations: state.iterationCount + 1,
|
||||
timeSegments: state.timeSegments,
|
||||
},
|
||||
cost: state.cost,
|
||||
}
|
||||
} catch (error) {
|
||||
const providerEndTime = Date.now()
|
||||
const duration = providerEndTime - providerStartTime
|
||||
|
||||
logger.error('Error in Gemini request:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
stack: error instanceof Error ? error.stack : undefined,
|
||||
})
|
||||
|
||||
const enhancedError = error instanceof Error ? error : new Error(String(error))
|
||||
Object.assign(enhancedError, {
|
||||
timing: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date(providerEndTime).toISOString(),
|
||||
duration,
|
||||
},
|
||||
})
|
||||
throw enhancedError
|
||||
}
|
||||
}
|
||||
18
apps/sim/providers/gemini/index.ts
Normal file
18
apps/sim/providers/gemini/index.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
/**
|
||||
* Shared Gemini execution core
|
||||
*
|
||||
* This module provides the shared execution logic for both Google Gemini API
|
||||
* and Vertex AI providers. The only difference between providers is how the
|
||||
* GoogleGenAI client is configured (API key vs OAuth).
|
||||
*/
|
||||
|
||||
export { createGeminiClient } from './client'
|
||||
export { executeGeminiRequest, type GeminiExecutionConfig } from './core'
|
||||
export type {
|
||||
ExecutionState,
|
||||
ForcedToolResult,
|
||||
GeminiClientConfig,
|
||||
GeminiProviderType,
|
||||
GeminiUsage,
|
||||
ParsedFunctionCall,
|
||||
} from './types'
|
||||
64
apps/sim/providers/gemini/types.ts
Normal file
64
apps/sim/providers/gemini/types.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
import type { Content, ToolConfig } from '@google/genai'
|
||||
import type { FunctionCallResponse, ModelPricing, TimeSegment } from '@/providers/types'
|
||||
|
||||
/**
|
||||
* Usage metadata from Gemini responses
|
||||
*/
|
||||
export interface GeminiUsage {
|
||||
promptTokenCount: number
|
||||
candidatesTokenCount: number
|
||||
totalTokenCount: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Parsed function call from Gemini response
|
||||
*/
|
||||
export interface ParsedFunctionCall {
|
||||
name: string
|
||||
args: Record<string, unknown>
|
||||
}
|
||||
|
||||
/**
|
||||
* Accumulated state during tool execution loop
|
||||
*/
|
||||
export interface ExecutionState {
|
||||
contents: Content[]
|
||||
tokens: { prompt: number; completion: number; total: number }
|
||||
cost: { input: number; output: number; total: number; pricing: ModelPricing }
|
||||
toolCalls: FunctionCallResponse[]
|
||||
toolResults: Record<string, unknown>[]
|
||||
iterationCount: number
|
||||
modelTime: number
|
||||
toolsTime: number
|
||||
timeSegments: TimeSegment[]
|
||||
usedForcedTools: string[]
|
||||
currentToolConfig: ToolConfig | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Result from forced tool usage check
|
||||
*/
|
||||
export interface ForcedToolResult {
|
||||
hasUsedForcedTool: boolean
|
||||
usedForcedTools: string[]
|
||||
nextToolConfig: ToolConfig | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for creating a Gemini client
|
||||
*/
|
||||
export interface GeminiClientConfig {
|
||||
/** For Google Gemini API */
|
||||
apiKey?: string
|
||||
/** For Vertex AI */
|
||||
vertexai?: boolean
|
||||
project?: string
|
||||
location?: string
|
||||
/** OAuth access token for Vertex AI */
|
||||
accessToken?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider type for logging and model lookup
|
||||
*/
|
||||
export type GeminiProviderType = 'google' | 'vertex'
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,61 +1,89 @@
|
||||
import type { Candidate } from '@google/genai'
|
||||
import {
|
||||
type Candidate,
|
||||
type Content,
|
||||
type FunctionCall,
|
||||
FunctionCallingConfigMode,
|
||||
type GenerateContentResponse,
|
||||
type GenerateContentResponseUsageMetadata,
|
||||
type Part,
|
||||
type Schema,
|
||||
type SchemaUnion,
|
||||
ThinkingLevel,
|
||||
type ToolConfig,
|
||||
Type,
|
||||
} from '@google/genai'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { ProviderRequest } from '@/providers/types'
|
||||
import { trackForcedToolUsage } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('GoogleUtils')
|
||||
|
||||
/**
|
||||
* Usage metadata for Google Gemini responses
|
||||
*/
|
||||
export interface GeminiUsage {
|
||||
promptTokenCount: number
|
||||
candidatesTokenCount: number
|
||||
totalTokenCount: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Parsed function call from Gemini response
|
||||
*/
|
||||
export interface ParsedFunctionCall {
|
||||
name: string
|
||||
args: Record<string, unknown>
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes additionalProperties from a schema object (not supported by Gemini)
|
||||
*/
|
||||
export function cleanSchemaForGemini(schema: any): any {
|
||||
export function cleanSchemaForGemini(schema: SchemaUnion): SchemaUnion {
|
||||
if (schema === null || schema === undefined) return schema
|
||||
if (typeof schema !== 'object') return schema
|
||||
if (Array.isArray(schema)) {
|
||||
return schema.map((item) => cleanSchemaForGemini(item))
|
||||
}
|
||||
|
||||
const cleanedSchema: any = {}
|
||||
const cleanedSchema: Record<string, unknown> = {}
|
||||
const schemaObj = schema as Record<string, unknown>
|
||||
|
||||
for (const key in schema) {
|
||||
for (const key in schemaObj) {
|
||||
if (key === 'additionalProperties') continue
|
||||
cleanedSchema[key] = cleanSchemaForGemini(schema[key])
|
||||
cleanedSchema[key] = cleanSchemaForGemini(schemaObj[key] as SchemaUnion)
|
||||
}
|
||||
|
||||
return cleanedSchema
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts text content from a Gemini response candidate, handling structured output
|
||||
* Extracts text content from a Gemini response candidate.
|
||||
* Filters out thought parts (model reasoning) from the output.
|
||||
*/
|
||||
export function extractTextContent(candidate: Candidate | undefined): string {
|
||||
if (!candidate?.content?.parts) return ''
|
||||
|
||||
if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) {
|
||||
const text = candidate.content.parts[0].text
|
||||
if (text && (text.trim().startsWith('{') || text.trim().startsWith('['))) {
|
||||
try {
|
||||
JSON.parse(text)
|
||||
return text
|
||||
} catch (_e) {}
|
||||
}
|
||||
}
|
||||
const textParts = candidate.content.parts.filter(
|
||||
(part): part is Part & { text: string } => Boolean(part.text) && part.thought !== true
|
||||
)
|
||||
|
||||
return candidate.content.parts
|
||||
.filter((part: any) => part.text)
|
||||
.map((part: any) => part.text)
|
||||
.join('\n')
|
||||
if (textParts.length === 0) return ''
|
||||
if (textParts.length === 1) return textParts[0].text
|
||||
|
||||
return textParts.map((part) => part.text).join('\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a function call from a Gemini response candidate
|
||||
* Extracts the first function call from a Gemini response candidate
|
||||
*/
|
||||
export function extractFunctionCall(
|
||||
candidate: Candidate | undefined
|
||||
): { name: string; args: any } | null {
|
||||
export function extractFunctionCall(candidate: Candidate | undefined): ParsedFunctionCall | null {
|
||||
if (!candidate?.content?.parts) return null
|
||||
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.functionCall) {
|
||||
return {
|
||||
name: part.functionCall.name ?? '',
|
||||
args: part.functionCall.args ?? {},
|
||||
args: (part.functionCall.args ?? {}) as Record<string, unknown>,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,16 +91,55 @@ export function extractFunctionCall(
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the full Part containing the function call (preserves thoughtSignature)
|
||||
*/
|
||||
export function extractFunctionCallPart(candidate: Candidate | undefined): Part | null {
|
||||
if (!candidate?.content?.parts) return null
|
||||
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.functionCall) {
|
||||
return part
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts usage metadata from SDK response to our format
|
||||
*/
|
||||
export function convertUsageMetadata(
|
||||
usageMetadata: GenerateContentResponseUsageMetadata | undefined
|
||||
): GeminiUsage {
|
||||
const promptTokenCount = usageMetadata?.promptTokenCount ?? 0
|
||||
const candidatesTokenCount = usageMetadata?.candidatesTokenCount ?? 0
|
||||
return {
|
||||
promptTokenCount,
|
||||
candidatesTokenCount,
|
||||
totalTokenCount: usageMetadata?.totalTokenCount ?? promptTokenCount + candidatesTokenCount,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tool definition for Gemini format
|
||||
*/
|
||||
export interface GeminiToolDef {
|
||||
name: string
|
||||
description: string
|
||||
parameters: Schema
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts OpenAI-style request format to Gemini format
|
||||
*/
|
||||
export function convertToGeminiFormat(request: ProviderRequest): {
|
||||
contents: any[]
|
||||
tools: any[] | undefined
|
||||
systemInstruction: any | undefined
|
||||
contents: Content[]
|
||||
tools: GeminiToolDef[] | undefined
|
||||
systemInstruction: Content | undefined
|
||||
} {
|
||||
const contents: any[] = []
|
||||
let systemInstruction
|
||||
const contents: Content[] = []
|
||||
let systemInstruction: Content | undefined
|
||||
|
||||
if (request.systemPrompt) {
|
||||
systemInstruction = { parts: [{ text: request.systemPrompt }] }
|
||||
@@ -82,13 +149,13 @@ export function convertToGeminiFormat(request: ProviderRequest): {
|
||||
contents.push({ role: 'user', parts: [{ text: request.context }] })
|
||||
}
|
||||
|
||||
if (request.messages && request.messages.length > 0) {
|
||||
if (request.messages?.length) {
|
||||
for (const message of request.messages) {
|
||||
if (message.role === 'system') {
|
||||
if (!systemInstruction) {
|
||||
systemInstruction = { parts: [{ text: message.content }] }
|
||||
} else {
|
||||
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text || ''}\n${message.content}`
|
||||
systemInstruction = { parts: [{ text: message.content ?? '' }] }
|
||||
} else if (systemInstruction.parts?.[0] && 'text' in systemInstruction.parts[0]) {
|
||||
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text}\n${message.content}`
|
||||
}
|
||||
} else if (message.role === 'user' || message.role === 'assistant') {
|
||||
const geminiRole = message.role === 'user' ? 'user' : 'model'
|
||||
@@ -97,60 +164,200 @@ export function convertToGeminiFormat(request: ProviderRequest): {
|
||||
contents.push({ role: geminiRole, parts: [{ text: message.content }] })
|
||||
}
|
||||
|
||||
if (message.role === 'assistant' && message.tool_calls && message.tool_calls.length > 0) {
|
||||
if (message.role === 'assistant' && message.tool_calls?.length) {
|
||||
const functionCalls = message.tool_calls.map((toolCall) => ({
|
||||
functionCall: {
|
||||
name: toolCall.function?.name,
|
||||
args: JSON.parse(toolCall.function?.arguments || '{}'),
|
||||
args: JSON.parse(toolCall.function?.arguments || '{}') as Record<string, unknown>,
|
||||
},
|
||||
}))
|
||||
|
||||
contents.push({ role: 'model', parts: functionCalls })
|
||||
}
|
||||
} else if (message.role === 'tool') {
|
||||
if (!message.name) {
|
||||
logger.warn('Tool message missing function name, skipping')
|
||||
continue
|
||||
}
|
||||
let responseData: Record<string, unknown>
|
||||
try {
|
||||
responseData = JSON.parse(message.content ?? '{}')
|
||||
} catch {
|
||||
responseData = { output: message.content }
|
||||
}
|
||||
contents.push({
|
||||
role: 'user',
|
||||
parts: [{ text: `Function result: ${message.content}` }],
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
id: message.tool_call_id,
|
||||
name: message.name,
|
||||
response: responseData,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const tools = request.tools?.map((tool) => {
|
||||
const tools = request.tools?.map((tool): GeminiToolDef => {
|
||||
const toolParameters = { ...(tool.parameters || {}) }
|
||||
|
||||
if (toolParameters.properties) {
|
||||
const properties = { ...toolParameters.properties }
|
||||
const required = toolParameters.required ? [...toolParameters.required] : []
|
||||
|
||||
// Remove default values from properties (not supported by Gemini)
|
||||
for (const key in properties) {
|
||||
const prop = properties[key] as any
|
||||
|
||||
const prop = properties[key] as Record<string, unknown>
|
||||
if (prop.default !== undefined) {
|
||||
const { default: _, ...cleanProp } = prop
|
||||
properties[key] = cleanProp
|
||||
}
|
||||
}
|
||||
|
||||
const parameters = {
|
||||
type: toolParameters.type || 'object',
|
||||
properties,
|
||||
const parameters: Schema = {
|
||||
type: (toolParameters.type as Schema['type']) || Type.OBJECT,
|
||||
properties: properties as Record<string, Schema>,
|
||||
...(required.length > 0 ? { required } : {}),
|
||||
}
|
||||
|
||||
return {
|
||||
name: tool.id,
|
||||
description: tool.description || `Execute the ${tool.id} function`,
|
||||
parameters: cleanSchemaForGemini(parameters),
|
||||
parameters: cleanSchemaForGemini(parameters) as Schema,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
name: tool.id,
|
||||
description: tool.description || `Execute the ${tool.id} function`,
|
||||
parameters: cleanSchemaForGemini(toolParameters),
|
||||
parameters: cleanSchemaForGemini(toolParameters) as Schema,
|
||||
}
|
||||
})
|
||||
|
||||
return { contents, tools, systemInstruction }
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from a Google Gemini streaming response
|
||||
*/
|
||||
export function createReadableStreamFromGeminiStream(
|
||||
stream: AsyncGenerator<GenerateContentResponse>,
|
||||
onComplete?: (content: string, usage: GeminiUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
let fullContent = ''
|
||||
let usage: GeminiUsage = { promptTokenCount: 0, candidatesTokenCount: 0, totalTokenCount: 0 }
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
if (chunk.usageMetadata) {
|
||||
usage = convertUsageMetadata(chunk.usageMetadata)
|
||||
}
|
||||
|
||||
const text = chunk.text
|
||||
if (text) {
|
||||
fullContent += text
|
||||
controller.enqueue(new TextEncoder().encode(text))
|
||||
}
|
||||
}
|
||||
|
||||
onComplete?.(fullContent, usage)
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
logger.error('Error reading Google Gemini stream', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps string mode to FunctionCallingConfigMode enum
|
||||
*/
|
||||
function mapToFunctionCallingMode(mode: string): FunctionCallingConfigMode {
|
||||
switch (mode) {
|
||||
case 'AUTO':
|
||||
return FunctionCallingConfigMode.AUTO
|
||||
case 'ANY':
|
||||
return FunctionCallingConfigMode.ANY
|
||||
case 'NONE':
|
||||
return FunctionCallingConfigMode.NONE
|
||||
default:
|
||||
return FunctionCallingConfigMode.AUTO
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps string level to ThinkingLevel enum
|
||||
*/
|
||||
export function mapToThinkingLevel(level: string): ThinkingLevel {
|
||||
switch (level.toLowerCase()) {
|
||||
case 'minimal':
|
||||
return ThinkingLevel.MINIMAL
|
||||
case 'low':
|
||||
return ThinkingLevel.LOW
|
||||
case 'medium':
|
||||
return ThinkingLevel.MEDIUM
|
||||
case 'high':
|
||||
return ThinkingLevel.HIGH
|
||||
default:
|
||||
return ThinkingLevel.HIGH
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of checking forced tool usage
|
||||
*/
|
||||
export interface ForcedToolResult {
|
||||
hasUsedForcedTool: boolean
|
||||
usedForcedTools: string[]
|
||||
nextToolConfig: ToolConfig | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks for forced tool usage in Google Gemini responses
|
||||
*/
|
||||
export function checkForForcedToolUsage(
|
||||
functionCalls: FunctionCall[] | undefined,
|
||||
toolConfig: ToolConfig | undefined,
|
||||
forcedTools: string[],
|
||||
usedForcedTools: string[]
|
||||
): ForcedToolResult | null {
|
||||
if (!functionCalls?.length) return null
|
||||
|
||||
const adaptedToolCalls = functionCalls.map((fc) => ({
|
||||
name: fc.name ?? '',
|
||||
arguments: (fc.args ?? {}) as Record<string, unknown>,
|
||||
}))
|
||||
|
||||
const result = trackForcedToolUsage(
|
||||
adaptedToolCalls,
|
||||
toolConfig,
|
||||
logger,
|
||||
'google',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
|
||||
if (!result) return null
|
||||
|
||||
const nextToolConfig: ToolConfig | undefined = result.nextToolConfig?.functionCallingConfig?.mode
|
||||
? {
|
||||
functionCallingConfig: {
|
||||
mode: mapToFunctionCallingMode(result.nextToolConfig.functionCallingConfig.mode),
|
||||
allowedFunctionNames: result.nextToolConfig.functionCallingConfig.allowedFunctionNames,
|
||||
},
|
||||
}
|
||||
: undefined
|
||||
|
||||
return {
|
||||
hasUsedForcedTool: result.hasUsedForcedTool,
|
||||
usedForcedTools: result.usedForcedTools,
|
||||
nextToolConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,10 +69,7 @@ export const groqProvider: ProviderConfig = {
|
||||
: undefined
|
||||
|
||||
const payload: any = {
|
||||
model: (request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct').replace(
|
||||
'groq/',
|
||||
''
|
||||
),
|
||||
model: request.model.replace('groq/', ''),
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
@@ -109,7 +106,7 @@ export const groqProvider: ProviderConfig = {
|
||||
toolChoice: payload.tool_choice,
|
||||
forcedToolsCount: forcedTools.length,
|
||||
hasFilteredTools,
|
||||
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
|
||||
model: request.model,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -149,7 +146,7 @@ export const groqProvider: ProviderConfig = {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
|
||||
model: request.model,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
@@ -393,7 +390,7 @@ export const groqProvider: ProviderConfig = {
|
||||
const streamingPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto',
|
||||
tool_choice: originalToolChoice || 'auto',
|
||||
stream: true,
|
||||
}
|
||||
|
||||
@@ -425,7 +422,7 @@ export const groqProvider: ProviderConfig = {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
completion: tokens.completion,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { getCostMultiplier } from '@/lib/core/config/feature-flags'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import type { ProviderRequest, ProviderResponse } from '@/providers/types'
|
||||
import { getProviderExecutor } from '@/providers/registry'
|
||||
import type { ProviderId, ProviderRequest, ProviderResponse } from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
generateStructuredOutputInstructions,
|
||||
getProvider,
|
||||
shouldBillModelUsage,
|
||||
supportsTemperature,
|
||||
} from '@/providers/utils'
|
||||
@@ -40,7 +40,7 @@ export async function executeProviderRequest(
|
||||
providerId: string,
|
||||
request: ProviderRequest
|
||||
): Promise<ProviderResponse | ReadableStream | StreamingExecution> {
|
||||
const provider = getProvider(providerId)
|
||||
const provider = await getProviderExecutor(providerId as ProviderId)
|
||||
if (!provider) {
|
||||
throw new Error(`Provider not found: ${providerId}`)
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ export const mistralProvider: ProviderConfig = {
|
||||
request: ProviderRequest
|
||||
): Promise<ProviderResponse | StreamingExecution> => {
|
||||
logger.info('Preparing Mistral request', {
|
||||
model: request.model || 'mistral-large-latest',
|
||||
model: request.model,
|
||||
hasSystemPrompt: !!request.systemPrompt,
|
||||
hasMessages: !!request.messages?.length,
|
||||
hasTools: !!request.tools?.length,
|
||||
@@ -86,7 +86,7 @@ export const mistralProvider: ProviderConfig = {
|
||||
: undefined
|
||||
|
||||
const payload: any = {
|
||||
model: request.model || 'mistral-large-latest',
|
||||
model: request.model,
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ export const mistralProvider: ProviderConfig = {
|
||||
: toolChoice.type === 'any'
|
||||
? `force:${toolChoice.any?.name || 'unknown'}`
|
||||
: 'unknown',
|
||||
model: request.model || 'mistral-large-latest',
|
||||
model: request.model,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,6 +39,10 @@ export interface ModelCapabilities {
|
||||
verbosity?: {
|
||||
values: string[]
|
||||
}
|
||||
thinking?: {
|
||||
levels: string[]
|
||||
default?: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface ModelDefinition {
|
||||
@@ -730,6 +734,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
thinking: {
|
||||
levels: ['low', 'high'],
|
||||
default: 'high',
|
||||
},
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
},
|
||||
@@ -743,6 +751,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
thinking: {
|
||||
levels: ['minimal', 'low', 'medium', 'high'],
|
||||
default: 'high',
|
||||
},
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
},
|
||||
@@ -832,6 +844,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
thinking: {
|
||||
levels: ['low', 'high'],
|
||||
default: 'high',
|
||||
},
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
},
|
||||
@@ -845,6 +861,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 2 },
|
||||
thinking: {
|
||||
levels: ['minimal', 'low', 'medium', 'high'],
|
||||
default: 'high',
|
||||
},
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
},
|
||||
@@ -1864,3 +1884,49 @@ export function supportsNativeStructuredOutputs(modelId: string): boolean {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model supports thinking/reasoning features.
|
||||
* Returns the thinking capability config if supported, null otherwise.
|
||||
*/
|
||||
export function getThinkingCapability(
|
||||
modelId: string
|
||||
): { levels: string[]; default?: string } | null {
|
||||
const normalizedModelId = modelId.toLowerCase()
|
||||
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
for (const model of provider.models) {
|
||||
if (model.capabilities.thinking) {
|
||||
const baseModelId = model.id.toLowerCase()
|
||||
if (normalizedModelId === baseModelId || normalizedModelId.startsWith(`${baseModelId}-`)) {
|
||||
return model.capabilities.thinking
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models that support thinking capability
|
||||
*/
|
||||
export function getModelsWithThinking(): string[] {
|
||||
const models: string[] = []
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
for (const model of provider.models) {
|
||||
if (model.capabilities.thinking) {
|
||||
models.push(model.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the thinking levels for a specific model
|
||||
* Returns the valid levels for that model, or null if the model doesn't support thinking
|
||||
*/
|
||||
export function getThinkingLevelsForModel(modelId: string): string[] | null {
|
||||
const capability = getThinkingCapability(modelId)
|
||||
return capability?.levels ?? null
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ export const openaiProvider: ProviderConfig = {
|
||||
request: ProviderRequest
|
||||
): Promise<ProviderResponse | StreamingExecution> => {
|
||||
logger.info('Preparing OpenAI request', {
|
||||
model: request.model || 'gpt-4o',
|
||||
model: request.model,
|
||||
hasSystemPrompt: !!request.systemPrompt,
|
||||
hasMessages: !!request.messages?.length,
|
||||
hasTools: !!request.tools?.length,
|
||||
@@ -76,7 +76,7 @@ export const openaiProvider: ProviderConfig = {
|
||||
: undefined
|
||||
|
||||
const payload: any = {
|
||||
model: request.model || 'gpt-4o',
|
||||
model: request.model,
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ export const openaiProvider: ProviderConfig = {
|
||||
: toolChoice.type === 'any'
|
||||
? `force:${toolChoice.any?.name || 'unknown'}`
|
||||
: 'unknown',
|
||||
model: request.model || 'gpt-4o',
|
||||
model: request.model,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ export const openRouterProvider: ProviderConfig = {
|
||||
baseURL: 'https://openrouter.ai/api/v1',
|
||||
})
|
||||
|
||||
const requestedModel = (request.model || '').replace(/^openrouter\//, '')
|
||||
const requestedModel = request.model.replace(/^openrouter\//, '')
|
||||
|
||||
logger.info('Preparing OpenRouter request', {
|
||||
model: requestedModel,
|
||||
|
||||
59
apps/sim/providers/registry.ts
Normal file
59
apps/sim/providers/registry.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { anthropicProvider } from '@/providers/anthropic'
|
||||
import { azureOpenAIProvider } from '@/providers/azure-openai'
|
||||
import { cerebrasProvider } from '@/providers/cerebras'
|
||||
import { deepseekProvider } from '@/providers/deepseek'
|
||||
import { googleProvider } from '@/providers/google'
|
||||
import { groqProvider } from '@/providers/groq'
|
||||
import { mistralProvider } from '@/providers/mistral'
|
||||
import { ollamaProvider } from '@/providers/ollama'
|
||||
import { openaiProvider } from '@/providers/openai'
|
||||
import { openRouterProvider } from '@/providers/openrouter'
|
||||
import type { ProviderConfig, ProviderId } from '@/providers/types'
|
||||
import { vertexProvider } from '@/providers/vertex'
|
||||
import { vllmProvider } from '@/providers/vllm'
|
||||
import { xAIProvider } from '@/providers/xai'
|
||||
|
||||
const logger = createLogger('ProviderRegistry')
|
||||
|
||||
const providerRegistry: Record<ProviderId, ProviderConfig> = {
|
||||
openai: openaiProvider,
|
||||
anthropic: anthropicProvider,
|
||||
google: googleProvider,
|
||||
vertex: vertexProvider,
|
||||
deepseek: deepseekProvider,
|
||||
xai: xAIProvider,
|
||||
cerebras: cerebrasProvider,
|
||||
groq: groqProvider,
|
||||
vllm: vllmProvider,
|
||||
mistral: mistralProvider,
|
||||
'azure-openai': azureOpenAIProvider,
|
||||
openrouter: openRouterProvider,
|
||||
ollama: ollamaProvider,
|
||||
}
|
||||
|
||||
export async function getProviderExecutor(
|
||||
providerId: ProviderId
|
||||
): Promise<ProviderConfig | undefined> {
|
||||
const provider = providerRegistry[providerId]
|
||||
if (!provider) {
|
||||
logger.error(`Provider not found: ${providerId}`)
|
||||
return undefined
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
export async function initializeProviders(): Promise<void> {
|
||||
for (const [id, provider] of Object.entries(providerRegistry)) {
|
||||
if (provider.initialize) {
|
||||
try {
|
||||
await provider.initialize()
|
||||
logger.info(`Initialized provider: ${id}`)
|
||||
} catch (error) {
|
||||
logger.error(`Failed to initialize ${id} provider`, {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -164,6 +164,7 @@ export interface ProviderRequest {
|
||||
vertexLocation?: string
|
||||
reasoningEffort?: string
|
||||
verbosity?: string
|
||||
thinkingLevel?: string
|
||||
}
|
||||
|
||||
export const providers: Record<string, ProviderConfig> = {}
|
||||
|
||||
@@ -3,13 +3,6 @@ import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import { getEnv, isTruthy } from '@/lib/core/config/env'
|
||||
import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import { createLogger, type Logger } from '@/lib/logs/console/logger'
|
||||
import { anthropicProvider } from '@/providers/anthropic'
|
||||
import { azureOpenAIProvider } from '@/providers/azure-openai'
|
||||
import { cerebrasProvider } from '@/providers/cerebras'
|
||||
import { deepseekProvider } from '@/providers/deepseek'
|
||||
import { googleProvider } from '@/providers/google'
|
||||
import { groqProvider } from '@/providers/groq'
|
||||
import { mistralProvider } from '@/providers/mistral'
|
||||
import {
|
||||
getComputerUseModels,
|
||||
getEmbeddingModelPricing,
|
||||
@@ -20,117 +13,82 @@ import {
|
||||
getModelsWithTemperatureSupport,
|
||||
getModelsWithTempRange01,
|
||||
getModelsWithTempRange02,
|
||||
getModelsWithThinking,
|
||||
getModelsWithVerbosity,
|
||||
getProviderDefaultModel as getProviderDefaultModelFromDefinitions,
|
||||
getProviderModels as getProviderModelsFromDefinitions,
|
||||
getProvidersWithToolUsageControl,
|
||||
getReasoningEffortValuesForModel as getReasoningEffortValuesForModelFromDefinitions,
|
||||
getThinkingLevelsForModel as getThinkingLevelsForModelFromDefinitions,
|
||||
getVerbosityValuesForModel as getVerbosityValuesForModelFromDefinitions,
|
||||
PROVIDER_DEFINITIONS,
|
||||
supportsTemperature as supportsTemperatureFromDefinitions,
|
||||
supportsToolUsageControl as supportsToolUsageControlFromDefinitions,
|
||||
updateOllamaModels as updateOllamaModelsInDefinitions,
|
||||
} from '@/providers/models'
|
||||
import { ollamaProvider } from '@/providers/ollama'
|
||||
import { openaiProvider } from '@/providers/openai'
|
||||
import { openRouterProvider } from '@/providers/openrouter'
|
||||
import type { ProviderConfig, ProviderId, ProviderToolConfig } from '@/providers/types'
|
||||
import { vertexProvider } from '@/providers/vertex'
|
||||
import { vllmProvider } from '@/providers/vllm'
|
||||
import { xAIProvider } from '@/providers/xai'
|
||||
import type { ProviderId, ProviderToolConfig } from '@/providers/types'
|
||||
import { useCustomToolsStore } from '@/stores/custom-tools/store'
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
|
||||
const logger = createLogger('ProviderUtils')
|
||||
|
||||
export const providers: Record<
|
||||
ProviderId,
|
||||
ProviderConfig & {
|
||||
models: string[]
|
||||
computerUseModels?: string[]
|
||||
modelPatterns?: RegExp[]
|
||||
/**
|
||||
* Client-safe provider metadata.
|
||||
* This object contains only model lists and patterns - no executeRequest implementations.
|
||||
* For server-side execution, use @/providers/registry.
|
||||
*/
|
||||
export interface ProviderMetadata {
|
||||
id: string
|
||||
name: string
|
||||
description: string
|
||||
version: string
|
||||
models: string[]
|
||||
defaultModel: string
|
||||
computerUseModels?: string[]
|
||||
modelPatterns?: RegExp[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Build provider metadata from PROVIDER_DEFINITIONS.
|
||||
* This is client-safe as it doesn't import any provider implementations.
|
||||
*/
|
||||
function buildProviderMetadata(providerId: ProviderId): ProviderMetadata {
|
||||
const def = PROVIDER_DEFINITIONS[providerId]
|
||||
return {
|
||||
id: providerId,
|
||||
name: def?.name || providerId,
|
||||
description: def?.description || '',
|
||||
version: '1.0.0',
|
||||
models: getProviderModelsFromDefinitions(providerId),
|
||||
defaultModel: getProviderDefaultModelFromDefinitions(providerId),
|
||||
modelPatterns: def?.modelPatterns,
|
||||
}
|
||||
> = {
|
||||
}
|
||||
|
||||
export const providers: Record<ProviderId, ProviderMetadata> = {
|
||||
openai: {
|
||||
...openaiProvider,
|
||||
models: getProviderModelsFromDefinitions('openai'),
|
||||
...buildProviderMetadata('openai'),
|
||||
computerUseModels: ['computer-use-preview'],
|
||||
modelPatterns: PROVIDER_DEFINITIONS.openai.modelPatterns,
|
||||
},
|
||||
anthropic: {
|
||||
...anthropicProvider,
|
||||
models: getProviderModelsFromDefinitions('anthropic'),
|
||||
...buildProviderMetadata('anthropic'),
|
||||
computerUseModels: getComputerUseModels().filter((model) =>
|
||||
getProviderModelsFromDefinitions('anthropic').includes(model)
|
||||
),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.anthropic.modelPatterns,
|
||||
},
|
||||
google: {
|
||||
...googleProvider,
|
||||
models: getProviderModelsFromDefinitions('google'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.google.modelPatterns,
|
||||
},
|
||||
vertex: {
|
||||
...vertexProvider,
|
||||
models: getProviderModelsFromDefinitions('vertex'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.vertex.modelPatterns,
|
||||
},
|
||||
deepseek: {
|
||||
...deepseekProvider,
|
||||
models: getProviderModelsFromDefinitions('deepseek'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.deepseek.modelPatterns,
|
||||
},
|
||||
xai: {
|
||||
...xAIProvider,
|
||||
models: getProviderModelsFromDefinitions('xai'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.xai.modelPatterns,
|
||||
},
|
||||
cerebras: {
|
||||
...cerebrasProvider,
|
||||
models: getProviderModelsFromDefinitions('cerebras'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.cerebras.modelPatterns,
|
||||
},
|
||||
groq: {
|
||||
...groqProvider,
|
||||
models: getProviderModelsFromDefinitions('groq'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.groq.modelPatterns,
|
||||
},
|
||||
vllm: {
|
||||
...vllmProvider,
|
||||
models: getProviderModelsFromDefinitions('vllm'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.vllm.modelPatterns,
|
||||
},
|
||||
mistral: {
|
||||
...mistralProvider,
|
||||
models: getProviderModelsFromDefinitions('mistral'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.mistral.modelPatterns,
|
||||
},
|
||||
'azure-openai': {
|
||||
...azureOpenAIProvider,
|
||||
models: getProviderModelsFromDefinitions('azure-openai'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS['azure-openai'].modelPatterns,
|
||||
},
|
||||
openrouter: {
|
||||
...openRouterProvider,
|
||||
models: getProviderModelsFromDefinitions('openrouter'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.openrouter.modelPatterns,
|
||||
},
|
||||
ollama: {
|
||||
...ollamaProvider,
|
||||
models: getProviderModelsFromDefinitions('ollama'),
|
||||
modelPatterns: PROVIDER_DEFINITIONS.ollama.modelPatterns,
|
||||
},
|
||||
google: buildProviderMetadata('google'),
|
||||
vertex: buildProviderMetadata('vertex'),
|
||||
deepseek: buildProviderMetadata('deepseek'),
|
||||
xai: buildProviderMetadata('xai'),
|
||||
cerebras: buildProviderMetadata('cerebras'),
|
||||
groq: buildProviderMetadata('groq'),
|
||||
vllm: buildProviderMetadata('vllm'),
|
||||
mistral: buildProviderMetadata('mistral'),
|
||||
'azure-openai': buildProviderMetadata('azure-openai'),
|
||||
openrouter: buildProviderMetadata('openrouter'),
|
||||
ollama: buildProviderMetadata('ollama'),
|
||||
}
|
||||
|
||||
Object.entries(providers).forEach(([id, provider]) => {
|
||||
if (provider.initialize) {
|
||||
provider.initialize().catch((error) => {
|
||||
logger.error(`Failed to initialize ${id} provider`, {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
export function updateOllamaProviderModels(models: string[]): void {
|
||||
updateOllamaModelsInDefinitions(models)
|
||||
providers.ollama.models = getProviderModelsFromDefinitions('ollama')
|
||||
@@ -211,12 +169,12 @@ export function getProviderFromModel(model: string): ProviderId {
|
||||
return 'ollama'
|
||||
}
|
||||
|
||||
export function getProvider(id: string): ProviderConfig | undefined {
|
||||
export function getProvider(id: string): ProviderMetadata | undefined {
|
||||
const providerId = id.split('/')[0] as ProviderId
|
||||
return providers[providerId]
|
||||
}
|
||||
|
||||
export function getProviderConfigFromModel(model: string): ProviderConfig | undefined {
|
||||
export function getProviderConfigFromModel(model: string): ProviderMetadata | undefined {
|
||||
const providerId = getProviderFromModel(model)
|
||||
return providers[providerId]
|
||||
}
|
||||
@@ -929,6 +887,7 @@ export const MODELS_TEMP_RANGE_0_1 = getModelsWithTempRange01()
|
||||
export const MODELS_WITH_TEMPERATURE_SUPPORT = getModelsWithTemperatureSupport()
|
||||
export const MODELS_WITH_REASONING_EFFORT = getModelsWithReasoningEffort()
|
||||
export const MODELS_WITH_VERBOSITY = getModelsWithVerbosity()
|
||||
export const MODELS_WITH_THINKING = getModelsWithThinking()
|
||||
export const PROVIDERS_WITH_TOOL_USAGE_CONTROL = getProvidersWithToolUsageControl()
|
||||
|
||||
export function supportsTemperature(model: string): boolean {
|
||||
@@ -963,6 +922,14 @@ export function getVerbosityValuesForModel(model: string): string[] | null {
|
||||
return getVerbosityValuesForModelFromDefinitions(model)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get thinking levels for a specific model
|
||||
* Returns the valid levels for that model, or null if the model doesn't support thinking
|
||||
*/
|
||||
export function getThinkingLevelsForModel(model: string): string[] | null {
|
||||
return getThinkingLevelsForModelFromDefinitions(model)
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare tool execution parameters, separating tool parameters from system parameters
|
||||
*/
|
||||
|
||||
@@ -1,33 +1,23 @@
|
||||
import { GoogleGenAI } from '@google/genai'
|
||||
import { OAuth2Client } from 'google-auth-library'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
import {
|
||||
cleanSchemaForGemini,
|
||||
convertToGeminiFormat,
|
||||
extractFunctionCall,
|
||||
extractTextContent,
|
||||
} from '@/providers/google/utils'
|
||||
import { executeGeminiRequest } from '@/providers/gemini/core'
|
||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
} from '@/providers/utils'
|
||||
import { buildVertexEndpoint, createReadableStreamFromVertexStream } from '@/providers/vertex/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types'
|
||||
|
||||
const logger = createLogger('VertexProvider')
|
||||
|
||||
/**
|
||||
* Vertex AI provider configuration
|
||||
* Vertex AI provider
|
||||
*
|
||||
* Uses the @google/genai SDK with Vertex AI backend and OAuth authentication.
|
||||
* Shares core execution logic with Google Gemini provider.
|
||||
*
|
||||
* Authentication:
|
||||
* - Uses OAuth access token passed via googleAuthOptions.authClient
|
||||
* - Token refresh is handled at the OAuth layer before calling this provider
|
||||
*/
|
||||
export const vertexProvider: ProviderConfig = {
|
||||
id: 'vertex',
|
||||
@@ -55,869 +45,35 @@ export const vertexProvider: ProviderConfig = {
|
||||
)
|
||||
}
|
||||
|
||||
logger.info('Preparing Vertex AI request', {
|
||||
model: request.model || 'vertex/gemini-2.5-pro',
|
||||
hasSystemPrompt: !!request.systemPrompt,
|
||||
hasMessages: !!request.messages?.length,
|
||||
hasTools: !!request.tools?.length,
|
||||
toolCount: request.tools?.length || 0,
|
||||
hasResponseFormat: !!request.responseFormat,
|
||||
streaming: !!request.stream,
|
||||
// Strip 'vertex/' prefix from model name if present
|
||||
const model = request.model.replace('vertex/', '')
|
||||
|
||||
logger.info('Creating Vertex AI client', {
|
||||
project: vertexProject,
|
||||
location: vertexLocation,
|
||||
model,
|
||||
})
|
||||
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
|
||||
|
||||
const requestedModel = (request.model || 'vertex/gemini-2.5-pro').replace('vertex/', '')
|
||||
|
||||
const payload: any = {
|
||||
contents,
|
||||
generationConfig: {},
|
||||
}
|
||||
|
||||
if (request.temperature !== undefined && request.temperature !== null) {
|
||||
payload.generationConfig.temperature = request.temperature
|
||||
}
|
||||
|
||||
if (request.maxTokens !== undefined) {
|
||||
payload.generationConfig.maxOutputTokens = request.maxTokens
|
||||
}
|
||||
|
||||
if (systemInstruction) {
|
||||
payload.systemInstruction = systemInstruction
|
||||
}
|
||||
|
||||
if (request.responseFormat && !tools?.length) {
|
||||
const responseFormatSchema = request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
payload.generationConfig.responseMimeType = 'application/json'
|
||||
payload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info('Using Vertex AI native structured output format', {
|
||||
hasSchema: !!cleanSchema,
|
||||
mimeType: 'application/json',
|
||||
})
|
||||
} else if (request.responseFormat && tools?.length) {
|
||||
logger.warn(
|
||||
'Vertex AI does not support structured output (responseFormat) with function calling (tools). Structured output will be ignored.'
|
||||
)
|
||||
}
|
||||
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
|
||||
if (tools?.length) {
|
||||
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'google')
|
||||
const { tools: filteredTools, toolConfig } = preparedTools
|
||||
|
||||
if (filteredTools?.length) {
|
||||
payload.tools = [
|
||||
{
|
||||
functionDeclarations: filteredTools,
|
||||
},
|
||||
]
|
||||
|
||||
if (toolConfig) {
|
||||
payload.toolConfig = toolConfig
|
||||
}
|
||||
|
||||
logger.info('Vertex AI request with tools:', {
|
||||
toolCount: filteredTools.length,
|
||||
model: requestedModel,
|
||||
tools: filteredTools.map((t) => t.name),
|
||||
hasToolConfig: !!toolConfig,
|
||||
toolConfig: toolConfig,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const initialCallTime = Date.now()
|
||||
const shouldStream = !!(request.stream && !tools?.length)
|
||||
|
||||
const endpoint = buildVertexEndpoint(
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
requestedModel,
|
||||
shouldStream
|
||||
)
|
||||
|
||||
if (request.stream && tools?.length) {
|
||||
logger.info('Streaming disabled for initial request due to tools presence', {
|
||||
toolCount: tools.length,
|
||||
willStreamAfterTools: true,
|
||||
})
|
||||
}
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${request.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const responseText = await response.text()
|
||||
logger.error('Vertex AI API error details:', {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
responseBody: responseText,
|
||||
})
|
||||
throw new Error(`Vertex AI API error: ${response.status} ${response.statusText}`)
|
||||
}
|
||||
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
if (shouldStream) {
|
||||
logger.info('Handling Vertex AI streaming response')
|
||||
|
||||
const streamingResult: StreamingExecution = {
|
||||
stream: null as any,
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
},
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: firstResponseTime,
|
||||
modelTime: firstResponseTime,
|
||||
toolsTime: 0,
|
||||
firstResponseTime,
|
||||
iterations: 1,
|
||||
timeSegments: [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Initial streaming response',
|
||||
startTime: initialCallTime,
|
||||
endTime: initialCallTime + firstResponseTime,
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
isStreaming: true,
|
||||
},
|
||||
}
|
||||
|
||||
streamingResult.stream = createReadableStreamFromVertexStream(
|
||||
response,
|
||||
(content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
if (streamingResult.execution.output.providerTiming) {
|
||||
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
|
||||
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
|
||||
streamEndTime
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
|
||||
const promptTokens = usage?.promptTokenCount || 0
|
||||
const completionTokens = usage?.candidatesTokenCount || 0
|
||||
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
|
||||
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: promptTokens,
|
||||
completion: completionTokens,
|
||||
total: totalTokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(request.model, promptTokens, completionTokens)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return streamingResult
|
||||
}
|
||||
|
||||
let geminiResponse = await response.json()
|
||||
|
||||
if (payload.generationConfig?.responseSchema) {
|
||||
const candidate = geminiResponse.candidates?.[0]
|
||||
if (candidate?.content?.parts?.[0]?.text) {
|
||||
const text = candidate.content.parts[0].text
|
||||
try {
|
||||
JSON.parse(text)
|
||||
logger.info('Successfully received structured JSON output')
|
||||
} catch (_e) {
|
||||
logger.warn('Failed to parse structured output as JSON')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let content = ''
|
||||
let tokens = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
let iterationCount = 0
|
||||
|
||||
const originalToolConfig = preparedTools?.toolConfig
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
let hasUsedForcedTool = false
|
||||
let currentToolConfig = originalToolConfig
|
||||
|
||||
const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => {
|
||||
if (currentToolConfig && forcedTools.length > 0) {
|
||||
const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }]
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsForTracking,
|
||||
currentToolConfig,
|
||||
logger,
|
||||
'google',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
usedForcedTools = result.usedForcedTools
|
||||
|
||||
if (result.nextToolConfig) {
|
||||
currentToolConfig = result.nextToolConfig
|
||||
logger.info('Updated tool config for next iteration', {
|
||||
hasNextToolConfig: !!currentToolConfig,
|
||||
usedForcedTools: usedForcedTools,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
name: 'Initial response',
|
||||
startTime: initialCallTime,
|
||||
endTime: initialCallTime + firstResponseTime,
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
]
|
||||
|
||||
try {
|
||||
const candidate = geminiResponse.candidates?.[0]
|
||||
|
||||
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
|
||||
logger.warn(
|
||||
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
|
||||
{
|
||||
finishReason: candidate.finishReason,
|
||||
hasContent: !!candidate?.content,
|
||||
hasParts: !!candidate?.content?.parts,
|
||||
}
|
||||
)
|
||||
content = extractTextContent(candidate)
|
||||
}
|
||||
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
|
||||
if (functionCall) {
|
||||
logger.info(`Received function call from Vertex AI: ${functionCall.name}`)
|
||||
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
const latestResponse = geminiResponse.candidates?.[0]
|
||||
const latestFunctionCall = extractFunctionCall(latestResponse)
|
||||
|
||||
if (!latestFunctionCall) {
|
||||
content = extractTextContent(latestResponse)
|
||||
break
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
try {
|
||||
const toolName = latestFunctionCall.name
|
||||
const toolArgs = latestFunctionCall.args || {}
|
||||
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) {
|
||||
logger.warn(`Tool ${toolName} not found in registry, skipping`)
|
||||
break
|
||||
}
|
||||
|
||||
const toolCallStartTime = Date.now()
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
const simplifiedMessages = [
|
||||
...(contents.filter((m) => m.role === 'user').length > 0
|
||||
? [contents.filter((m) => m.role === 'user')[0]]
|
||||
: [contents[0]]),
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
name: latestFunctionCall.name,
|
||||
args: latestFunctionCall.args,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: `Function ${latestFunctionCall.name} result: ${JSON.stringify(resultContent)}`,
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
checkForForcedToolUsage(latestFunctionCall)
|
||||
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
try {
|
||||
if (request.stream) {
|
||||
const streamingPayload = {
|
||||
...payload,
|
||||
contents: simplifiedMessages,
|
||||
}
|
||||
|
||||
const allForcedToolsUsed =
|
||||
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
|
||||
|
||||
if (allForcedToolsUsed && request.responseFormat) {
|
||||
streamingPayload.tools = undefined
|
||||
streamingPayload.toolConfig = undefined
|
||||
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
if (!streamingPayload.generationConfig) {
|
||||
streamingPayload.generationConfig = {}
|
||||
}
|
||||
streamingPayload.generationConfig.responseMimeType = 'application/json'
|
||||
streamingPayload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info('Using structured output for final response after tool execution')
|
||||
} else {
|
||||
if (currentToolConfig) {
|
||||
streamingPayload.toolConfig = currentToolConfig
|
||||
} else {
|
||||
streamingPayload.toolConfig = { functionCallingConfig: { mode: 'AUTO' } }
|
||||
}
|
||||
}
|
||||
|
||||
const checkPayload = {
|
||||
...streamingPayload,
|
||||
}
|
||||
checkPayload.stream = undefined
|
||||
|
||||
const checkEndpoint = buildVertexEndpoint(
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
requestedModel,
|
||||
false
|
||||
)
|
||||
|
||||
const checkResponse = await fetch(checkEndpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${request.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(checkPayload),
|
||||
})
|
||||
|
||||
if (!checkResponse.ok) {
|
||||
const errorBody = await checkResponse.text()
|
||||
logger.error('Error in Vertex AI check request:', {
|
||||
status: checkResponse.status,
|
||||
statusText: checkResponse.statusText,
|
||||
responseBody: errorBody,
|
||||
})
|
||||
throw new Error(
|
||||
`Vertex AI API check error: ${checkResponse.status} ${checkResponse.statusText}`
|
||||
)
|
||||
}
|
||||
|
||||
const checkResult = await checkResponse.json()
|
||||
const checkCandidate = checkResult.candidates?.[0]
|
||||
const checkFunctionCall = extractFunctionCall(checkCandidate)
|
||||
|
||||
if (checkFunctionCall) {
|
||||
logger.info(
|
||||
'Function call detected in follow-up, handling in non-streaming mode',
|
||||
{
|
||||
functionName: checkFunctionCall.name,
|
||||
}
|
||||
)
|
||||
|
||||
geminiResponse = checkResult
|
||||
|
||||
if (checkResult.usageMetadata) {
|
||||
tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0
|
||||
tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0
|
||||
tokens.total +=
|
||||
(checkResult.usageMetadata.promptTokenCount || 0) +
|
||||
(checkResult.usageMetadata.candidatesTokenCount || 0)
|
||||
}
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
modelTime += thisModelTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
startTime: nextModelStartTime,
|
||||
endTime: nextModelEndTime,
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
iterationCount++
|
||||
continue
|
||||
}
|
||||
|
||||
logger.info('No function call detected, proceeding with streaming response')
|
||||
|
||||
if (request.responseFormat) {
|
||||
streamingPayload.tools = undefined
|
||||
streamingPayload.toolConfig = undefined
|
||||
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
if (!streamingPayload.generationConfig) {
|
||||
streamingPayload.generationConfig = {}
|
||||
}
|
||||
streamingPayload.generationConfig.responseMimeType = 'application/json'
|
||||
streamingPayload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info(
|
||||
'Using structured output for final streaming response after tool execution'
|
||||
)
|
||||
}
|
||||
|
||||
const streamEndpoint = buildVertexEndpoint(
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
requestedModel,
|
||||
true
|
||||
)
|
||||
|
||||
const streamingResponse = await fetch(streamEndpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${request.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(streamingPayload),
|
||||
})
|
||||
|
||||
if (!streamingResponse.ok) {
|
||||
const errorBody = await streamingResponse.text()
|
||||
logger.error('Error in Vertex AI streaming follow-up request:', {
|
||||
status: streamingResponse.status,
|
||||
statusText: streamingResponse.statusText,
|
||||
responseBody: errorBody,
|
||||
})
|
||||
throw new Error(
|
||||
`Vertex AI API streaming error: ${streamingResponse.status} ${streamingResponse.statusText}`
|
||||
)
|
||||
}
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
modelTime += thisModelTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: 'Final streaming response after tool calls',
|
||||
startTime: nextModelStartTime,
|
||||
endTime: nextModelEndTime,
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
const streamingExecution: StreamingExecution = {
|
||||
stream: null as any,
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls:
|
||||
toolCalls.length > 0
|
||||
? {
|
||||
list: toolCalls,
|
||||
count: toolCalls.length,
|
||||
}
|
||||
: undefined,
|
||||
toolResults,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
modelTime,
|
||||
toolsTime,
|
||||
firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments,
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
duration: Date.now() - providerStartTime,
|
||||
},
|
||||
isStreaming: true,
|
||||
},
|
||||
}
|
||||
|
||||
streamingExecution.stream = createReadableStreamFromVertexStream(
|
||||
streamingResponse,
|
||||
(content, usage) => {
|
||||
streamingExecution.execution.output.content = content
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
if (streamingExecution.execution.output.providerTiming) {
|
||||
streamingExecution.execution.output.providerTiming.endTime =
|
||||
streamEndTimeISO
|
||||
streamingExecution.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
|
||||
const promptTokens = usage?.promptTokenCount || 0
|
||||
const completionTokens = usage?.candidatesTokenCount || 0
|
||||
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
|
||||
|
||||
const existingTokens = streamingExecution.execution.output.tokens || {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
|
||||
const existingPrompt = existingTokens.prompt || 0
|
||||
const existingCompletion = existingTokens.completion || 0
|
||||
const existingTotal = existingTokens.total || 0
|
||||
|
||||
streamingExecution.execution.output.tokens = {
|
||||
prompt: existingPrompt + promptTokens,
|
||||
completion: existingCompletion + completionTokens,
|
||||
total: existingTotal + totalTokens,
|
||||
}
|
||||
|
||||
const accumulatedCost = calculateCost(
|
||||
request.model,
|
||||
existingPrompt,
|
||||
existingCompletion
|
||||
)
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
promptTokens,
|
||||
completionTokens
|
||||
)
|
||||
streamingExecution.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return streamingExecution
|
||||
}
|
||||
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
contents: simplifiedMessages,
|
||||
}
|
||||
|
||||
const allForcedToolsUsed =
|
||||
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
|
||||
|
||||
if (allForcedToolsUsed && request.responseFormat) {
|
||||
nextPayload.tools = undefined
|
||||
nextPayload.toolConfig = undefined
|
||||
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
if (!nextPayload.generationConfig) {
|
||||
nextPayload.generationConfig = {}
|
||||
}
|
||||
nextPayload.generationConfig.responseMimeType = 'application/json'
|
||||
nextPayload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info(
|
||||
'Using structured output for final non-streaming response after tool execution'
|
||||
)
|
||||
} else {
|
||||
if (currentToolConfig) {
|
||||
nextPayload.toolConfig = currentToolConfig
|
||||
}
|
||||
}
|
||||
|
||||
const nextEndpoint = buildVertexEndpoint(
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
requestedModel,
|
||||
false
|
||||
)
|
||||
|
||||
const nextResponse = await fetch(nextEndpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${request.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(nextPayload),
|
||||
})
|
||||
|
||||
if (!nextResponse.ok) {
|
||||
const errorBody = await nextResponse.text()
|
||||
logger.error('Error in Vertex AI follow-up request:', {
|
||||
status: nextResponse.status,
|
||||
statusText: nextResponse.statusText,
|
||||
responseBody: errorBody,
|
||||
iterationCount,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
geminiResponse = await nextResponse.json()
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
startTime: nextModelStartTime,
|
||||
endTime: nextModelEndTime,
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
modelTime += thisModelTime
|
||||
|
||||
const nextCandidate = geminiResponse.candidates?.[0]
|
||||
const nextFunctionCall = extractFunctionCall(nextCandidate)
|
||||
|
||||
if (!nextFunctionCall) {
|
||||
if (request.responseFormat) {
|
||||
const finalPayload = {
|
||||
...payload,
|
||||
contents: nextPayload.contents,
|
||||
tools: undefined,
|
||||
toolConfig: undefined,
|
||||
}
|
||||
|
||||
const responseFormatSchema =
|
||||
request.responseFormat.schema || request.responseFormat
|
||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
||||
|
||||
if (!finalPayload.generationConfig) {
|
||||
finalPayload.generationConfig = {}
|
||||
}
|
||||
finalPayload.generationConfig.responseMimeType = 'application/json'
|
||||
finalPayload.generationConfig.responseSchema = cleanSchema
|
||||
|
||||
logger.info('Making final request with structured output after tool execution')
|
||||
|
||||
const finalEndpoint = buildVertexEndpoint(
|
||||
vertexProject,
|
||||
vertexLocation,
|
||||
requestedModel,
|
||||
false
|
||||
)
|
||||
|
||||
const finalResponse = await fetch(finalEndpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${request.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(finalPayload),
|
||||
})
|
||||
|
||||
if (finalResponse.ok) {
|
||||
const finalResult = await finalResponse.json()
|
||||
const finalCandidate = finalResult.candidates?.[0]
|
||||
content = extractTextContent(finalCandidate)
|
||||
|
||||
if (finalResult.usageMetadata) {
|
||||
tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0
|
||||
tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0
|
||||
tokens.total +=
|
||||
(finalResult.usageMetadata.promptTokenCount || 0) +
|
||||
(finalResult.usageMetadata.candidatesTokenCount || 0)
|
||||
}
|
||||
} else {
|
||||
logger.warn(
|
||||
'Failed to get structured output, falling back to regular response'
|
||||
)
|
||||
content = extractTextContent(nextCandidate)
|
||||
}
|
||||
} else {
|
||||
content = extractTextContent(nextCandidate)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
iterationCount++
|
||||
} catch (error) {
|
||||
logger.error('Error in Vertex AI follow-up request:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
iterationCount,
|
||||
})
|
||||
break
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error processing function call:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
functionName: latestFunctionCall?.name || 'unknown',
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
content = extractTextContent(candidate)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error processing Vertex AI response:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
iterationCount,
|
||||
})
|
||||
|
||||
if (!content && toolCalls.length > 0) {
|
||||
content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.`
|
||||
}
|
||||
}
|
||||
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
if (geminiResponse.usageMetadata) {
|
||||
tokens = {
|
||||
prompt: geminiResponse.usageMetadata.promptTokenCount || 0,
|
||||
completion: geminiResponse.usageMetadata.candidatesTokenCount || 0,
|
||||
total:
|
||||
(geminiResponse.usageMetadata.promptTokenCount || 0) +
|
||||
(geminiResponse.usageMetadata.candidatesTokenCount || 0),
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content,
|
||||
model: request.model,
|
||||
tokens,
|
||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
||||
timing: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
modelTime: modelTime,
|
||||
toolsTime: toolsTime,
|
||||
firstResponseTime: firstResponseTime,
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
logger.error('Error in Vertex AI request:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
duration: totalDuration,
|
||||
}
|
||||
|
||||
throw enhancedError
|
||||
}
|
||||
// Create an OAuth2Client and set the access token
|
||||
// This allows us to use an OAuth access token with the SDK
|
||||
const authClient = new OAuth2Client()
|
||||
authClient.setCredentials({ access_token: request.apiKey })
|
||||
|
||||
// Create client with Vertex AI configuration
|
||||
const ai = new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project: vertexProject,
|
||||
location: vertexLocation,
|
||||
googleAuthOptions: {
|
||||
authClient,
|
||||
},
|
||||
})
|
||||
|
||||
return executeGeminiRequest({
|
||||
ai,
|
||||
model,
|
||||
request,
|
||||
providerType: 'vertex',
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,231 +0,0 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { extractFunctionCall, extractTextContent } from '@/providers/google/utils'
|
||||
|
||||
const logger = createLogger('VertexUtils')
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from Vertex AI's Gemini stream response
|
||||
*/
|
||||
export function createReadableStreamFromVertexStream(
|
||||
response: Response,
|
||||
onComplete?: (
|
||||
content: string,
|
||||
usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number }
|
||||
) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
throw new Error('Failed to get reader from response body')
|
||||
}
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
let buffer = ''
|
||||
let fullContent = ''
|
||||
let usageData: {
|
||||
promptTokenCount?: number
|
||||
candidatesTokenCount?: number
|
||||
totalTokenCount?: number
|
||||
} | null = null
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
if (buffer.trim()) {
|
||||
try {
|
||||
const data = JSON.parse(buffer.trim())
|
||||
if (data.usageMetadata) {
|
||||
usageData = data.usageMetadata
|
||||
}
|
||||
const candidate = data.candidates?.[0]
|
||||
if (candidate?.content?.parts) {
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
'Function call detected in final buffer, ending stream to execute tool',
|
||||
{
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
const content = extractTextContent(candidate)
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (buffer.trim().startsWith('[')) {
|
||||
try {
|
||||
const dataArray = JSON.parse(buffer.trim())
|
||||
if (Array.isArray(dataArray)) {
|
||||
for (const item of dataArray) {
|
||||
if (item.usageMetadata) {
|
||||
usageData = item.usageMetadata
|
||||
}
|
||||
const candidate = item.candidates?.[0]
|
||||
if (candidate?.content?.parts) {
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
'Function call detected in array item, ending stream to execute tool',
|
||||
{
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
const content = extractTextContent(candidate)
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (arrayError) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
break
|
||||
}
|
||||
|
||||
const text = new TextDecoder().decode(value)
|
||||
buffer += text
|
||||
|
||||
let searchIndex = 0
|
||||
while (searchIndex < buffer.length) {
|
||||
const openBrace = buffer.indexOf('{', searchIndex)
|
||||
if (openBrace === -1) break
|
||||
|
||||
let braceCount = 0
|
||||
let inString = false
|
||||
let escaped = false
|
||||
let closeBrace = -1
|
||||
|
||||
for (let i = openBrace; i < buffer.length; i++) {
|
||||
const char = buffer[i]
|
||||
|
||||
if (!inString) {
|
||||
if (char === '"' && !escaped) {
|
||||
inString = true
|
||||
} else if (char === '{') {
|
||||
braceCount++
|
||||
} else if (char === '}') {
|
||||
braceCount--
|
||||
if (braceCount === 0) {
|
||||
closeBrace = i
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (char === '"' && !escaped) {
|
||||
inString = false
|
||||
}
|
||||
}
|
||||
|
||||
escaped = char === '\\' && !escaped
|
||||
}
|
||||
|
||||
if (closeBrace !== -1) {
|
||||
const jsonStr = buffer.substring(openBrace, closeBrace + 1)
|
||||
|
||||
try {
|
||||
const data = JSON.parse(jsonStr)
|
||||
|
||||
if (data.usageMetadata) {
|
||||
usageData = data.usageMetadata
|
||||
}
|
||||
|
||||
const candidate = data.candidates?.[0]
|
||||
|
||||
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
|
||||
logger.warn(
|
||||
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
|
||||
{
|
||||
finishReason: candidate.finishReason,
|
||||
hasContent: !!candidate?.content,
|
||||
hasParts: !!candidate?.content?.parts,
|
||||
}
|
||||
)
|
||||
const textContent = extractTextContent(candidate)
|
||||
if (textContent) {
|
||||
fullContent += textContent
|
||||
controller.enqueue(new TextEncoder().encode(textContent))
|
||||
}
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
|
||||
if (candidate?.content?.parts) {
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
'Function call detected in stream, ending stream to execute tool',
|
||||
{
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
const content = extractTextContent(candidate)
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error('Error parsing JSON from stream', {
|
||||
error: e instanceof Error ? e.message : String(e),
|
||||
jsonPreview: jsonStr.substring(0, 200),
|
||||
})
|
||||
}
|
||||
|
||||
buffer = buffer.substring(closeBrace + 1)
|
||||
searchIndex = 0
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error('Error reading Vertex AI stream', {
|
||||
error: e instanceof Error ? e.message : String(e),
|
||||
})
|
||||
controller.error(e)
|
||||
}
|
||||
},
|
||||
async cancel() {
|
||||
await reader.cancel()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Build Vertex AI endpoint URL
|
||||
*/
|
||||
export function buildVertexEndpoint(
|
||||
project: string,
|
||||
location: string,
|
||||
model: string,
|
||||
isStreaming: boolean
|
||||
): string {
|
||||
const action = isStreaming ? 'streamGenerateContent' : 'generateContent'
|
||||
|
||||
if (location === 'global') {
|
||||
return `https://aiplatform.googleapis.com/v1/projects/${project}/locations/global/publishers/google/models/${model}:${action}`
|
||||
}
|
||||
|
||||
return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/google/models/${model}:${action}`
|
||||
}
|
||||
@@ -130,7 +130,7 @@ export const vllmProvider: ProviderConfig = {
|
||||
: undefined
|
||||
|
||||
const payload: any = {
|
||||
model: (request.model || getProviderDefaultModel('vllm')).replace(/^vllm\//, ''),
|
||||
model: request.model.replace(/^vllm\//, ''),
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ export const xAIProvider: ProviderConfig = {
|
||||
hasTools: !!request.tools?.length,
|
||||
toolCount: request.tools?.length || 0,
|
||||
hasResponseFormat: !!request.responseFormat,
|
||||
model: request.model || 'grok-3-latest',
|
||||
model: request.model,
|
||||
streaming: !!request.stream,
|
||||
})
|
||||
|
||||
@@ -87,7 +87,7 @@ export const xAIProvider: ProviderConfig = {
|
||||
)
|
||||
}
|
||||
const basePayload: any = {
|
||||
model: request.model || 'grok-3-latest',
|
||||
model: request.model,
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ export const xAIProvider: ProviderConfig = {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model || 'grok-3-latest',
|
||||
model: request.model,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
@@ -505,7 +505,7 @@ export const xAIProvider: ProviderConfig = {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model || 'grok-3-latest',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
completion: tokens.completion,
|
||||
|
||||
Reference in New Issue
Block a user