mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-08 22:48:14 -05:00
fix(search): removed full text param from built-in search, anthropic provider streaming fix (#2542)
* fix(search): removed full text param from built-in search, anthropic provider streaming fix * rewrite gemini provider with official sdk + add thinking capability * vertex gemini consolidation * never silently use different model * pass oauth client through the googleAuthOptions param directly * make server side provider registry * remove comments * take oauth selector below model selector --------- Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
This commit is contained in:
@@ -56,7 +56,7 @@ export async function POST(request: NextRequest) {
|
|||||||
query: validated.query,
|
query: validated.query,
|
||||||
type: 'auto',
|
type: 'auto',
|
||||||
useAutoprompt: true,
|
useAutoprompt: true,
|
||||||
text: true,
|
highlights: true,
|
||||||
apiKey: env.EXA_API_KEY,
|
apiKey: env.EXA_API_KEY,
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -77,7 +77,7 @@ export async function POST(request: NextRequest) {
|
|||||||
const results = (result.output.results || []).map((r: any, index: number) => ({
|
const results = (result.output.results || []).map((r: any, index: number) => ({
|
||||||
title: r.title || '',
|
title: r.title || '',
|
||||||
link: r.url || '',
|
link: r.url || '',
|
||||||
snippet: r.text || '',
|
snippet: Array.isArray(r.highlights) ? r.highlights.join(' ... ') : '',
|
||||||
date: r.publishedDate || undefined,
|
date: r.publishedDate || undefined,
|
||||||
position: index + 1,
|
position: index + 1,
|
||||||
}))
|
}))
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ const SCOPE_DESCRIPTIONS: Record<string, string> = {
|
|||||||
'https://www.googleapis.com/auth/admin.directory.group.readonly': 'View Google Workspace groups',
|
'https://www.googleapis.com/auth/admin.directory.group.readonly': 'View Google Workspace groups',
|
||||||
'https://www.googleapis.com/auth/admin.directory.group.member.readonly':
|
'https://www.googleapis.com/auth/admin.directory.group.member.readonly':
|
||||||
'View Google Workspace group memberships',
|
'View Google Workspace group memberships',
|
||||||
|
'https://www.googleapis.com/auth/cloud-platform':
|
||||||
|
'Full access to Google Cloud resources for Vertex AI',
|
||||||
'read:confluence-content.all': 'Read all Confluence content',
|
'read:confluence-content.all': 'Read all Confluence content',
|
||||||
'read:confluence-space.summary': 'Read Confluence space information',
|
'read:confluence-space.summary': 'Read Confluence space information',
|
||||||
'read:space:confluence': 'View Confluence spaces',
|
'read:space:confluence': 'View Confluence spaces',
|
||||||
|
|||||||
@@ -9,8 +9,10 @@ import {
|
|||||||
getMaxTemperature,
|
getMaxTemperature,
|
||||||
getProviderIcon,
|
getProviderIcon,
|
||||||
getReasoningEffortValuesForModel,
|
getReasoningEffortValuesForModel,
|
||||||
|
getThinkingLevelsForModel,
|
||||||
getVerbosityValuesForModel,
|
getVerbosityValuesForModel,
|
||||||
MODELS_WITH_REASONING_EFFORT,
|
MODELS_WITH_REASONING_EFFORT,
|
||||||
|
MODELS_WITH_THINKING,
|
||||||
MODELS_WITH_VERBOSITY,
|
MODELS_WITH_VERBOSITY,
|
||||||
providers,
|
providers,
|
||||||
supportsTemperature,
|
supportsTemperature,
|
||||||
@@ -108,7 +110,19 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
|||||||
})
|
})
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
id: 'vertexCredential',
|
||||||
|
title: 'Google Cloud Account',
|
||||||
|
type: 'oauth-input',
|
||||||
|
serviceId: 'vertex-ai',
|
||||||
|
requiredScopes: ['https://www.googleapis.com/auth/cloud-platform'],
|
||||||
|
placeholder: 'Select Google Cloud account',
|
||||||
|
required: true,
|
||||||
|
condition: {
|
||||||
|
field: 'model',
|
||||||
|
value: providers.vertex.models,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
id: 'reasoningEffort',
|
id: 'reasoningEffort',
|
||||||
title: 'Reasoning Effort',
|
title: 'Reasoning Effort',
|
||||||
@@ -215,6 +229,57 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
|||||||
value: MODELS_WITH_VERBOSITY,
|
value: MODELS_WITH_VERBOSITY,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
id: 'thinkingLevel',
|
||||||
|
title: 'Thinking Level',
|
||||||
|
type: 'dropdown',
|
||||||
|
placeholder: 'Select thinking level...',
|
||||||
|
options: [
|
||||||
|
{ label: 'minimal', id: 'minimal' },
|
||||||
|
{ label: 'low', id: 'low' },
|
||||||
|
{ label: 'medium', id: 'medium' },
|
||||||
|
{ label: 'high', id: 'high' },
|
||||||
|
],
|
||||||
|
dependsOn: ['model'],
|
||||||
|
fetchOptions: async (blockId: string) => {
|
||||||
|
const { useSubBlockStore } = await import('@/stores/workflows/subblock/store')
|
||||||
|
const { useWorkflowRegistry } = await import('@/stores/workflows/registry/store')
|
||||||
|
|
||||||
|
const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId
|
||||||
|
if (!activeWorkflowId) {
|
||||||
|
return [
|
||||||
|
{ label: 'low', id: 'low' },
|
||||||
|
{ label: 'high', id: 'high' },
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
const workflowValues = useSubBlockStore.getState().workflowValues[activeWorkflowId]
|
||||||
|
const blockValues = workflowValues?.[blockId]
|
||||||
|
const modelValue = blockValues?.model as string
|
||||||
|
|
||||||
|
if (!modelValue) {
|
||||||
|
return [
|
||||||
|
{ label: 'low', id: 'low' },
|
||||||
|
{ label: 'high', id: 'high' },
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
const validOptions = getThinkingLevelsForModel(modelValue)
|
||||||
|
if (!validOptions) {
|
||||||
|
return [
|
||||||
|
{ label: 'low', id: 'low' },
|
||||||
|
{ label: 'high', id: 'high' },
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
return validOptions.map((opt) => ({ label: opt, id: opt }))
|
||||||
|
},
|
||||||
|
value: () => 'high',
|
||||||
|
condition: {
|
||||||
|
field: 'model',
|
||||||
|
value: MODELS_WITH_THINKING,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
id: 'azureEndpoint',
|
id: 'azureEndpoint',
|
||||||
@@ -275,17 +340,21 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
|
|||||||
password: true,
|
password: true,
|
||||||
connectionDroppable: false,
|
connectionDroppable: false,
|
||||||
required: true,
|
required: true,
|
||||||
// Hide API key for hosted models, Ollama models, and vLLM models
|
// Hide API key for hosted models, Ollama models, vLLM models, and Vertex models (uses OAuth)
|
||||||
condition: isHosted
|
condition: isHosted
|
||||||
? {
|
? {
|
||||||
field: 'model',
|
field: 'model',
|
||||||
value: getHostedModels(),
|
value: [...getHostedModels(), ...providers.vertex.models],
|
||||||
not: true, // Show for all models EXCEPT those listed
|
not: true, // Show for all models EXCEPT those listed
|
||||||
}
|
}
|
||||||
: () => ({
|
: () => ({
|
||||||
field: 'model',
|
field: 'model',
|
||||||
value: [...getCurrentOllamaModels(), ...getCurrentVLLMModels()],
|
value: [
|
||||||
not: true, // Show for all models EXCEPT Ollama and vLLM models
|
...getCurrentOllamaModels(),
|
||||||
|
...getCurrentVLLMModels(),
|
||||||
|
...providers.vertex.models,
|
||||||
|
],
|
||||||
|
not: true, // Show for all models EXCEPT Ollama, vLLM, and Vertex models
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -609,6 +678,7 @@ Example 3 (Array Input):
|
|||||||
temperature: { type: 'number', description: 'Response randomness level' },
|
temperature: { type: 'number', description: 'Response randomness level' },
|
||||||
reasoningEffort: { type: 'string', description: 'Reasoning effort level for GPT-5 models' },
|
reasoningEffort: { type: 'string', description: 'Reasoning effort level for GPT-5 models' },
|
||||||
verbosity: { type: 'string', description: 'Verbosity level for GPT-5 models' },
|
verbosity: { type: 'string', description: 'Verbosity level for GPT-5 models' },
|
||||||
|
thinkingLevel: { type: 'string', description: 'Thinking level for Gemini 3 models' },
|
||||||
tools: { type: 'json', description: 'Available tools configuration' },
|
tools: { type: 'json', description: 'Available tools configuration' },
|
||||||
},
|
},
|
||||||
outputs: {
|
outputs: {
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import { db } from '@sim/db'
|
import { db } from '@sim/db'
|
||||||
import { mcpServers } from '@sim/db/schema'
|
import { account, mcpServers } from '@sim/db/schema'
|
||||||
import { and, eq, inArray, isNull } from 'drizzle-orm'
|
import { and, eq, inArray, isNull } from 'drizzle-orm'
|
||||||
import { createLogger } from '@/lib/logs/console/logger'
|
import { createLogger } from '@/lib/logs/console/logger'
|
||||||
import { createMcpToolId } from '@/lib/mcp/utils'
|
import { createMcpToolId } from '@/lib/mcp/utils'
|
||||||
|
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
|
||||||
import { getAllBlocks } from '@/blocks'
|
import { getAllBlocks } from '@/blocks'
|
||||||
import type { BlockOutput } from '@/blocks/types'
|
import type { BlockOutput } from '@/blocks/types'
|
||||||
import { AGENT, BlockType, DEFAULTS, HTTP } from '@/executor/constants'
|
import { AGENT, BlockType, DEFAULTS, HTTP } from '@/executor/constants'
|
||||||
@@ -919,6 +920,7 @@ export class AgentBlockHandler implements BlockHandler {
|
|||||||
azureApiVersion: inputs.azureApiVersion,
|
azureApiVersion: inputs.azureApiVersion,
|
||||||
vertexProject: inputs.vertexProject,
|
vertexProject: inputs.vertexProject,
|
||||||
vertexLocation: inputs.vertexLocation,
|
vertexLocation: inputs.vertexLocation,
|
||||||
|
vertexCredential: inputs.vertexCredential,
|
||||||
responseFormat,
|
responseFormat,
|
||||||
workflowId: ctx.workflowId,
|
workflowId: ctx.workflowId,
|
||||||
workspaceId: ctx.workspaceId,
|
workspaceId: ctx.workspaceId,
|
||||||
@@ -997,7 +999,17 @@ export class AgentBlockHandler implements BlockHandler {
|
|||||||
responseFormat: any,
|
responseFormat: any,
|
||||||
providerStartTime: number
|
providerStartTime: number
|
||||||
) {
|
) {
|
||||||
const finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey)
|
let finalApiKey: string
|
||||||
|
|
||||||
|
// For Vertex AI, resolve OAuth credential to access token
|
||||||
|
if (providerId === 'vertex' && providerRequest.vertexCredential) {
|
||||||
|
finalApiKey = await this.resolveVertexCredential(
|
||||||
|
providerRequest.vertexCredential,
|
||||||
|
ctx.workflowId
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
const { blockData, blockNameMapping } = collectBlockData(ctx)
|
const { blockData, blockNameMapping } = collectBlockData(ctx)
|
||||||
|
|
||||||
@@ -1024,7 +1036,6 @@ export class AgentBlockHandler implements BlockHandler {
|
|||||||
blockNameMapping,
|
blockNameMapping,
|
||||||
})
|
})
|
||||||
|
|
||||||
this.logExecutionSuccess(providerId, model, ctx, block, providerStartTime, response)
|
|
||||||
return this.processProviderResponse(response, block, responseFormat)
|
return this.processProviderResponse(response, block, responseFormat)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1049,15 +1060,6 @@ export class AgentBlockHandler implements BlockHandler {
|
|||||||
throw new Error(errorMessage)
|
throw new Error(errorMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
this.logExecutionSuccess(
|
|
||||||
providerRequest.provider,
|
|
||||||
providerRequest.model,
|
|
||||||
ctx,
|
|
||||||
block,
|
|
||||||
providerStartTime,
|
|
||||||
'HTTP response'
|
|
||||||
)
|
|
||||||
|
|
||||||
const contentType = response.headers.get('Content-Type')
|
const contentType = response.headers.get('Content-Type')
|
||||||
if (contentType?.includes(HTTP.CONTENT_TYPE.EVENT_STREAM)) {
|
if (contentType?.includes(HTTP.CONTENT_TYPE.EVENT_STREAM)) {
|
||||||
return this.handleStreamingResponse(response, block, ctx, inputs)
|
return this.handleStreamingResponse(response, block, ctx, inputs)
|
||||||
@@ -1117,21 +1119,33 @@ export class AgentBlockHandler implements BlockHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private logExecutionSuccess(
|
/**
|
||||||
provider: string,
|
* Resolves a Vertex AI OAuth credential to an access token
|
||||||
model: string,
|
*/
|
||||||
ctx: ExecutionContext,
|
private async resolveVertexCredential(credentialId: string, workflowId: string): Promise<string> {
|
||||||
block: SerializedBlock,
|
const requestId = `vertex-${Date.now()}`
|
||||||
startTime: number,
|
|
||||||
response: any
|
logger.info(`[${requestId}] Resolving Vertex AI credential: ${credentialId}`)
|
||||||
) {
|
|
||||||
const executionTime = Date.now() - startTime
|
// Get the credential - we need to find the owner
|
||||||
const responseType =
|
// Since we're in a workflow context, we can query the credential directly
|
||||||
response instanceof ReadableStream
|
const credential = await db.query.account.findFirst({
|
||||||
? 'stream'
|
where: eq(account.id, credentialId),
|
||||||
: response && typeof response === 'object' && 'stream' in response
|
})
|
||||||
? 'streaming-execution'
|
|
||||||
: 'json'
|
if (!credential) {
|
||||||
|
throw new Error(`Vertex AI credential not found: ${credentialId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh the token if needed
|
||||||
|
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
|
||||||
|
|
||||||
|
if (!accessToken) {
|
||||||
|
throw new Error('Failed to get Vertex AI access token')
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(`[${requestId}] Successfully resolved Vertex AI credential`)
|
||||||
|
return accessToken
|
||||||
}
|
}
|
||||||
|
|
||||||
private handleExecutionError(
|
private handleExecutionError(
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ export interface AgentInputs {
|
|||||||
azureApiVersion?: string
|
azureApiVersion?: string
|
||||||
vertexProject?: string
|
vertexProject?: string
|
||||||
vertexLocation?: string
|
vertexLocation?: string
|
||||||
|
vertexCredential?: string
|
||||||
reasoningEffort?: string
|
reasoningEffort?: string
|
||||||
verbosity?: string
|
verbosity?: string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -579,6 +579,21 @@ export const auth = betterAuth({
|
|||||||
redirectURI: `${getBaseUrl()}/api/auth/oauth2/callback/google-groups`,
|
redirectURI: `${getBaseUrl()}/api/auth/oauth2/callback/google-groups`,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
providerId: 'vertex-ai',
|
||||||
|
clientId: env.GOOGLE_CLIENT_ID as string,
|
||||||
|
clientSecret: env.GOOGLE_CLIENT_SECRET as string,
|
||||||
|
discoveryUrl: 'https://accounts.google.com/.well-known/openid-configuration',
|
||||||
|
accessType: 'offline',
|
||||||
|
scopes: [
|
||||||
|
'https://www.googleapis.com/auth/userinfo.email',
|
||||||
|
'https://www.googleapis.com/auth/userinfo.profile',
|
||||||
|
'https://www.googleapis.com/auth/cloud-platform',
|
||||||
|
],
|
||||||
|
prompt: 'consent',
|
||||||
|
redirectURI: `${getBaseUrl()}/api/auth/oauth2/callback/vertex-ai`,
|
||||||
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
providerId: 'microsoft-teams',
|
providerId: 'microsoft-teams',
|
||||||
clientId: env.MICROSOFT_CLIENT_ID as string,
|
clientId: env.MICROSOFT_CLIENT_ID as string,
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ function filterUserFile(data: any): any {
|
|||||||
const DISPLAY_FILTERS = [filterUserFile]
|
const DISPLAY_FILTERS = [filterUserFile]
|
||||||
|
|
||||||
export function filterForDisplay(data: any): any {
|
export function filterForDisplay(data: any): any {
|
||||||
const seen = new WeakSet()
|
const seen = new Set<object>()
|
||||||
return filterForDisplayInternal(data, seen, 0)
|
return filterForDisplayInternal(data, seen, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ function getObjectType(data: unknown): string {
|
|||||||
return Object.prototype.toString.call(data).slice(8, -1)
|
return Object.prototype.toString.call(data).slice(8, -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: number): any {
|
function filterForDisplayInternal(data: any, seen: Set<object>, depth: number): any {
|
||||||
try {
|
try {
|
||||||
if (data === null || data === undefined) {
|
if (data === null || data === undefined) {
|
||||||
return data
|
return data
|
||||||
@@ -93,6 +93,7 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
|
|||||||
return '[Unknown Type]'
|
return '[Unknown Type]'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// True circular reference: object is an ancestor in the current path
|
||||||
if (seen.has(data)) {
|
if (seen.has(data)) {
|
||||||
return '[Circular Reference]'
|
return '[Circular Reference]'
|
||||||
}
|
}
|
||||||
@@ -131,18 +132,24 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
|
|||||||
return `[ArrayBuffer: ${(data as ArrayBuffer).byteLength} bytes]`
|
return `[ArrayBuffer: ${(data as ArrayBuffer).byteLength} bytes]`
|
||||||
|
|
||||||
case 'Map': {
|
case 'Map': {
|
||||||
|
seen.add(data)
|
||||||
const obj: Record<string, any> = {}
|
const obj: Record<string, any> = {}
|
||||||
for (const [key, value] of (data as Map<any, any>).entries()) {
|
for (const [key, value] of (data as Map<any, any>).entries()) {
|
||||||
const keyStr = typeof key === 'string' ? key : String(key)
|
const keyStr = typeof key === 'string' ? key : String(key)
|
||||||
obj[keyStr] = filterForDisplayInternal(value, seen, depth + 1)
|
obj[keyStr] = filterForDisplayInternal(value, seen, depth + 1)
|
||||||
}
|
}
|
||||||
|
seen.delete(data)
|
||||||
return obj
|
return obj
|
||||||
}
|
}
|
||||||
|
|
||||||
case 'Set':
|
case 'Set': {
|
||||||
return Array.from(data as Set<any>).map((item) =>
|
seen.add(data)
|
||||||
|
const result = Array.from(data as Set<any>).map((item) =>
|
||||||
filterForDisplayInternal(item, seen, depth + 1)
|
filterForDisplayInternal(item, seen, depth + 1)
|
||||||
)
|
)
|
||||||
|
seen.delete(data)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
case 'WeakMap':
|
case 'WeakMap':
|
||||||
return '[WeakMap]'
|
return '[WeakMap]'
|
||||||
@@ -161,17 +168,22 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
|
|||||||
return `[${objectType}: ${(data as ArrayBufferView).byteLength} bytes]`
|
return `[${objectType}: ${(data as ArrayBufferView).byteLength} bytes]`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add to current path before processing children
|
||||||
seen.add(data)
|
seen.add(data)
|
||||||
|
|
||||||
for (const filterFn of DISPLAY_FILTERS) {
|
for (const filterFn of DISPLAY_FILTERS) {
|
||||||
const result = filterFn(data)
|
const filtered = filterFn(data)
|
||||||
if (result !== data) {
|
if (filtered !== data) {
|
||||||
return filterForDisplayInternal(result, seen, depth + 1)
|
const result = filterForDisplayInternal(filtered, seen, depth + 1)
|
||||||
|
seen.delete(data)
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Array.isArray(data)) {
|
if (Array.isArray(data)) {
|
||||||
return data.map((item) => filterForDisplayInternal(item, seen, depth + 1))
|
const result = data.map((item) => filterForDisplayInternal(item, seen, depth + 1))
|
||||||
|
seen.delete(data)
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
const result: Record<string, any> = {}
|
const result: Record<string, any> = {}
|
||||||
@@ -182,6 +194,8 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
|
|||||||
result[key] = '[Error accessing property]'
|
result[key] = '[Error accessing property]'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Remove from current path after processing children
|
||||||
|
seen.delete(data)
|
||||||
return result
|
return result
|
||||||
} catch {
|
} catch {
|
||||||
return '[Unserializable]'
|
return '[Unserializable]'
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ import {
|
|||||||
SlackIcon,
|
SlackIcon,
|
||||||
SpotifyIcon,
|
SpotifyIcon,
|
||||||
TrelloIcon,
|
TrelloIcon,
|
||||||
|
VertexIcon,
|
||||||
WealthboxIcon,
|
WealthboxIcon,
|
||||||
WebflowIcon,
|
WebflowIcon,
|
||||||
WordpressIcon,
|
WordpressIcon,
|
||||||
@@ -80,6 +81,7 @@ export type OAuthService =
|
|||||||
| 'google-vault'
|
| 'google-vault'
|
||||||
| 'google-forms'
|
| 'google-forms'
|
||||||
| 'google-groups'
|
| 'google-groups'
|
||||||
|
| 'vertex-ai'
|
||||||
| 'github'
|
| 'github'
|
||||||
| 'x'
|
| 'x'
|
||||||
| 'confluence'
|
| 'confluence'
|
||||||
@@ -237,6 +239,16 @@ export const OAUTH_PROVIDERS: Record<string, OAuthProviderConfig> = {
|
|||||||
],
|
],
|
||||||
scopeHints: ['admin.directory.group'],
|
scopeHints: ['admin.directory.group'],
|
||||||
},
|
},
|
||||||
|
'vertex-ai': {
|
||||||
|
id: 'vertex-ai',
|
||||||
|
name: 'Vertex AI',
|
||||||
|
description: 'Access Google Cloud Vertex AI for Gemini models with OAuth.',
|
||||||
|
providerId: 'vertex-ai',
|
||||||
|
icon: (props) => VertexIcon(props),
|
||||||
|
baseProviderIcon: (props) => VertexIcon(props),
|
||||||
|
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
|
||||||
|
scopeHints: ['cloud-platform', 'vertex', 'aiplatform'],
|
||||||
|
},
|
||||||
},
|
},
|
||||||
defaultService: 'gmail',
|
defaultService: 'gmail',
|
||||||
},
|
},
|
||||||
@@ -1099,6 +1111,12 @@ export function parseProvider(provider: OAuthProvider): ProviderConfig {
|
|||||||
featureType: 'microsoft-planner',
|
featureType: 'microsoft-planner',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (provider === 'vertex-ai') {
|
||||||
|
return {
|
||||||
|
baseProvider: 'google',
|
||||||
|
featureType: 'vertex-ai',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Handle compound providers (e.g., 'google-email' -> { baseProvider: 'google', featureType: 'email' })
|
// Handle compound providers (e.g., 'google-email' -> { baseProvider: 'google', featureType: 'email' })
|
||||||
const [base, feature] = provider.split('-')
|
const [base, feature] = provider.split('-')
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ export const anthropicProvider: ProviderConfig = {
|
|||||||
throw new Error('API key is required for Anthropic')
|
throw new Error('API key is required for Anthropic')
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelId = request.model || 'claude-3-7-sonnet-20250219'
|
const modelId = request.model
|
||||||
const useNativeStructuredOutputs = !!(
|
const useNativeStructuredOutputs = !!(
|
||||||
request.responseFormat && supportsNativeStructuredOutputs(modelId)
|
request.responseFormat && supportsNativeStructuredOutputs(modelId)
|
||||||
)
|
)
|
||||||
@@ -174,7 +174,7 @@ export const anthropicProvider: ProviderConfig = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const payload: any = {
|
const payload: any = {
|
||||||
model: request.model || 'claude-3-7-sonnet-20250219',
|
model: request.model,
|
||||||
messages,
|
messages,
|
||||||
system: systemPrompt,
|
system: systemPrompt,
|
||||||
max_tokens: Number.parseInt(String(request.maxTokens)) || 1024,
|
max_tokens: Number.parseInt(String(request.maxTokens)) || 1024,
|
||||||
@@ -561,37 +561,93 @@ export const anthropicProvider: ProviderConfig = {
|
|||||||
throw error
|
throw error
|
||||||
}
|
}
|
||||||
|
|
||||||
const providerEndTime = Date.now()
|
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
|
||||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
|
||||||
const totalDuration = providerEndTime - providerStartTime
|
|
||||||
|
|
||||||
return {
|
const streamingPayload = {
|
||||||
content,
|
...payload,
|
||||||
model: request.model || 'claude-3-7-sonnet-20250219',
|
messages: currentMessages,
|
||||||
tokens,
|
stream: true,
|
||||||
toolCalls:
|
tool_choice: undefined,
|
||||||
toolCalls.length > 0
|
}
|
||||||
? toolCalls.map((tc) => ({
|
|
||||||
name: tc.name,
|
const streamResponse: any = await anthropic.messages.create(streamingPayload)
|
||||||
arguments: tc.arguments as Record<string, any>,
|
|
||||||
startTime: tc.startTime,
|
const streamingResult = {
|
||||||
endTime: tc.endTime,
|
stream: createReadableStreamFromAnthropicStream(
|
||||||
duration: tc.duration,
|
streamResponse,
|
||||||
result: tc.result,
|
(streamContent, usage) => {
|
||||||
}))
|
streamingResult.execution.output.content = streamContent
|
||||||
: undefined,
|
streamingResult.execution.output.tokens = {
|
||||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
prompt: tokens.prompt + usage.input_tokens,
|
||||||
timing: {
|
completion: tokens.completion + usage.output_tokens,
|
||||||
startTime: providerStartTimeISO,
|
total: tokens.total + usage.input_tokens + usage.output_tokens,
|
||||||
endTime: providerEndTimeISO,
|
}
|
||||||
duration: totalDuration,
|
|
||||||
modelTime: modelTime,
|
const streamCost = calculateCost(
|
||||||
toolsTime: toolsTime,
|
request.model,
|
||||||
firstResponseTime: firstResponseTime,
|
usage.input_tokens,
|
||||||
iterations: iterationCount + 1,
|
usage.output_tokens
|
||||||
timeSegments: timeSegments,
|
)
|
||||||
|
streamingResult.execution.output.cost = {
|
||||||
|
input: accumulatedCost.input + streamCost.input,
|
||||||
|
output: accumulatedCost.output + streamCost.output,
|
||||||
|
total: accumulatedCost.total + streamCost.total,
|
||||||
|
}
|
||||||
|
|
||||||
|
const streamEndTime = Date.now()
|
||||||
|
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||||
|
|
||||||
|
if (streamingResult.execution.output.providerTiming) {
|
||||||
|
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
|
||||||
|
streamingResult.execution.output.providerTiming.duration =
|
||||||
|
streamEndTime - providerStartTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
execution: {
|
||||||
|
success: true,
|
||||||
|
output: {
|
||||||
|
content: '',
|
||||||
|
model: request.model,
|
||||||
|
tokens: {
|
||||||
|
prompt: tokens.prompt,
|
||||||
|
completion: tokens.completion,
|
||||||
|
total: tokens.total,
|
||||||
|
},
|
||||||
|
toolCalls:
|
||||||
|
toolCalls.length > 0
|
||||||
|
? {
|
||||||
|
list: toolCalls,
|
||||||
|
count: toolCalls.length,
|
||||||
|
}
|
||||||
|
: undefined,
|
||||||
|
providerTiming: {
|
||||||
|
startTime: providerStartTimeISO,
|
||||||
|
endTime: new Date().toISOString(),
|
||||||
|
duration: Date.now() - providerStartTime,
|
||||||
|
modelTime: modelTime,
|
||||||
|
toolsTime: toolsTime,
|
||||||
|
firstResponseTime: firstResponseTime,
|
||||||
|
iterations: iterationCount + 1,
|
||||||
|
timeSegments: timeSegments,
|
||||||
|
},
|
||||||
|
cost: {
|
||||||
|
input: accumulatedCost.input,
|
||||||
|
output: accumulatedCost.output,
|
||||||
|
total: accumulatedCost.total,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
logs: [],
|
||||||
|
metadata: {
|
||||||
|
startTime: providerStartTimeISO,
|
||||||
|
endTime: new Date().toISOString(),
|
||||||
|
duration: Date.now() - providerStartTime,
|
||||||
|
},
|
||||||
|
isStreaming: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return streamingResult as StreamingExecution
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const providerEndTime = Date.now()
|
const providerEndTime = Date.now()
|
||||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||||
@@ -934,7 +990,7 @@ export const anthropicProvider: ProviderConfig = {
|
|||||||
success: true,
|
success: true,
|
||||||
output: {
|
output: {
|
||||||
content: '',
|
content: '',
|
||||||
model: request.model || 'claude-3-7-sonnet-20250219',
|
model: request.model,
|
||||||
tokens: {
|
tokens: {
|
||||||
prompt: tokens.prompt,
|
prompt: tokens.prompt,
|
||||||
completion: tokens.completion,
|
completion: tokens.completion,
|
||||||
@@ -978,7 +1034,7 @@ export const anthropicProvider: ProviderConfig = {
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
content,
|
content,
|
||||||
model: request.model || 'claude-3-7-sonnet-20250219',
|
model: request.model,
|
||||||
tokens,
|
tokens,
|
||||||
toolCalls:
|
toolCalls:
|
||||||
toolCalls.length > 0
|
toolCalls.length > 0
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ export const azureOpenAIProvider: ProviderConfig = {
|
|||||||
request: ProviderRequest
|
request: ProviderRequest
|
||||||
): Promise<ProviderResponse | StreamingExecution> => {
|
): Promise<ProviderResponse | StreamingExecution> => {
|
||||||
logger.info('Preparing Azure OpenAI request', {
|
logger.info('Preparing Azure OpenAI request', {
|
||||||
model: request.model || 'azure/gpt-4o',
|
model: request.model,
|
||||||
hasSystemPrompt: !!request.systemPrompt,
|
hasSystemPrompt: !!request.systemPrompt,
|
||||||
hasMessages: !!request.messages?.length,
|
hasMessages: !!request.messages?.length,
|
||||||
hasTools: !!request.tools?.length,
|
hasTools: !!request.tools?.length,
|
||||||
@@ -95,7 +95,7 @@ export const azureOpenAIProvider: ProviderConfig = {
|
|||||||
}))
|
}))
|
||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
const deploymentName = (request.model || 'azure/gpt-4o').replace('azure/', '')
|
const deploymentName = request.model.replace('azure/', '')
|
||||||
const payload: any = {
|
const payload: any = {
|
||||||
model: deploymentName,
|
model: deploymentName,
|
||||||
messages: allMessages,
|
messages: allMessages,
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ export const cerebrasProvider: ProviderConfig = {
|
|||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
const payload: any = {
|
const payload: any = {
|
||||||
model: (request.model || 'cerebras/llama-3.3-70b').replace('cerebras/', ''),
|
model: request.model.replace('cerebras/', ''),
|
||||||
messages: allMessages,
|
messages: allMessages,
|
||||||
}
|
}
|
||||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||||
@@ -145,7 +145,7 @@ export const cerebrasProvider: ProviderConfig = {
|
|||||||
success: true,
|
success: true,
|
||||||
output: {
|
output: {
|
||||||
content: '',
|
content: '',
|
||||||
model: request.model || 'cerebras/llama-3.3-70b',
|
model: request.model,
|
||||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||||
toolCalls: undefined,
|
toolCalls: undefined,
|
||||||
providerTiming: {
|
providerTiming: {
|
||||||
@@ -470,7 +470,7 @@ export const cerebrasProvider: ProviderConfig = {
|
|||||||
success: true,
|
success: true,
|
||||||
output: {
|
output: {
|
||||||
content: '',
|
content: '',
|
||||||
model: request.model || 'cerebras/llama-3.3-70b',
|
model: request.model,
|
||||||
tokens: {
|
tokens: {
|
||||||
prompt: tokens.prompt,
|
prompt: tokens.prompt,
|
||||||
completion: tokens.completion,
|
completion: tokens.completion,
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ export const deepseekProvider: ProviderConfig = {
|
|||||||
: toolChoice.type === 'any'
|
: toolChoice.type === 'any'
|
||||||
? `force:${toolChoice.any?.name || 'unknown'}`
|
? `force:${toolChoice.any?.name || 'unknown'}`
|
||||||
: 'unknown',
|
: 'unknown',
|
||||||
model: request.model || 'deepseek-v3',
|
model: request.model,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -145,7 +145,7 @@ export const deepseekProvider: ProviderConfig = {
|
|||||||
success: true,
|
success: true,
|
||||||
output: {
|
output: {
|
||||||
content: '',
|
content: '',
|
||||||
model: request.model || 'deepseek-chat',
|
model: request.model,
|
||||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||||
toolCalls: undefined,
|
toolCalls: undefined,
|
||||||
providerTiming: {
|
providerTiming: {
|
||||||
@@ -469,7 +469,7 @@ export const deepseekProvider: ProviderConfig = {
|
|||||||
success: true,
|
success: true,
|
||||||
output: {
|
output: {
|
||||||
content: '',
|
content: '',
|
||||||
model: request.model || 'deepseek-chat',
|
model: request.model,
|
||||||
tokens: {
|
tokens: {
|
||||||
prompt: tokens.prompt,
|
prompt: tokens.prompt,
|
||||||
completion: tokens.completion,
|
completion: tokens.completion,
|
||||||
|
|||||||
58
apps/sim/providers/gemini/client.ts
Normal file
58
apps/sim/providers/gemini/client.ts
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import { GoogleGenAI } from '@google/genai'
|
||||||
|
import { createLogger } from '@/lib/logs/console/logger'
|
||||||
|
import type { GeminiClientConfig } from './types'
|
||||||
|
|
||||||
|
const logger = createLogger('GeminiClient')
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a GoogleGenAI client configured for either Google Gemini API or Vertex AI
|
||||||
|
*
|
||||||
|
* For Google Gemini API:
|
||||||
|
* - Uses API key authentication
|
||||||
|
*
|
||||||
|
* For Vertex AI:
|
||||||
|
* - Uses OAuth access token via HTTP Authorization header
|
||||||
|
* - Requires project and location
|
||||||
|
*/
|
||||||
|
export function createGeminiClient(config: GeminiClientConfig): GoogleGenAI {
|
||||||
|
if (config.vertexai) {
|
||||||
|
if (!config.project) {
|
||||||
|
throw new Error('Vertex AI requires a project ID')
|
||||||
|
}
|
||||||
|
if (!config.accessToken) {
|
||||||
|
throw new Error('Vertex AI requires an access token')
|
||||||
|
}
|
||||||
|
|
||||||
|
const location = config.location ?? 'us-central1'
|
||||||
|
|
||||||
|
logger.info('Creating Vertex AI client', {
|
||||||
|
project: config.project,
|
||||||
|
location,
|
||||||
|
hasAccessToken: !!config.accessToken,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create client with Vertex AI configuration
|
||||||
|
// Use httpOptions.headers to pass the access token directly
|
||||||
|
return new GoogleGenAI({
|
||||||
|
vertexai: true,
|
||||||
|
project: config.project,
|
||||||
|
location,
|
||||||
|
httpOptions: {
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${config.accessToken}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Google Gemini API with API key
|
||||||
|
if (!config.apiKey) {
|
||||||
|
throw new Error('Google Gemini API requires an API key')
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info('Creating Google Gemini client')
|
||||||
|
|
||||||
|
return new GoogleGenAI({
|
||||||
|
apiKey: config.apiKey,
|
||||||
|
})
|
||||||
|
}
|
||||||
680
apps/sim/providers/gemini/core.ts
Normal file
680
apps/sim/providers/gemini/core.ts
Normal file
@@ -0,0 +1,680 @@
|
|||||||
|
import {
|
||||||
|
type Content,
|
||||||
|
FunctionCallingConfigMode,
|
||||||
|
type FunctionDeclaration,
|
||||||
|
type GenerateContentConfig,
|
||||||
|
type GenerateContentResponse,
|
||||||
|
type GoogleGenAI,
|
||||||
|
type Part,
|
||||||
|
type Schema,
|
||||||
|
type ThinkingConfig,
|
||||||
|
type ToolConfig,
|
||||||
|
} from '@google/genai'
|
||||||
|
import { createLogger } from '@/lib/logs/console/logger'
|
||||||
|
import type { StreamingExecution } from '@/executor/types'
|
||||||
|
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||||
|
import {
|
||||||
|
checkForForcedToolUsage,
|
||||||
|
cleanSchemaForGemini,
|
||||||
|
convertToGeminiFormat,
|
||||||
|
convertUsageMetadata,
|
||||||
|
createReadableStreamFromGeminiStream,
|
||||||
|
extractFunctionCallPart,
|
||||||
|
extractTextContent,
|
||||||
|
mapToThinkingLevel,
|
||||||
|
} from '@/providers/google/utils'
|
||||||
|
import { getThinkingCapability } from '@/providers/models'
|
||||||
|
import type { FunctionCallResponse, ProviderRequest, ProviderResponse } from '@/providers/types'
|
||||||
|
import {
|
||||||
|
calculateCost,
|
||||||
|
prepareToolExecution,
|
||||||
|
prepareToolsWithUsageControl,
|
||||||
|
} from '@/providers/utils'
|
||||||
|
import { executeTool } from '@/tools'
|
||||||
|
import type { ExecutionState, GeminiProviderType, GeminiUsage, ParsedFunctionCall } from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates initial execution state
|
||||||
|
*/
|
||||||
|
function createInitialState(
|
||||||
|
contents: Content[],
|
||||||
|
initialUsage: GeminiUsage,
|
||||||
|
firstResponseTime: number,
|
||||||
|
initialCallTime: number,
|
||||||
|
model: string,
|
||||||
|
toolConfig: ToolConfig | undefined
|
||||||
|
): ExecutionState {
|
||||||
|
const initialCost = calculateCost(
|
||||||
|
model,
|
||||||
|
initialUsage.promptTokenCount,
|
||||||
|
initialUsage.candidatesTokenCount
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
contents,
|
||||||
|
tokens: {
|
||||||
|
prompt: initialUsage.promptTokenCount,
|
||||||
|
completion: initialUsage.candidatesTokenCount,
|
||||||
|
total: initialUsage.totalTokenCount,
|
||||||
|
},
|
||||||
|
cost: initialCost,
|
||||||
|
toolCalls: [],
|
||||||
|
toolResults: [],
|
||||||
|
iterationCount: 0,
|
||||||
|
modelTime: firstResponseTime,
|
||||||
|
toolsTime: 0,
|
||||||
|
timeSegments: [
|
||||||
|
{
|
||||||
|
type: 'model',
|
||||||
|
name: 'Initial response',
|
||||||
|
startTime: initialCallTime,
|
||||||
|
endTime: initialCallTime + firstResponseTime,
|
||||||
|
duration: firstResponseTime,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usedForcedTools: [],
|
||||||
|
currentToolConfig: toolConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Executes a tool call and updates state
|
||||||
|
*/
|
||||||
|
async function executeToolCall(
|
||||||
|
functionCallPart: Part,
|
||||||
|
functionCall: ParsedFunctionCall,
|
||||||
|
request: ProviderRequest,
|
||||||
|
state: ExecutionState,
|
||||||
|
forcedTools: string[],
|
||||||
|
logger: ReturnType<typeof createLogger>
|
||||||
|
): Promise<{ success: boolean; state: ExecutionState }> {
|
||||||
|
const toolCallStartTime = Date.now()
|
||||||
|
const toolName = functionCall.name
|
||||||
|
|
||||||
|
const tool = request.tools?.find((t) => t.id === toolName)
|
||||||
|
if (!tool) {
|
||||||
|
logger.warn(`Tool ${toolName} not found in registry, skipping`)
|
||||||
|
return { success: false, state }
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { toolParams, executionParams } = prepareToolExecution(tool, functionCall.args, request)
|
||||||
|
const result = await executeTool(toolName, executionParams, true)
|
||||||
|
const toolCallEndTime = Date.now()
|
||||||
|
const duration = toolCallEndTime - toolCallStartTime
|
||||||
|
|
||||||
|
const resultContent: Record<string, unknown> = result.success
|
||||||
|
? (result.output as Record<string, unknown>)
|
||||||
|
: { error: true, message: result.error || 'Tool execution failed', tool: toolName }
|
||||||
|
|
||||||
|
const toolCall: FunctionCallResponse = {
|
||||||
|
name: toolName,
|
||||||
|
arguments: toolParams,
|
||||||
|
startTime: new Date(toolCallStartTime).toISOString(),
|
||||||
|
endTime: new Date(toolCallEndTime).toISOString(),
|
||||||
|
duration,
|
||||||
|
result: resultContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
const updatedContents: Content[] = [
|
||||||
|
...state.contents,
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [functionCallPart],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
name: functionCall.name,
|
||||||
|
response: resultContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
const forcedToolCheck = checkForForcedToolUsage(
|
||||||
|
[{ name: functionCall.name, args: functionCall.args }],
|
||||||
|
state.currentToolConfig,
|
||||||
|
forcedTools,
|
||||||
|
state.usedForcedTools
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
success: true,
|
||||||
|
state: {
|
||||||
|
...state,
|
||||||
|
contents: updatedContents,
|
||||||
|
toolCalls: [...state.toolCalls, toolCall],
|
||||||
|
toolResults: result.success
|
||||||
|
? [...state.toolResults, result.output as Record<string, unknown>]
|
||||||
|
: state.toolResults,
|
||||||
|
toolsTime: state.toolsTime + duration,
|
||||||
|
timeSegments: [
|
||||||
|
...state.timeSegments,
|
||||||
|
{
|
||||||
|
type: 'tool',
|
||||||
|
name: toolName,
|
||||||
|
startTime: toolCallStartTime,
|
||||||
|
endTime: toolCallEndTime,
|
||||||
|
duration,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usedForcedTools: forcedToolCheck?.usedForcedTools ?? state.usedForcedTools,
|
||||||
|
currentToolConfig: forcedToolCheck?.nextToolConfig ?? state.currentToolConfig,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error processing function call:', {
|
||||||
|
error: error instanceof Error ? error.message : String(error),
|
||||||
|
functionName: toolName,
|
||||||
|
})
|
||||||
|
return { success: false, state }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Updates state with model response metadata
|
||||||
|
*/
|
||||||
|
function updateStateWithResponse(
|
||||||
|
state: ExecutionState,
|
||||||
|
response: GenerateContentResponse,
|
||||||
|
model: string,
|
||||||
|
startTime: number,
|
||||||
|
endTime: number
|
||||||
|
): ExecutionState {
|
||||||
|
const usage = convertUsageMetadata(response.usageMetadata)
|
||||||
|
const cost = calculateCost(model, usage.promptTokenCount, usage.candidatesTokenCount)
|
||||||
|
const duration = endTime - startTime
|
||||||
|
|
||||||
|
return {
|
||||||
|
...state,
|
||||||
|
tokens: {
|
||||||
|
prompt: state.tokens.prompt + usage.promptTokenCount,
|
||||||
|
completion: state.tokens.completion + usage.candidatesTokenCount,
|
||||||
|
total: state.tokens.total + usage.totalTokenCount,
|
||||||
|
},
|
||||||
|
cost: {
|
||||||
|
input: state.cost.input + cost.input,
|
||||||
|
output: state.cost.output + cost.output,
|
||||||
|
total: state.cost.total + cost.total,
|
||||||
|
pricing: cost.pricing, // Use latest pricing
|
||||||
|
},
|
||||||
|
modelTime: state.modelTime + duration,
|
||||||
|
timeSegments: [
|
||||||
|
...state.timeSegments,
|
||||||
|
{
|
||||||
|
type: 'model',
|
||||||
|
name: `Model response (iteration ${state.iterationCount + 1})`,
|
||||||
|
startTime,
|
||||||
|
endTime,
|
||||||
|
duration,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
iterationCount: state.iterationCount + 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds config for next iteration
|
||||||
|
*/
|
||||||
|
function buildNextConfig(
|
||||||
|
baseConfig: GenerateContentConfig,
|
||||||
|
state: ExecutionState,
|
||||||
|
forcedTools: string[],
|
||||||
|
request: ProviderRequest,
|
||||||
|
logger: ReturnType<typeof createLogger>
|
||||||
|
): GenerateContentConfig {
|
||||||
|
const nextConfig = { ...baseConfig }
|
||||||
|
const allForcedToolsUsed =
|
||||||
|
forcedTools.length > 0 && state.usedForcedTools.length === forcedTools.length
|
||||||
|
|
||||||
|
if (allForcedToolsUsed && request.responseFormat) {
|
||||||
|
nextConfig.tools = undefined
|
||||||
|
nextConfig.toolConfig = undefined
|
||||||
|
nextConfig.responseMimeType = 'application/json'
|
||||||
|
nextConfig.responseSchema = cleanSchemaForGemini(request.responseFormat.schema) as Schema
|
||||||
|
logger.info('Using structured output for final response after tool execution')
|
||||||
|
} else if (state.currentToolConfig) {
|
||||||
|
nextConfig.toolConfig = state.currentToolConfig
|
||||||
|
} else {
|
||||||
|
nextConfig.toolConfig = { functionCallingConfig: { mode: FunctionCallingConfigMode.AUTO } }
|
||||||
|
}
|
||||||
|
|
||||||
|
return nextConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates streaming execution result template
|
||||||
|
*/
|
||||||
|
function createStreamingResult(
|
||||||
|
providerStartTime: number,
|
||||||
|
providerStartTimeISO: string,
|
||||||
|
firstResponseTime: number,
|
||||||
|
initialCallTime: number,
|
||||||
|
state?: ExecutionState
|
||||||
|
): StreamingExecution {
|
||||||
|
return {
|
||||||
|
stream: undefined as unknown as ReadableStream<Uint8Array>,
|
||||||
|
execution: {
|
||||||
|
success: true,
|
||||||
|
output: {
|
||||||
|
content: '',
|
||||||
|
model: '',
|
||||||
|
tokens: state?.tokens ?? { prompt: 0, completion: 0, total: 0 },
|
||||||
|
toolCalls: state?.toolCalls.length
|
||||||
|
? { list: state.toolCalls, count: state.toolCalls.length }
|
||||||
|
: undefined,
|
||||||
|
toolResults: state?.toolResults,
|
||||||
|
providerTiming: {
|
||||||
|
startTime: providerStartTimeISO,
|
||||||
|
endTime: new Date().toISOString(),
|
||||||
|
duration: Date.now() - providerStartTime,
|
||||||
|
modelTime: state?.modelTime ?? firstResponseTime,
|
||||||
|
toolsTime: state?.toolsTime ?? 0,
|
||||||
|
firstResponseTime,
|
||||||
|
iterations: (state?.iterationCount ?? 0) + 1,
|
||||||
|
timeSegments: state?.timeSegments ?? [
|
||||||
|
{
|
||||||
|
type: 'model',
|
||||||
|
name: 'Initial streaming response',
|
||||||
|
startTime: initialCallTime,
|
||||||
|
endTime: initialCallTime + firstResponseTime,
|
||||||
|
duration: firstResponseTime,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
cost: state?.cost ?? {
|
||||||
|
input: 0,
|
||||||
|
output: 0,
|
||||||
|
total: 0,
|
||||||
|
pricing: { input: 0, output: 0, updatedAt: new Date().toISOString() },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
logs: [],
|
||||||
|
metadata: {
|
||||||
|
startTime: providerStartTimeISO,
|
||||||
|
endTime: new Date().toISOString(),
|
||||||
|
duration: Date.now() - providerStartTime,
|
||||||
|
},
|
||||||
|
isStreaming: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration for executing a Gemini request
|
||||||
|
*/
|
||||||
|
export interface GeminiExecutionConfig {
|
||||||
|
ai: GoogleGenAI
|
||||||
|
model: string
|
||||||
|
request: ProviderRequest
|
||||||
|
providerType: GeminiProviderType
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Executes a request using the Gemini API
|
||||||
|
*
|
||||||
|
* This is the shared core logic for both Google and Vertex AI providers.
|
||||||
|
* The only difference is how the GoogleGenAI client is configured.
|
||||||
|
*/
|
||||||
|
export async function executeGeminiRequest(
|
||||||
|
config: GeminiExecutionConfig
|
||||||
|
): Promise<ProviderResponse | StreamingExecution> {
|
||||||
|
const { ai, model, request, providerType } = config
|
||||||
|
const logger = createLogger(providerType === 'google' ? 'GoogleProvider' : 'VertexProvider')
|
||||||
|
|
||||||
|
logger.info(`Preparing ${providerType} Gemini request`, {
|
||||||
|
model,
|
||||||
|
hasSystemPrompt: !!request.systemPrompt,
|
||||||
|
hasMessages: !!request.messages?.length,
|
||||||
|
hasTools: !!request.tools?.length,
|
||||||
|
toolCount: request.tools?.length ?? 0,
|
||||||
|
hasResponseFormat: !!request.responseFormat,
|
||||||
|
streaming: !!request.stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
const providerStartTime = Date.now()
|
||||||
|
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
|
||||||
|
|
||||||
|
// Build configuration
|
||||||
|
const geminiConfig: GenerateContentConfig = {}
|
||||||
|
|
||||||
|
if (request.temperature !== undefined) {
|
||||||
|
geminiConfig.temperature = request.temperature
|
||||||
|
}
|
||||||
|
if (request.maxTokens !== undefined) {
|
||||||
|
geminiConfig.maxOutputTokens = request.maxTokens
|
||||||
|
}
|
||||||
|
if (systemInstruction) {
|
||||||
|
geminiConfig.systemInstruction = systemInstruction
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle response format (only when no tools)
|
||||||
|
if (request.responseFormat && !tools?.length) {
|
||||||
|
geminiConfig.responseMimeType = 'application/json'
|
||||||
|
geminiConfig.responseSchema = cleanSchemaForGemini(request.responseFormat.schema) as Schema
|
||||||
|
logger.info('Using Gemini native structured output format')
|
||||||
|
} else if (request.responseFormat && tools?.length) {
|
||||||
|
logger.warn('Gemini does not support responseFormat with tools. Structured output ignored.')
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure thinking for models that support it
|
||||||
|
const thinkingCapability = getThinkingCapability(model)
|
||||||
|
if (thinkingCapability) {
|
||||||
|
const level = request.thinkingLevel ?? thinkingCapability.default ?? 'high'
|
||||||
|
const thinkingConfig: ThinkingConfig = {
|
||||||
|
includeThoughts: false,
|
||||||
|
thinkingLevel: mapToThinkingLevel(level),
|
||||||
|
}
|
||||||
|
geminiConfig.thinkingConfig = thinkingConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare tools
|
||||||
|
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||||
|
let toolConfig: ToolConfig | undefined
|
||||||
|
|
||||||
|
if (tools?.length) {
|
||||||
|
const functionDeclarations: FunctionDeclaration[] = tools.map((t) => ({
|
||||||
|
name: t.name,
|
||||||
|
description: t.description,
|
||||||
|
parameters: t.parameters,
|
||||||
|
}))
|
||||||
|
|
||||||
|
preparedTools = prepareToolsWithUsageControl(
|
||||||
|
functionDeclarations,
|
||||||
|
request.tools,
|
||||||
|
logger,
|
||||||
|
'google'
|
||||||
|
)
|
||||||
|
const { tools: filteredTools, toolConfig: preparedToolConfig } = preparedTools
|
||||||
|
|
||||||
|
if (filteredTools?.length) {
|
||||||
|
geminiConfig.tools = [{ functionDeclarations: filteredTools as FunctionDeclaration[] }]
|
||||||
|
|
||||||
|
if (preparedToolConfig) {
|
||||||
|
toolConfig = {
|
||||||
|
functionCallingConfig: {
|
||||||
|
mode:
|
||||||
|
{
|
||||||
|
AUTO: FunctionCallingConfigMode.AUTO,
|
||||||
|
ANY: FunctionCallingConfigMode.ANY,
|
||||||
|
NONE: FunctionCallingConfigMode.NONE,
|
||||||
|
}[preparedToolConfig.functionCallingConfig.mode] ?? FunctionCallingConfigMode.AUTO,
|
||||||
|
allowedFunctionNames: preparedToolConfig.functionCallingConfig.allowedFunctionNames,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
geminiConfig.toolConfig = toolConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info('Gemini request with tools:', {
|
||||||
|
toolCount: filteredTools.length,
|
||||||
|
model,
|
||||||
|
tools: filteredTools.map((t) => (t as FunctionDeclaration).name),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const initialCallTime = Date.now()
|
||||||
|
const shouldStream = request.stream && !tools?.length
|
||||||
|
|
||||||
|
// Streaming without tools
|
||||||
|
if (shouldStream) {
|
||||||
|
logger.info('Handling Gemini streaming response')
|
||||||
|
|
||||||
|
const streamGenerator = await ai.models.generateContentStream({
|
||||||
|
model,
|
||||||
|
contents,
|
||||||
|
config: geminiConfig,
|
||||||
|
})
|
||||||
|
const firstResponseTime = Date.now() - initialCallTime
|
||||||
|
|
||||||
|
const streamingResult = createStreamingResult(
|
||||||
|
providerStartTime,
|
||||||
|
providerStartTimeISO,
|
||||||
|
firstResponseTime,
|
||||||
|
initialCallTime
|
||||||
|
)
|
||||||
|
streamingResult.execution.output.model = model
|
||||||
|
|
||||||
|
streamingResult.stream = createReadableStreamFromGeminiStream(
|
||||||
|
streamGenerator,
|
||||||
|
(content: string, usage: GeminiUsage) => {
|
||||||
|
streamingResult.execution.output.content = content
|
||||||
|
streamingResult.execution.output.tokens = {
|
||||||
|
prompt: usage.promptTokenCount,
|
||||||
|
completion: usage.candidatesTokenCount,
|
||||||
|
total: usage.totalTokenCount,
|
||||||
|
}
|
||||||
|
|
||||||
|
const costResult = calculateCost(
|
||||||
|
model,
|
||||||
|
usage.promptTokenCount,
|
||||||
|
usage.candidatesTokenCount
|
||||||
|
)
|
||||||
|
streamingResult.execution.output.cost = costResult
|
||||||
|
|
||||||
|
const streamEndTime = Date.now()
|
||||||
|
if (streamingResult.execution.output.providerTiming) {
|
||||||
|
streamingResult.execution.output.providerTiming.endTime = new Date(
|
||||||
|
streamEndTime
|
||||||
|
).toISOString()
|
||||||
|
streamingResult.execution.output.providerTiming.duration =
|
||||||
|
streamEndTime - providerStartTime
|
||||||
|
const segments = streamingResult.execution.output.providerTiming.timeSegments
|
||||||
|
if (segments?.[0]) {
|
||||||
|
segments[0].endTime = streamEndTime
|
||||||
|
segments[0].duration = streamEndTime - providerStartTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return streamingResult
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-streaming request
|
||||||
|
const response = await ai.models.generateContent({ model, contents, config: geminiConfig })
|
||||||
|
const firstResponseTime = Date.now() - initialCallTime
|
||||||
|
|
||||||
|
// Check for UNEXPECTED_TOOL_CALL
|
||||||
|
const candidate = response.candidates?.[0]
|
||||||
|
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
|
||||||
|
logger.warn('Gemini returned UNEXPECTED_TOOL_CALL - model attempted to call unknown tool')
|
||||||
|
}
|
||||||
|
|
||||||
|
const initialUsage = convertUsageMetadata(response.usageMetadata)
|
||||||
|
let state = createInitialState(
|
||||||
|
contents,
|
||||||
|
initialUsage,
|
||||||
|
firstResponseTime,
|
||||||
|
initialCallTime,
|
||||||
|
model,
|
||||||
|
toolConfig
|
||||||
|
)
|
||||||
|
const forcedTools = preparedTools?.forcedTools ?? []
|
||||||
|
|
||||||
|
let currentResponse = response
|
||||||
|
let content = ''
|
||||||
|
|
||||||
|
// Tool execution loop
|
||||||
|
const functionCalls = response.functionCalls
|
||||||
|
if (functionCalls?.length) {
|
||||||
|
logger.info(`Received function call from Gemini: ${functionCalls[0].name}`)
|
||||||
|
|
||||||
|
while (state.iterationCount < MAX_TOOL_ITERATIONS) {
|
||||||
|
const functionCallPart = extractFunctionCallPart(currentResponse.candidates?.[0])
|
||||||
|
if (!functionCallPart?.functionCall) {
|
||||||
|
content = extractTextContent(currentResponse.candidates?.[0])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
const functionCall: ParsedFunctionCall = {
|
||||||
|
name: functionCallPart.functionCall.name ?? '',
|
||||||
|
args: (functionCallPart.functionCall.args ?? {}) as Record<string, unknown>,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
`Processing function call: ${functionCall.name} (iteration ${state.iterationCount + 1})`
|
||||||
|
)
|
||||||
|
|
||||||
|
const { success, state: updatedState } = await executeToolCall(
|
||||||
|
functionCallPart,
|
||||||
|
functionCall,
|
||||||
|
request,
|
||||||
|
state,
|
||||||
|
forcedTools,
|
||||||
|
logger
|
||||||
|
)
|
||||||
|
if (!success) {
|
||||||
|
content = extractTextContent(currentResponse.candidates?.[0])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
state = { ...updatedState, iterationCount: updatedState.iterationCount + 1 }
|
||||||
|
const nextConfig = buildNextConfig(geminiConfig, state, forcedTools, request, logger)
|
||||||
|
|
||||||
|
// Stream final response if requested
|
||||||
|
if (request.stream) {
|
||||||
|
const checkResponse = await ai.models.generateContent({
|
||||||
|
model,
|
||||||
|
contents: state.contents,
|
||||||
|
config: nextConfig,
|
||||||
|
})
|
||||||
|
state = updateStateWithResponse(state, checkResponse, model, Date.now() - 100, Date.now())
|
||||||
|
|
||||||
|
if (checkResponse.functionCalls?.length) {
|
||||||
|
currentResponse = checkResponse
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info('No more function calls, streaming final response')
|
||||||
|
|
||||||
|
if (request.responseFormat) {
|
||||||
|
nextConfig.tools = undefined
|
||||||
|
nextConfig.toolConfig = undefined
|
||||||
|
nextConfig.responseMimeType = 'application/json'
|
||||||
|
nextConfig.responseSchema = cleanSchemaForGemini(
|
||||||
|
request.responseFormat.schema
|
||||||
|
) as Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture accumulated cost before streaming
|
||||||
|
const accumulatedCost = {
|
||||||
|
input: state.cost.input,
|
||||||
|
output: state.cost.output,
|
||||||
|
total: state.cost.total,
|
||||||
|
}
|
||||||
|
const accumulatedTokens = { ...state.tokens }
|
||||||
|
|
||||||
|
const streamGenerator = await ai.models.generateContentStream({
|
||||||
|
model,
|
||||||
|
contents: state.contents,
|
||||||
|
config: nextConfig,
|
||||||
|
})
|
||||||
|
|
||||||
|
const streamingResult = createStreamingResult(
|
||||||
|
providerStartTime,
|
||||||
|
providerStartTimeISO,
|
||||||
|
firstResponseTime,
|
||||||
|
initialCallTime,
|
||||||
|
state
|
||||||
|
)
|
||||||
|
streamingResult.execution.output.model = model
|
||||||
|
|
||||||
|
streamingResult.stream = createReadableStreamFromGeminiStream(
|
||||||
|
streamGenerator,
|
||||||
|
(streamContent: string, usage: GeminiUsage) => {
|
||||||
|
streamingResult.execution.output.content = streamContent
|
||||||
|
streamingResult.execution.output.tokens = {
|
||||||
|
prompt: accumulatedTokens.prompt + usage.promptTokenCount,
|
||||||
|
completion: accumulatedTokens.completion + usage.candidatesTokenCount,
|
||||||
|
total: accumulatedTokens.total + usage.totalTokenCount,
|
||||||
|
}
|
||||||
|
|
||||||
|
const streamCost = calculateCost(
|
||||||
|
model,
|
||||||
|
usage.promptTokenCount,
|
||||||
|
usage.candidatesTokenCount
|
||||||
|
)
|
||||||
|
streamingResult.execution.output.cost = {
|
||||||
|
input: accumulatedCost.input + streamCost.input,
|
||||||
|
output: accumulatedCost.output + streamCost.output,
|
||||||
|
total: accumulatedCost.total + streamCost.total,
|
||||||
|
pricing: streamCost.pricing,
|
||||||
|
}
|
||||||
|
|
||||||
|
if (streamingResult.execution.output.providerTiming) {
|
||||||
|
streamingResult.execution.output.providerTiming.endTime = new Date().toISOString()
|
||||||
|
streamingResult.execution.output.providerTiming.duration =
|
||||||
|
Date.now() - providerStartTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return streamingResult
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-streaming: get next response
|
||||||
|
const nextModelStartTime = Date.now()
|
||||||
|
const nextResponse = await ai.models.generateContent({
|
||||||
|
model,
|
||||||
|
contents: state.contents,
|
||||||
|
config: nextConfig,
|
||||||
|
})
|
||||||
|
state = updateStateWithResponse(state, nextResponse, model, nextModelStartTime, Date.now())
|
||||||
|
currentResponse = nextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!content) {
|
||||||
|
content = extractTextContent(currentResponse.candidates?.[0])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content = extractTextContent(candidate)
|
||||||
|
}
|
||||||
|
|
||||||
|
const providerEndTime = Date.now()
|
||||||
|
|
||||||
|
return {
|
||||||
|
content,
|
||||||
|
model,
|
||||||
|
tokens: state.tokens,
|
||||||
|
toolCalls: state.toolCalls.length ? state.toolCalls : undefined,
|
||||||
|
toolResults: state.toolResults.length ? state.toolResults : undefined,
|
||||||
|
timing: {
|
||||||
|
startTime: providerStartTimeISO,
|
||||||
|
endTime: new Date(providerEndTime).toISOString(),
|
||||||
|
duration: providerEndTime - providerStartTime,
|
||||||
|
modelTime: state.modelTime,
|
||||||
|
toolsTime: state.toolsTime,
|
||||||
|
firstResponseTime,
|
||||||
|
iterations: state.iterationCount + 1,
|
||||||
|
timeSegments: state.timeSegments,
|
||||||
|
},
|
||||||
|
cost: state.cost,
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
const providerEndTime = Date.now()
|
||||||
|
const duration = providerEndTime - providerStartTime
|
||||||
|
|
||||||
|
logger.error('Error in Gemini request:', {
|
||||||
|
error: error instanceof Error ? error.message : String(error),
|
||||||
|
stack: error instanceof Error ? error.stack : undefined,
|
||||||
|
})
|
||||||
|
|
||||||
|
const enhancedError = error instanceof Error ? error : new Error(String(error))
|
||||||
|
Object.assign(enhancedError, {
|
||||||
|
timing: {
|
||||||
|
startTime: providerStartTimeISO,
|
||||||
|
endTime: new Date(providerEndTime).toISOString(),
|
||||||
|
duration,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
throw enhancedError
|
||||||
|
}
|
||||||
|
}
|
||||||
18
apps/sim/providers/gemini/index.ts
Normal file
18
apps/sim/providers/gemini/index.ts
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
/**
|
||||||
|
* Shared Gemini execution core
|
||||||
|
*
|
||||||
|
* This module provides the shared execution logic for both Google Gemini API
|
||||||
|
* and Vertex AI providers. The only difference between providers is how the
|
||||||
|
* GoogleGenAI client is configured (API key vs OAuth).
|
||||||
|
*/
|
||||||
|
|
||||||
|
export { createGeminiClient } from './client'
|
||||||
|
export { executeGeminiRequest, type GeminiExecutionConfig } from './core'
|
||||||
|
export type {
|
||||||
|
ExecutionState,
|
||||||
|
ForcedToolResult,
|
||||||
|
GeminiClientConfig,
|
||||||
|
GeminiProviderType,
|
||||||
|
GeminiUsage,
|
||||||
|
ParsedFunctionCall,
|
||||||
|
} from './types'
|
||||||
64
apps/sim/providers/gemini/types.ts
Normal file
64
apps/sim/providers/gemini/types.ts
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import type { Content, ToolConfig } from '@google/genai'
|
||||||
|
import type { FunctionCallResponse, ModelPricing, TimeSegment } from '@/providers/types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Usage metadata from Gemini responses
|
||||||
|
*/
|
||||||
|
export interface GeminiUsage {
|
||||||
|
promptTokenCount: number
|
||||||
|
candidatesTokenCount: number
|
||||||
|
totalTokenCount: number
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parsed function call from Gemini response
|
||||||
|
*/
|
||||||
|
export interface ParsedFunctionCall {
|
||||||
|
name: string
|
||||||
|
args: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Accumulated state during tool execution loop
|
||||||
|
*/
|
||||||
|
export interface ExecutionState {
|
||||||
|
contents: Content[]
|
||||||
|
tokens: { prompt: number; completion: number; total: number }
|
||||||
|
cost: { input: number; output: number; total: number; pricing: ModelPricing }
|
||||||
|
toolCalls: FunctionCallResponse[]
|
||||||
|
toolResults: Record<string, unknown>[]
|
||||||
|
iterationCount: number
|
||||||
|
modelTime: number
|
||||||
|
toolsTime: number
|
||||||
|
timeSegments: TimeSegment[]
|
||||||
|
usedForcedTools: string[]
|
||||||
|
currentToolConfig: ToolConfig | undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Result from forced tool usage check
|
||||||
|
*/
|
||||||
|
export interface ForcedToolResult {
|
||||||
|
hasUsedForcedTool: boolean
|
||||||
|
usedForcedTools: string[]
|
||||||
|
nextToolConfig: ToolConfig | undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration for creating a Gemini client
|
||||||
|
*/
|
||||||
|
export interface GeminiClientConfig {
|
||||||
|
/** For Google Gemini API */
|
||||||
|
apiKey?: string
|
||||||
|
/** For Vertex AI */
|
||||||
|
vertexai?: boolean
|
||||||
|
project?: string
|
||||||
|
location?: string
|
||||||
|
/** OAuth access token for Vertex AI */
|
||||||
|
accessToken?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider type for logging and model lookup
|
||||||
|
*/
|
||||||
|
export type GeminiProviderType = 'google' | 'vertex'
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,61 +1,89 @@
|
|||||||
import type { Candidate } from '@google/genai'
|
import {
|
||||||
|
type Candidate,
|
||||||
|
type Content,
|
||||||
|
type FunctionCall,
|
||||||
|
FunctionCallingConfigMode,
|
||||||
|
type GenerateContentResponse,
|
||||||
|
type GenerateContentResponseUsageMetadata,
|
||||||
|
type Part,
|
||||||
|
type Schema,
|
||||||
|
type SchemaUnion,
|
||||||
|
ThinkingLevel,
|
||||||
|
type ToolConfig,
|
||||||
|
Type,
|
||||||
|
} from '@google/genai'
|
||||||
|
import { createLogger } from '@/lib/logs/console/logger'
|
||||||
import type { ProviderRequest } from '@/providers/types'
|
import type { ProviderRequest } from '@/providers/types'
|
||||||
|
import { trackForcedToolUsage } from '@/providers/utils'
|
||||||
|
|
||||||
|
const logger = createLogger('GoogleUtils')
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Usage metadata for Google Gemini responses
|
||||||
|
*/
|
||||||
|
export interface GeminiUsage {
|
||||||
|
promptTokenCount: number
|
||||||
|
candidatesTokenCount: number
|
||||||
|
totalTokenCount: number
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parsed function call from Gemini response
|
||||||
|
*/
|
||||||
|
export interface ParsedFunctionCall {
|
||||||
|
name: string
|
||||||
|
args: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Removes additionalProperties from a schema object (not supported by Gemini)
|
* Removes additionalProperties from a schema object (not supported by Gemini)
|
||||||
*/
|
*/
|
||||||
export function cleanSchemaForGemini(schema: any): any {
|
export function cleanSchemaForGemini(schema: SchemaUnion): SchemaUnion {
|
||||||
if (schema === null || schema === undefined) return schema
|
if (schema === null || schema === undefined) return schema
|
||||||
if (typeof schema !== 'object') return schema
|
if (typeof schema !== 'object') return schema
|
||||||
if (Array.isArray(schema)) {
|
if (Array.isArray(schema)) {
|
||||||
return schema.map((item) => cleanSchemaForGemini(item))
|
return schema.map((item) => cleanSchemaForGemini(item))
|
||||||
}
|
}
|
||||||
|
|
||||||
const cleanedSchema: any = {}
|
const cleanedSchema: Record<string, unknown> = {}
|
||||||
|
const schemaObj = schema as Record<string, unknown>
|
||||||
|
|
||||||
for (const key in schema) {
|
for (const key in schemaObj) {
|
||||||
if (key === 'additionalProperties') continue
|
if (key === 'additionalProperties') continue
|
||||||
cleanedSchema[key] = cleanSchemaForGemini(schema[key])
|
cleanedSchema[key] = cleanSchemaForGemini(schemaObj[key] as SchemaUnion)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cleanedSchema
|
return cleanedSchema
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extracts text content from a Gemini response candidate, handling structured output
|
* Extracts text content from a Gemini response candidate.
|
||||||
|
* Filters out thought parts (model reasoning) from the output.
|
||||||
*/
|
*/
|
||||||
export function extractTextContent(candidate: Candidate | undefined): string {
|
export function extractTextContent(candidate: Candidate | undefined): string {
|
||||||
if (!candidate?.content?.parts) return ''
|
if (!candidate?.content?.parts) return ''
|
||||||
|
|
||||||
if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) {
|
const textParts = candidate.content.parts.filter(
|
||||||
const text = candidate.content.parts[0].text
|
(part): part is Part & { text: string } => Boolean(part.text) && part.thought !== true
|
||||||
if (text && (text.trim().startsWith('{') || text.trim().startsWith('['))) {
|
)
|
||||||
try {
|
|
||||||
JSON.parse(text)
|
|
||||||
return text
|
|
||||||
} catch (_e) {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return candidate.content.parts
|
if (textParts.length === 0) return ''
|
||||||
.filter((part: any) => part.text)
|
if (textParts.length === 1) return textParts[0].text
|
||||||
.map((part: any) => part.text)
|
|
||||||
.join('\n')
|
return textParts.map((part) => part.text).join('\n')
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extracts a function call from a Gemini response candidate
|
* Extracts the first function call from a Gemini response candidate
|
||||||
*/
|
*/
|
||||||
export function extractFunctionCall(
|
export function extractFunctionCall(candidate: Candidate | undefined): ParsedFunctionCall | null {
|
||||||
candidate: Candidate | undefined
|
|
||||||
): { name: string; args: any } | null {
|
|
||||||
if (!candidate?.content?.parts) return null
|
if (!candidate?.content?.parts) return null
|
||||||
|
|
||||||
for (const part of candidate.content.parts) {
|
for (const part of candidate.content.parts) {
|
||||||
if (part.functionCall) {
|
if (part.functionCall) {
|
||||||
return {
|
return {
|
||||||
name: part.functionCall.name ?? '',
|
name: part.functionCall.name ?? '',
|
||||||
args: part.functionCall.args ?? {},
|
args: (part.functionCall.args ?? {}) as Record<string, unknown>,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -63,16 +91,55 @@ export function extractFunctionCall(
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts the full Part containing the function call (preserves thoughtSignature)
|
||||||
|
*/
|
||||||
|
export function extractFunctionCallPart(candidate: Candidate | undefined): Part | null {
|
||||||
|
if (!candidate?.content?.parts) return null
|
||||||
|
|
||||||
|
for (const part of candidate.content.parts) {
|
||||||
|
if (part.functionCall) {
|
||||||
|
return part
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts usage metadata from SDK response to our format
|
||||||
|
*/
|
||||||
|
export function convertUsageMetadata(
|
||||||
|
usageMetadata: GenerateContentResponseUsageMetadata | undefined
|
||||||
|
): GeminiUsage {
|
||||||
|
const promptTokenCount = usageMetadata?.promptTokenCount ?? 0
|
||||||
|
const candidatesTokenCount = usageMetadata?.candidatesTokenCount ?? 0
|
||||||
|
return {
|
||||||
|
promptTokenCount,
|
||||||
|
candidatesTokenCount,
|
||||||
|
totalTokenCount: usageMetadata?.totalTokenCount ?? promptTokenCount + candidatesTokenCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tool definition for Gemini format
|
||||||
|
*/
|
||||||
|
export interface GeminiToolDef {
|
||||||
|
name: string
|
||||||
|
description: string
|
||||||
|
parameters: Schema
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Converts OpenAI-style request format to Gemini format
|
* Converts OpenAI-style request format to Gemini format
|
||||||
*/
|
*/
|
||||||
export function convertToGeminiFormat(request: ProviderRequest): {
|
export function convertToGeminiFormat(request: ProviderRequest): {
|
||||||
contents: any[]
|
contents: Content[]
|
||||||
tools: any[] | undefined
|
tools: GeminiToolDef[] | undefined
|
||||||
systemInstruction: any | undefined
|
systemInstruction: Content | undefined
|
||||||
} {
|
} {
|
||||||
const contents: any[] = []
|
const contents: Content[] = []
|
||||||
let systemInstruction
|
let systemInstruction: Content | undefined
|
||||||
|
|
||||||
if (request.systemPrompt) {
|
if (request.systemPrompt) {
|
||||||
systemInstruction = { parts: [{ text: request.systemPrompt }] }
|
systemInstruction = { parts: [{ text: request.systemPrompt }] }
|
||||||
@@ -82,13 +149,13 @@ export function convertToGeminiFormat(request: ProviderRequest): {
|
|||||||
contents.push({ role: 'user', parts: [{ text: request.context }] })
|
contents.push({ role: 'user', parts: [{ text: request.context }] })
|
||||||
}
|
}
|
||||||
|
|
||||||
if (request.messages && request.messages.length > 0) {
|
if (request.messages?.length) {
|
||||||
for (const message of request.messages) {
|
for (const message of request.messages) {
|
||||||
if (message.role === 'system') {
|
if (message.role === 'system') {
|
||||||
if (!systemInstruction) {
|
if (!systemInstruction) {
|
||||||
systemInstruction = { parts: [{ text: message.content }] }
|
systemInstruction = { parts: [{ text: message.content ?? '' }] }
|
||||||
} else {
|
} else if (systemInstruction.parts?.[0] && 'text' in systemInstruction.parts[0]) {
|
||||||
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text || ''}\n${message.content}`
|
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text}\n${message.content}`
|
||||||
}
|
}
|
||||||
} else if (message.role === 'user' || message.role === 'assistant') {
|
} else if (message.role === 'user' || message.role === 'assistant') {
|
||||||
const geminiRole = message.role === 'user' ? 'user' : 'model'
|
const geminiRole = message.role === 'user' ? 'user' : 'model'
|
||||||
@@ -97,60 +164,200 @@ export function convertToGeminiFormat(request: ProviderRequest): {
|
|||||||
contents.push({ role: geminiRole, parts: [{ text: message.content }] })
|
contents.push({ role: geminiRole, parts: [{ text: message.content }] })
|
||||||
}
|
}
|
||||||
|
|
||||||
if (message.role === 'assistant' && message.tool_calls && message.tool_calls.length > 0) {
|
if (message.role === 'assistant' && message.tool_calls?.length) {
|
||||||
const functionCalls = message.tool_calls.map((toolCall) => ({
|
const functionCalls = message.tool_calls.map((toolCall) => ({
|
||||||
functionCall: {
|
functionCall: {
|
||||||
name: toolCall.function?.name,
|
name: toolCall.function?.name,
|
||||||
args: JSON.parse(toolCall.function?.arguments || '{}'),
|
args: JSON.parse(toolCall.function?.arguments || '{}') as Record<string, unknown>,
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
contents.push({ role: 'model', parts: functionCalls })
|
contents.push({ role: 'model', parts: functionCalls })
|
||||||
}
|
}
|
||||||
} else if (message.role === 'tool') {
|
} else if (message.role === 'tool') {
|
||||||
|
if (!message.name) {
|
||||||
|
logger.warn('Tool message missing function name, skipping')
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
let responseData: Record<string, unknown>
|
||||||
|
try {
|
||||||
|
responseData = JSON.parse(message.content ?? '{}')
|
||||||
|
} catch {
|
||||||
|
responseData = { output: message.content }
|
||||||
|
}
|
||||||
contents.push({
|
contents.push({
|
||||||
role: 'user',
|
role: 'user',
|
||||||
parts: [{ text: `Function result: ${message.content}` }],
|
parts: [
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
id: message.tool_call_id,
|
||||||
|
name: message.name,
|
||||||
|
response: responseData,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const tools = request.tools?.map((tool) => {
|
const tools = request.tools?.map((tool): GeminiToolDef => {
|
||||||
const toolParameters = { ...(tool.parameters || {}) }
|
const toolParameters = { ...(tool.parameters || {}) }
|
||||||
|
|
||||||
if (toolParameters.properties) {
|
if (toolParameters.properties) {
|
||||||
const properties = { ...toolParameters.properties }
|
const properties = { ...toolParameters.properties }
|
||||||
const required = toolParameters.required ? [...toolParameters.required] : []
|
const required = toolParameters.required ? [...toolParameters.required] : []
|
||||||
|
|
||||||
|
// Remove default values from properties (not supported by Gemini)
|
||||||
for (const key in properties) {
|
for (const key in properties) {
|
||||||
const prop = properties[key] as any
|
const prop = properties[key] as Record<string, unknown>
|
||||||
|
|
||||||
if (prop.default !== undefined) {
|
if (prop.default !== undefined) {
|
||||||
const { default: _, ...cleanProp } = prop
|
const { default: _, ...cleanProp } = prop
|
||||||
properties[key] = cleanProp
|
properties[key] = cleanProp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const parameters = {
|
const parameters: Schema = {
|
||||||
type: toolParameters.type || 'object',
|
type: (toolParameters.type as Schema['type']) || Type.OBJECT,
|
||||||
properties,
|
properties: properties as Record<string, Schema>,
|
||||||
...(required.length > 0 ? { required } : {}),
|
...(required.length > 0 ? { required } : {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
name: tool.id,
|
name: tool.id,
|
||||||
description: tool.description || `Execute the ${tool.id} function`,
|
description: tool.description || `Execute the ${tool.id} function`,
|
||||||
parameters: cleanSchemaForGemini(parameters),
|
parameters: cleanSchemaForGemini(parameters) as Schema,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
name: tool.id,
|
name: tool.id,
|
||||||
description: tool.description || `Execute the ${tool.id} function`,
|
description: tool.description || `Execute the ${tool.id} function`,
|
||||||
parameters: cleanSchemaForGemini(toolParameters),
|
parameters: cleanSchemaForGemini(toolParameters) as Schema,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
return { contents, tools, systemInstruction }
|
return { contents, tools, systemInstruction }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a ReadableStream from a Google Gemini streaming response
|
||||||
|
*/
|
||||||
|
export function createReadableStreamFromGeminiStream(
|
||||||
|
stream: AsyncGenerator<GenerateContentResponse>,
|
||||||
|
onComplete?: (content: string, usage: GeminiUsage) => void
|
||||||
|
): ReadableStream<Uint8Array> {
|
||||||
|
let fullContent = ''
|
||||||
|
let usage: GeminiUsage = { promptTokenCount: 0, candidatesTokenCount: 0, totalTokenCount: 0 }
|
||||||
|
|
||||||
|
return new ReadableStream({
|
||||||
|
async start(controller) {
|
||||||
|
try {
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
if (chunk.usageMetadata) {
|
||||||
|
usage = convertUsageMetadata(chunk.usageMetadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
const text = chunk.text
|
||||||
|
if (text) {
|
||||||
|
fullContent += text
|
||||||
|
controller.enqueue(new TextEncoder().encode(text))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onComplete?.(fullContent, usage)
|
||||||
|
controller.close()
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error reading Google Gemini stream', {
|
||||||
|
error: error instanceof Error ? error.message : String(error),
|
||||||
|
})
|
||||||
|
controller.error(error)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps string mode to FunctionCallingConfigMode enum
|
||||||
|
*/
|
||||||
|
function mapToFunctionCallingMode(mode: string): FunctionCallingConfigMode {
|
||||||
|
switch (mode) {
|
||||||
|
case 'AUTO':
|
||||||
|
return FunctionCallingConfigMode.AUTO
|
||||||
|
case 'ANY':
|
||||||
|
return FunctionCallingConfigMode.ANY
|
||||||
|
case 'NONE':
|
||||||
|
return FunctionCallingConfigMode.NONE
|
||||||
|
default:
|
||||||
|
return FunctionCallingConfigMode.AUTO
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps string level to ThinkingLevel enum
|
||||||
|
*/
|
||||||
|
export function mapToThinkingLevel(level: string): ThinkingLevel {
|
||||||
|
switch (level.toLowerCase()) {
|
||||||
|
case 'minimal':
|
||||||
|
return ThinkingLevel.MINIMAL
|
||||||
|
case 'low':
|
||||||
|
return ThinkingLevel.LOW
|
||||||
|
case 'medium':
|
||||||
|
return ThinkingLevel.MEDIUM
|
||||||
|
case 'high':
|
||||||
|
return ThinkingLevel.HIGH
|
||||||
|
default:
|
||||||
|
return ThinkingLevel.HIGH
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Result of checking forced tool usage
|
||||||
|
*/
|
||||||
|
export interface ForcedToolResult {
|
||||||
|
hasUsedForcedTool: boolean
|
||||||
|
usedForcedTools: string[]
|
||||||
|
nextToolConfig: ToolConfig | undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks for forced tool usage in Google Gemini responses
|
||||||
|
*/
|
||||||
|
export function checkForForcedToolUsage(
|
||||||
|
functionCalls: FunctionCall[] | undefined,
|
||||||
|
toolConfig: ToolConfig | undefined,
|
||||||
|
forcedTools: string[],
|
||||||
|
usedForcedTools: string[]
|
||||||
|
): ForcedToolResult | null {
|
||||||
|
if (!functionCalls?.length) return null
|
||||||
|
|
||||||
|
const adaptedToolCalls = functionCalls.map((fc) => ({
|
||||||
|
name: fc.name ?? '',
|
||||||
|
arguments: (fc.args ?? {}) as Record<string, unknown>,
|
||||||
|
}))
|
||||||
|
|
||||||
|
const result = trackForcedToolUsage(
|
||||||
|
adaptedToolCalls,
|
||||||
|
toolConfig,
|
||||||
|
logger,
|
||||||
|
'google',
|
||||||
|
forcedTools,
|
||||||
|
usedForcedTools
|
||||||
|
)
|
||||||
|
|
||||||
|
if (!result) return null
|
||||||
|
|
||||||
|
const nextToolConfig: ToolConfig | undefined = result.nextToolConfig?.functionCallingConfig?.mode
|
||||||
|
? {
|
||||||
|
functionCallingConfig: {
|
||||||
|
mode: mapToFunctionCallingMode(result.nextToolConfig.functionCallingConfig.mode),
|
||||||
|
allowedFunctionNames: result.nextToolConfig.functionCallingConfig.allowedFunctionNames,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
return {
|
||||||
|
hasUsedForcedTool: result.hasUsedForcedTool,
|
||||||
|
usedForcedTools: result.usedForcedTools,
|
||||||
|
nextToolConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -69,10 +69,7 @@ export const groqProvider: ProviderConfig = {
|
|||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
const payload: any = {
|
const payload: any = {
|
||||||
model: (request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct').replace(
|
model: request.model.replace('groq/', ''),
|
||||||
'groq/',
|
|
||||||
''
|
|
||||||
),
|
|
||||||
messages: allMessages,
|
messages: allMessages,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,7 +106,7 @@ export const groqProvider: ProviderConfig = {
|
|||||||
toolChoice: payload.tool_choice,
|
toolChoice: payload.tool_choice,
|
||||||
forcedToolsCount: forcedTools.length,
|
forcedToolsCount: forcedTools.length,
|
||||||
hasFilteredTools,
|
hasFilteredTools,
|
||||||
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
|
model: request.model,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -149,7 +146,7 @@ export const groqProvider: ProviderConfig = {
|
|||||||
success: true,
|
success: true,
|
||||||
output: {
|
output: {
|
||||||
content: '',
|
content: '',
|
||||||
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
|
model: request.model,
|
||||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||||
toolCalls: undefined,
|
toolCalls: undefined,
|
||||||
providerTiming: {
|
providerTiming: {
|
||||||
@@ -393,7 +390,7 @@ export const groqProvider: ProviderConfig = {
|
|||||||
const streamingPayload = {
|
const streamingPayload = {
|
||||||
...payload,
|
...payload,
|
||||||
messages: currentMessages,
|
messages: currentMessages,
|
||||||
tool_choice: 'auto',
|
tool_choice: originalToolChoice || 'auto',
|
||||||
stream: true,
|
stream: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -425,7 +422,7 @@ export const groqProvider: ProviderConfig = {
|
|||||||
success: true,
|
success: true,
|
||||||
output: {
|
output: {
|
||||||
content: '',
|
content: '',
|
||||||
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
|
model: request.model,
|
||||||
tokens: {
|
tokens: {
|
||||||
prompt: tokens.prompt,
|
prompt: tokens.prompt,
|
||||||
completion: tokens.completion,
|
completion: tokens.completion,
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import { getCostMultiplier } from '@/lib/core/config/feature-flags'
|
import { getCostMultiplier } from '@/lib/core/config/feature-flags'
|
||||||
import { createLogger } from '@/lib/logs/console/logger'
|
import { createLogger } from '@/lib/logs/console/logger'
|
||||||
import type { StreamingExecution } from '@/executor/types'
|
import type { StreamingExecution } from '@/executor/types'
|
||||||
import type { ProviderRequest, ProviderResponse } from '@/providers/types'
|
import { getProviderExecutor } from '@/providers/registry'
|
||||||
|
import type { ProviderId, ProviderRequest, ProviderResponse } from '@/providers/types'
|
||||||
import {
|
import {
|
||||||
calculateCost,
|
calculateCost,
|
||||||
generateStructuredOutputInstructions,
|
generateStructuredOutputInstructions,
|
||||||
getProvider,
|
|
||||||
shouldBillModelUsage,
|
shouldBillModelUsage,
|
||||||
supportsTemperature,
|
supportsTemperature,
|
||||||
} from '@/providers/utils'
|
} from '@/providers/utils'
|
||||||
@@ -40,7 +40,7 @@ export async function executeProviderRequest(
|
|||||||
providerId: string,
|
providerId: string,
|
||||||
request: ProviderRequest
|
request: ProviderRequest
|
||||||
): Promise<ProviderResponse | ReadableStream | StreamingExecution> {
|
): Promise<ProviderResponse | ReadableStream | StreamingExecution> {
|
||||||
const provider = getProvider(providerId)
|
const provider = await getProviderExecutor(providerId as ProviderId)
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
throw new Error(`Provider not found: ${providerId}`)
|
throw new Error(`Provider not found: ${providerId}`)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ export const mistralProvider: ProviderConfig = {
|
|||||||
request: ProviderRequest
|
request: ProviderRequest
|
||||||
): Promise<ProviderResponse | StreamingExecution> => {
|
): Promise<ProviderResponse | StreamingExecution> => {
|
||||||
logger.info('Preparing Mistral request', {
|
logger.info('Preparing Mistral request', {
|
||||||
model: request.model || 'mistral-large-latest',
|
model: request.model,
|
||||||
hasSystemPrompt: !!request.systemPrompt,
|
hasSystemPrompt: !!request.systemPrompt,
|
||||||
hasMessages: !!request.messages?.length,
|
hasMessages: !!request.messages?.length,
|
||||||
hasTools: !!request.tools?.length,
|
hasTools: !!request.tools?.length,
|
||||||
@@ -86,7 +86,7 @@ export const mistralProvider: ProviderConfig = {
|
|||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
const payload: any = {
|
const payload: any = {
|
||||||
model: request.model || 'mistral-large-latest',
|
model: request.model,
|
||||||
messages: allMessages,
|
messages: allMessages,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,7 +126,7 @@ export const mistralProvider: ProviderConfig = {
|
|||||||
: toolChoice.type === 'any'
|
: toolChoice.type === 'any'
|
||||||
? `force:${toolChoice.any?.name || 'unknown'}`
|
? `force:${toolChoice.any?.name || 'unknown'}`
|
||||||
: 'unknown',
|
: 'unknown',
|
||||||
model: request.model || 'mistral-large-latest',
|
model: request.model,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,6 +39,10 @@ export interface ModelCapabilities {
|
|||||||
verbosity?: {
|
verbosity?: {
|
||||||
values: string[]
|
values: string[]
|
||||||
}
|
}
|
||||||
|
thinking?: {
|
||||||
|
levels: string[]
|
||||||
|
default?: string
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ModelDefinition {
|
export interface ModelDefinition {
|
||||||
@@ -730,6 +734,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
|||||||
},
|
},
|
||||||
capabilities: {
|
capabilities: {
|
||||||
temperature: { min: 0, max: 2 },
|
temperature: { min: 0, max: 2 },
|
||||||
|
thinking: {
|
||||||
|
levels: ['low', 'high'],
|
||||||
|
default: 'high',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
contextWindow: 1000000,
|
contextWindow: 1000000,
|
||||||
},
|
},
|
||||||
@@ -743,6 +751,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
|||||||
},
|
},
|
||||||
capabilities: {
|
capabilities: {
|
||||||
temperature: { min: 0, max: 2 },
|
temperature: { min: 0, max: 2 },
|
||||||
|
thinking: {
|
||||||
|
levels: ['minimal', 'low', 'medium', 'high'],
|
||||||
|
default: 'high',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
contextWindow: 1000000,
|
contextWindow: 1000000,
|
||||||
},
|
},
|
||||||
@@ -832,6 +844,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
|||||||
},
|
},
|
||||||
capabilities: {
|
capabilities: {
|
||||||
temperature: { min: 0, max: 2 },
|
temperature: { min: 0, max: 2 },
|
||||||
|
thinking: {
|
||||||
|
levels: ['low', 'high'],
|
||||||
|
default: 'high',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
contextWindow: 1000000,
|
contextWindow: 1000000,
|
||||||
},
|
},
|
||||||
@@ -845,6 +861,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
|||||||
},
|
},
|
||||||
capabilities: {
|
capabilities: {
|
||||||
temperature: { min: 0, max: 2 },
|
temperature: { min: 0, max: 2 },
|
||||||
|
thinking: {
|
||||||
|
levels: ['minimal', 'low', 'medium', 'high'],
|
||||||
|
default: 'high',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
contextWindow: 1000000,
|
contextWindow: 1000000,
|
||||||
},
|
},
|
||||||
@@ -1864,3 +1884,49 @@ export function supportsNativeStructuredOutputs(modelId: string): boolean {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a model supports thinking/reasoning features.
|
||||||
|
* Returns the thinking capability config if supported, null otherwise.
|
||||||
|
*/
|
||||||
|
export function getThinkingCapability(
|
||||||
|
modelId: string
|
||||||
|
): { levels: string[]; default?: string } | null {
|
||||||
|
const normalizedModelId = modelId.toLowerCase()
|
||||||
|
|
||||||
|
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||||
|
for (const model of provider.models) {
|
||||||
|
if (model.capabilities.thinking) {
|
||||||
|
const baseModelId = model.id.toLowerCase()
|
||||||
|
if (normalizedModelId === baseModelId || normalizedModelId.startsWith(`${baseModelId}-`)) {
|
||||||
|
return model.capabilities.thinking
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all models that support thinking capability
|
||||||
|
*/
|
||||||
|
export function getModelsWithThinking(): string[] {
|
||||||
|
const models: string[] = []
|
||||||
|
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||||
|
for (const model of provider.models) {
|
||||||
|
if (model.capabilities.thinking) {
|
||||||
|
models.push(model.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the thinking levels for a specific model
|
||||||
|
* Returns the valid levels for that model, or null if the model doesn't support thinking
|
||||||
|
*/
|
||||||
|
export function getThinkingLevelsForModel(modelId: string): string[] | null {
|
||||||
|
const capability = getThinkingCapability(modelId)
|
||||||
|
return capability?.levels ?? null
|
||||||
|
}
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ export const openaiProvider: ProviderConfig = {
|
|||||||
request: ProviderRequest
|
request: ProviderRequest
|
||||||
): Promise<ProviderResponse | StreamingExecution> => {
|
): Promise<ProviderResponse | StreamingExecution> => {
|
||||||
logger.info('Preparing OpenAI request', {
|
logger.info('Preparing OpenAI request', {
|
||||||
model: request.model || 'gpt-4o',
|
model: request.model,
|
||||||
hasSystemPrompt: !!request.systemPrompt,
|
hasSystemPrompt: !!request.systemPrompt,
|
||||||
hasMessages: !!request.messages?.length,
|
hasMessages: !!request.messages?.length,
|
||||||
hasTools: !!request.tools?.length,
|
hasTools: !!request.tools?.length,
|
||||||
@@ -76,7 +76,7 @@ export const openaiProvider: ProviderConfig = {
|
|||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
const payload: any = {
|
const payload: any = {
|
||||||
model: request.model || 'gpt-4o',
|
model: request.model,
|
||||||
messages: allMessages,
|
messages: allMessages,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +121,7 @@ export const openaiProvider: ProviderConfig = {
|
|||||||
: toolChoice.type === 'any'
|
: toolChoice.type === 'any'
|
||||||
? `force:${toolChoice.any?.name || 'unknown'}`
|
? `force:${toolChoice.any?.name || 'unknown'}`
|
||||||
: 'unknown',
|
: 'unknown',
|
||||||
model: request.model || 'gpt-4o',
|
model: request.model,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ export const openRouterProvider: ProviderConfig = {
|
|||||||
baseURL: 'https://openrouter.ai/api/v1',
|
baseURL: 'https://openrouter.ai/api/v1',
|
||||||
})
|
})
|
||||||
|
|
||||||
const requestedModel = (request.model || '').replace(/^openrouter\//, '')
|
const requestedModel = request.model.replace(/^openrouter\//, '')
|
||||||
|
|
||||||
logger.info('Preparing OpenRouter request', {
|
logger.info('Preparing OpenRouter request', {
|
||||||
model: requestedModel,
|
model: requestedModel,
|
||||||
|
|||||||
59
apps/sim/providers/registry.ts
Normal file
59
apps/sim/providers/registry.ts
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import { createLogger } from '@/lib/logs/console/logger'
|
||||||
|
import { anthropicProvider } from '@/providers/anthropic'
|
||||||
|
import { azureOpenAIProvider } from '@/providers/azure-openai'
|
||||||
|
import { cerebrasProvider } from '@/providers/cerebras'
|
||||||
|
import { deepseekProvider } from '@/providers/deepseek'
|
||||||
|
import { googleProvider } from '@/providers/google'
|
||||||
|
import { groqProvider } from '@/providers/groq'
|
||||||
|
import { mistralProvider } from '@/providers/mistral'
|
||||||
|
import { ollamaProvider } from '@/providers/ollama'
|
||||||
|
import { openaiProvider } from '@/providers/openai'
|
||||||
|
import { openRouterProvider } from '@/providers/openrouter'
|
||||||
|
import type { ProviderConfig, ProviderId } from '@/providers/types'
|
||||||
|
import { vertexProvider } from '@/providers/vertex'
|
||||||
|
import { vllmProvider } from '@/providers/vllm'
|
||||||
|
import { xAIProvider } from '@/providers/xai'
|
||||||
|
|
||||||
|
const logger = createLogger('ProviderRegistry')
|
||||||
|
|
||||||
|
const providerRegistry: Record<ProviderId, ProviderConfig> = {
|
||||||
|
openai: openaiProvider,
|
||||||
|
anthropic: anthropicProvider,
|
||||||
|
google: googleProvider,
|
||||||
|
vertex: vertexProvider,
|
||||||
|
deepseek: deepseekProvider,
|
||||||
|
xai: xAIProvider,
|
||||||
|
cerebras: cerebrasProvider,
|
||||||
|
groq: groqProvider,
|
||||||
|
vllm: vllmProvider,
|
||||||
|
mistral: mistralProvider,
|
||||||
|
'azure-openai': azureOpenAIProvider,
|
||||||
|
openrouter: openRouterProvider,
|
||||||
|
ollama: ollamaProvider,
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getProviderExecutor(
|
||||||
|
providerId: ProviderId
|
||||||
|
): Promise<ProviderConfig | undefined> {
|
||||||
|
const provider = providerRegistry[providerId]
|
||||||
|
if (!provider) {
|
||||||
|
logger.error(`Provider not found: ${providerId}`)
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
return provider
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function initializeProviders(): Promise<void> {
|
||||||
|
for (const [id, provider] of Object.entries(providerRegistry)) {
|
||||||
|
if (provider.initialize) {
|
||||||
|
try {
|
||||||
|
await provider.initialize()
|
||||||
|
logger.info(`Initialized provider: ${id}`)
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`Failed to initialize ${id} provider`, {
|
||||||
|
error: error instanceof Error ? error.message : 'Unknown error',
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -164,6 +164,7 @@ export interface ProviderRequest {
|
|||||||
vertexLocation?: string
|
vertexLocation?: string
|
||||||
reasoningEffort?: string
|
reasoningEffort?: string
|
||||||
verbosity?: string
|
verbosity?: string
|
||||||
|
thinkingLevel?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export const providers: Record<string, ProviderConfig> = {}
|
export const providers: Record<string, ProviderConfig> = {}
|
||||||
|
|||||||
@@ -3,13 +3,6 @@ import type { CompletionUsage } from 'openai/resources/completions'
|
|||||||
import { getEnv, isTruthy } from '@/lib/core/config/env'
|
import { getEnv, isTruthy } from '@/lib/core/config/env'
|
||||||
import { isHosted } from '@/lib/core/config/feature-flags'
|
import { isHosted } from '@/lib/core/config/feature-flags'
|
||||||
import { createLogger, type Logger } from '@/lib/logs/console/logger'
|
import { createLogger, type Logger } from '@/lib/logs/console/logger'
|
||||||
import { anthropicProvider } from '@/providers/anthropic'
|
|
||||||
import { azureOpenAIProvider } from '@/providers/azure-openai'
|
|
||||||
import { cerebrasProvider } from '@/providers/cerebras'
|
|
||||||
import { deepseekProvider } from '@/providers/deepseek'
|
|
||||||
import { googleProvider } from '@/providers/google'
|
|
||||||
import { groqProvider } from '@/providers/groq'
|
|
||||||
import { mistralProvider } from '@/providers/mistral'
|
|
||||||
import {
|
import {
|
||||||
getComputerUseModels,
|
getComputerUseModels,
|
||||||
getEmbeddingModelPricing,
|
getEmbeddingModelPricing,
|
||||||
@@ -20,117 +13,82 @@ import {
|
|||||||
getModelsWithTemperatureSupport,
|
getModelsWithTemperatureSupport,
|
||||||
getModelsWithTempRange01,
|
getModelsWithTempRange01,
|
||||||
getModelsWithTempRange02,
|
getModelsWithTempRange02,
|
||||||
|
getModelsWithThinking,
|
||||||
getModelsWithVerbosity,
|
getModelsWithVerbosity,
|
||||||
|
getProviderDefaultModel as getProviderDefaultModelFromDefinitions,
|
||||||
getProviderModels as getProviderModelsFromDefinitions,
|
getProviderModels as getProviderModelsFromDefinitions,
|
||||||
getProvidersWithToolUsageControl,
|
getProvidersWithToolUsageControl,
|
||||||
getReasoningEffortValuesForModel as getReasoningEffortValuesForModelFromDefinitions,
|
getReasoningEffortValuesForModel as getReasoningEffortValuesForModelFromDefinitions,
|
||||||
|
getThinkingLevelsForModel as getThinkingLevelsForModelFromDefinitions,
|
||||||
getVerbosityValuesForModel as getVerbosityValuesForModelFromDefinitions,
|
getVerbosityValuesForModel as getVerbosityValuesForModelFromDefinitions,
|
||||||
PROVIDER_DEFINITIONS,
|
PROVIDER_DEFINITIONS,
|
||||||
supportsTemperature as supportsTemperatureFromDefinitions,
|
supportsTemperature as supportsTemperatureFromDefinitions,
|
||||||
supportsToolUsageControl as supportsToolUsageControlFromDefinitions,
|
supportsToolUsageControl as supportsToolUsageControlFromDefinitions,
|
||||||
updateOllamaModels as updateOllamaModelsInDefinitions,
|
updateOllamaModels as updateOllamaModelsInDefinitions,
|
||||||
} from '@/providers/models'
|
} from '@/providers/models'
|
||||||
import { ollamaProvider } from '@/providers/ollama'
|
import type { ProviderId, ProviderToolConfig } from '@/providers/types'
|
||||||
import { openaiProvider } from '@/providers/openai'
|
|
||||||
import { openRouterProvider } from '@/providers/openrouter'
|
|
||||||
import type { ProviderConfig, ProviderId, ProviderToolConfig } from '@/providers/types'
|
|
||||||
import { vertexProvider } from '@/providers/vertex'
|
|
||||||
import { vllmProvider } from '@/providers/vllm'
|
|
||||||
import { xAIProvider } from '@/providers/xai'
|
|
||||||
import { useCustomToolsStore } from '@/stores/custom-tools/store'
|
import { useCustomToolsStore } from '@/stores/custom-tools/store'
|
||||||
import { useProvidersStore } from '@/stores/providers/store'
|
import { useProvidersStore } from '@/stores/providers/store'
|
||||||
|
|
||||||
const logger = createLogger('ProviderUtils')
|
const logger = createLogger('ProviderUtils')
|
||||||
|
|
||||||
export const providers: Record<
|
/**
|
||||||
ProviderId,
|
* Client-safe provider metadata.
|
||||||
ProviderConfig & {
|
* This object contains only model lists and patterns - no executeRequest implementations.
|
||||||
models: string[]
|
* For server-side execution, use @/providers/registry.
|
||||||
computerUseModels?: string[]
|
*/
|
||||||
modelPatterns?: RegExp[]
|
export interface ProviderMetadata {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
description: string
|
||||||
|
version: string
|
||||||
|
models: string[]
|
||||||
|
defaultModel: string
|
||||||
|
computerUseModels?: string[]
|
||||||
|
modelPatterns?: RegExp[]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build provider metadata from PROVIDER_DEFINITIONS.
|
||||||
|
* This is client-safe as it doesn't import any provider implementations.
|
||||||
|
*/
|
||||||
|
function buildProviderMetadata(providerId: ProviderId): ProviderMetadata {
|
||||||
|
const def = PROVIDER_DEFINITIONS[providerId]
|
||||||
|
return {
|
||||||
|
id: providerId,
|
||||||
|
name: def?.name || providerId,
|
||||||
|
description: def?.description || '',
|
||||||
|
version: '1.0.0',
|
||||||
|
models: getProviderModelsFromDefinitions(providerId),
|
||||||
|
defaultModel: getProviderDefaultModelFromDefinitions(providerId),
|
||||||
|
modelPatterns: def?.modelPatterns,
|
||||||
}
|
}
|
||||||
> = {
|
}
|
||||||
|
|
||||||
|
export const providers: Record<ProviderId, ProviderMetadata> = {
|
||||||
openai: {
|
openai: {
|
||||||
...openaiProvider,
|
...buildProviderMetadata('openai'),
|
||||||
models: getProviderModelsFromDefinitions('openai'),
|
|
||||||
computerUseModels: ['computer-use-preview'],
|
computerUseModels: ['computer-use-preview'],
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.openai.modelPatterns,
|
|
||||||
},
|
},
|
||||||
anthropic: {
|
anthropic: {
|
||||||
...anthropicProvider,
|
...buildProviderMetadata('anthropic'),
|
||||||
models: getProviderModelsFromDefinitions('anthropic'),
|
|
||||||
computerUseModels: getComputerUseModels().filter((model) =>
|
computerUseModels: getComputerUseModels().filter((model) =>
|
||||||
getProviderModelsFromDefinitions('anthropic').includes(model)
|
getProviderModelsFromDefinitions('anthropic').includes(model)
|
||||||
),
|
),
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.anthropic.modelPatterns,
|
|
||||||
},
|
|
||||||
google: {
|
|
||||||
...googleProvider,
|
|
||||||
models: getProviderModelsFromDefinitions('google'),
|
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.google.modelPatterns,
|
|
||||||
},
|
|
||||||
vertex: {
|
|
||||||
...vertexProvider,
|
|
||||||
models: getProviderModelsFromDefinitions('vertex'),
|
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.vertex.modelPatterns,
|
|
||||||
},
|
|
||||||
deepseek: {
|
|
||||||
...deepseekProvider,
|
|
||||||
models: getProviderModelsFromDefinitions('deepseek'),
|
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.deepseek.modelPatterns,
|
|
||||||
},
|
|
||||||
xai: {
|
|
||||||
...xAIProvider,
|
|
||||||
models: getProviderModelsFromDefinitions('xai'),
|
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.xai.modelPatterns,
|
|
||||||
},
|
|
||||||
cerebras: {
|
|
||||||
...cerebrasProvider,
|
|
||||||
models: getProviderModelsFromDefinitions('cerebras'),
|
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.cerebras.modelPatterns,
|
|
||||||
},
|
|
||||||
groq: {
|
|
||||||
...groqProvider,
|
|
||||||
models: getProviderModelsFromDefinitions('groq'),
|
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.groq.modelPatterns,
|
|
||||||
},
|
|
||||||
vllm: {
|
|
||||||
...vllmProvider,
|
|
||||||
models: getProviderModelsFromDefinitions('vllm'),
|
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.vllm.modelPatterns,
|
|
||||||
},
|
|
||||||
mistral: {
|
|
||||||
...mistralProvider,
|
|
||||||
models: getProviderModelsFromDefinitions('mistral'),
|
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.mistral.modelPatterns,
|
|
||||||
},
|
|
||||||
'azure-openai': {
|
|
||||||
...azureOpenAIProvider,
|
|
||||||
models: getProviderModelsFromDefinitions('azure-openai'),
|
|
||||||
modelPatterns: PROVIDER_DEFINITIONS['azure-openai'].modelPatterns,
|
|
||||||
},
|
|
||||||
openrouter: {
|
|
||||||
...openRouterProvider,
|
|
||||||
models: getProviderModelsFromDefinitions('openrouter'),
|
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.openrouter.modelPatterns,
|
|
||||||
},
|
|
||||||
ollama: {
|
|
||||||
...ollamaProvider,
|
|
||||||
models: getProviderModelsFromDefinitions('ollama'),
|
|
||||||
modelPatterns: PROVIDER_DEFINITIONS.ollama.modelPatterns,
|
|
||||||
},
|
},
|
||||||
|
google: buildProviderMetadata('google'),
|
||||||
|
vertex: buildProviderMetadata('vertex'),
|
||||||
|
deepseek: buildProviderMetadata('deepseek'),
|
||||||
|
xai: buildProviderMetadata('xai'),
|
||||||
|
cerebras: buildProviderMetadata('cerebras'),
|
||||||
|
groq: buildProviderMetadata('groq'),
|
||||||
|
vllm: buildProviderMetadata('vllm'),
|
||||||
|
mistral: buildProviderMetadata('mistral'),
|
||||||
|
'azure-openai': buildProviderMetadata('azure-openai'),
|
||||||
|
openrouter: buildProviderMetadata('openrouter'),
|
||||||
|
ollama: buildProviderMetadata('ollama'),
|
||||||
}
|
}
|
||||||
|
|
||||||
Object.entries(providers).forEach(([id, provider]) => {
|
|
||||||
if (provider.initialize) {
|
|
||||||
provider.initialize().catch((error) => {
|
|
||||||
logger.error(`Failed to initialize ${id} provider`, {
|
|
||||||
error: error instanceof Error ? error.message : 'Unknown error',
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
export function updateOllamaProviderModels(models: string[]): void {
|
export function updateOllamaProviderModels(models: string[]): void {
|
||||||
updateOllamaModelsInDefinitions(models)
|
updateOllamaModelsInDefinitions(models)
|
||||||
providers.ollama.models = getProviderModelsFromDefinitions('ollama')
|
providers.ollama.models = getProviderModelsFromDefinitions('ollama')
|
||||||
@@ -211,12 +169,12 @@ export function getProviderFromModel(model: string): ProviderId {
|
|||||||
return 'ollama'
|
return 'ollama'
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getProvider(id: string): ProviderConfig | undefined {
|
export function getProvider(id: string): ProviderMetadata | undefined {
|
||||||
const providerId = id.split('/')[0] as ProviderId
|
const providerId = id.split('/')[0] as ProviderId
|
||||||
return providers[providerId]
|
return providers[providerId]
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getProviderConfigFromModel(model: string): ProviderConfig | undefined {
|
export function getProviderConfigFromModel(model: string): ProviderMetadata | undefined {
|
||||||
const providerId = getProviderFromModel(model)
|
const providerId = getProviderFromModel(model)
|
||||||
return providers[providerId]
|
return providers[providerId]
|
||||||
}
|
}
|
||||||
@@ -929,6 +887,7 @@ export const MODELS_TEMP_RANGE_0_1 = getModelsWithTempRange01()
|
|||||||
export const MODELS_WITH_TEMPERATURE_SUPPORT = getModelsWithTemperatureSupport()
|
export const MODELS_WITH_TEMPERATURE_SUPPORT = getModelsWithTemperatureSupport()
|
||||||
export const MODELS_WITH_REASONING_EFFORT = getModelsWithReasoningEffort()
|
export const MODELS_WITH_REASONING_EFFORT = getModelsWithReasoningEffort()
|
||||||
export const MODELS_WITH_VERBOSITY = getModelsWithVerbosity()
|
export const MODELS_WITH_VERBOSITY = getModelsWithVerbosity()
|
||||||
|
export const MODELS_WITH_THINKING = getModelsWithThinking()
|
||||||
export const PROVIDERS_WITH_TOOL_USAGE_CONTROL = getProvidersWithToolUsageControl()
|
export const PROVIDERS_WITH_TOOL_USAGE_CONTROL = getProvidersWithToolUsageControl()
|
||||||
|
|
||||||
export function supportsTemperature(model: string): boolean {
|
export function supportsTemperature(model: string): boolean {
|
||||||
@@ -963,6 +922,14 @@ export function getVerbosityValuesForModel(model: string): string[] | null {
|
|||||||
return getVerbosityValuesForModelFromDefinitions(model)
|
return getVerbosityValuesForModelFromDefinitions(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get thinking levels for a specific model
|
||||||
|
* Returns the valid levels for that model, or null if the model doesn't support thinking
|
||||||
|
*/
|
||||||
|
export function getThinkingLevelsForModel(model: string): string[] | null {
|
||||||
|
return getThinkingLevelsForModelFromDefinitions(model)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prepare tool execution parameters, separating tool parameters from system parameters
|
* Prepare tool execution parameters, separating tool parameters from system parameters
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -1,33 +1,23 @@
|
|||||||
|
import { GoogleGenAI } from '@google/genai'
|
||||||
|
import { OAuth2Client } from 'google-auth-library'
|
||||||
import { env } from '@/lib/core/config/env'
|
import { env } from '@/lib/core/config/env'
|
||||||
import { createLogger } from '@/lib/logs/console/logger'
|
import { createLogger } from '@/lib/logs/console/logger'
|
||||||
import type { StreamingExecution } from '@/executor/types'
|
import type { StreamingExecution } from '@/executor/types'
|
||||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
import { executeGeminiRequest } from '@/providers/gemini/core'
|
||||||
import {
|
|
||||||
cleanSchemaForGemini,
|
|
||||||
convertToGeminiFormat,
|
|
||||||
extractFunctionCall,
|
|
||||||
extractTextContent,
|
|
||||||
} from '@/providers/google/utils'
|
|
||||||
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||||
import type {
|
import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types'
|
||||||
ProviderConfig,
|
|
||||||
ProviderRequest,
|
|
||||||
ProviderResponse,
|
|
||||||
TimeSegment,
|
|
||||||
} from '@/providers/types'
|
|
||||||
import {
|
|
||||||
calculateCost,
|
|
||||||
prepareToolExecution,
|
|
||||||
prepareToolsWithUsageControl,
|
|
||||||
trackForcedToolUsage,
|
|
||||||
} from '@/providers/utils'
|
|
||||||
import { buildVertexEndpoint, createReadableStreamFromVertexStream } from '@/providers/vertex/utils'
|
|
||||||
import { executeTool } from '@/tools'
|
|
||||||
|
|
||||||
const logger = createLogger('VertexProvider')
|
const logger = createLogger('VertexProvider')
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Vertex AI provider configuration
|
* Vertex AI provider
|
||||||
|
*
|
||||||
|
* Uses the @google/genai SDK with Vertex AI backend and OAuth authentication.
|
||||||
|
* Shares core execution logic with Google Gemini provider.
|
||||||
|
*
|
||||||
|
* Authentication:
|
||||||
|
* - Uses OAuth access token passed via googleAuthOptions.authClient
|
||||||
|
* - Token refresh is handled at the OAuth layer before calling this provider
|
||||||
*/
|
*/
|
||||||
export const vertexProvider: ProviderConfig = {
|
export const vertexProvider: ProviderConfig = {
|
||||||
id: 'vertex',
|
id: 'vertex',
|
||||||
@@ -55,869 +45,35 @@ export const vertexProvider: ProviderConfig = {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info('Preparing Vertex AI request', {
|
// Strip 'vertex/' prefix from model name if present
|
||||||
model: request.model || 'vertex/gemini-2.5-pro',
|
const model = request.model.replace('vertex/', '')
|
||||||
hasSystemPrompt: !!request.systemPrompt,
|
|
||||||
hasMessages: !!request.messages?.length,
|
logger.info('Creating Vertex AI client', {
|
||||||
hasTools: !!request.tools?.length,
|
|
||||||
toolCount: request.tools?.length || 0,
|
|
||||||
hasResponseFormat: !!request.responseFormat,
|
|
||||||
streaming: !!request.stream,
|
|
||||||
project: vertexProject,
|
project: vertexProject,
|
||||||
location: vertexLocation,
|
location: vertexLocation,
|
||||||
|
model,
|
||||||
})
|
})
|
||||||
|
|
||||||
const providerStartTime = Date.now()
|
// Create an OAuth2Client and set the access token
|
||||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
// This allows us to use an OAuth access token with the SDK
|
||||||
|
const authClient = new OAuth2Client()
|
||||||
try {
|
authClient.setCredentials({ access_token: request.apiKey })
|
||||||
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
|
|
||||||
|
// Create client with Vertex AI configuration
|
||||||
const requestedModel = (request.model || 'vertex/gemini-2.5-pro').replace('vertex/', '')
|
const ai = new GoogleGenAI({
|
||||||
|
vertexai: true,
|
||||||
const payload: any = {
|
project: vertexProject,
|
||||||
contents,
|
location: vertexLocation,
|
||||||
generationConfig: {},
|
googleAuthOptions: {
|
||||||
}
|
authClient,
|
||||||
|
},
|
||||||
if (request.temperature !== undefined && request.temperature !== null) {
|
})
|
||||||
payload.generationConfig.temperature = request.temperature
|
|
||||||
}
|
return executeGeminiRequest({
|
||||||
|
ai,
|
||||||
if (request.maxTokens !== undefined) {
|
model,
|
||||||
payload.generationConfig.maxOutputTokens = request.maxTokens
|
request,
|
||||||
}
|
providerType: 'vertex',
|
||||||
|
})
|
||||||
if (systemInstruction) {
|
|
||||||
payload.systemInstruction = systemInstruction
|
|
||||||
}
|
|
||||||
|
|
||||||
if (request.responseFormat && !tools?.length) {
|
|
||||||
const responseFormatSchema = request.responseFormat.schema || request.responseFormat
|
|
||||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
|
||||||
|
|
||||||
payload.generationConfig.responseMimeType = 'application/json'
|
|
||||||
payload.generationConfig.responseSchema = cleanSchema
|
|
||||||
|
|
||||||
logger.info('Using Vertex AI native structured output format', {
|
|
||||||
hasSchema: !!cleanSchema,
|
|
||||||
mimeType: 'application/json',
|
|
||||||
})
|
|
||||||
} else if (request.responseFormat && tools?.length) {
|
|
||||||
logger.warn(
|
|
||||||
'Vertex AI does not support structured output (responseFormat) with function calling (tools). Structured output will be ignored.'
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
|
||||||
|
|
||||||
if (tools?.length) {
|
|
||||||
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'google')
|
|
||||||
const { tools: filteredTools, toolConfig } = preparedTools
|
|
||||||
|
|
||||||
if (filteredTools?.length) {
|
|
||||||
payload.tools = [
|
|
||||||
{
|
|
||||||
functionDeclarations: filteredTools,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
if (toolConfig) {
|
|
||||||
payload.toolConfig = toolConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info('Vertex AI request with tools:', {
|
|
||||||
toolCount: filteredTools.length,
|
|
||||||
model: requestedModel,
|
|
||||||
tools: filteredTools.map((t) => t.name),
|
|
||||||
hasToolConfig: !!toolConfig,
|
|
||||||
toolConfig: toolConfig,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const initialCallTime = Date.now()
|
|
||||||
const shouldStream = !!(request.stream && !tools?.length)
|
|
||||||
|
|
||||||
const endpoint = buildVertexEndpoint(
|
|
||||||
vertexProject,
|
|
||||||
vertexLocation,
|
|
||||||
requestedModel,
|
|
||||||
shouldStream
|
|
||||||
)
|
|
||||||
|
|
||||||
if (request.stream && tools?.length) {
|
|
||||||
logger.info('Streaming disabled for initial request due to tools presence', {
|
|
||||||
toolCount: tools.length,
|
|
||||||
willStreamAfterTools: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await fetch(endpoint, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
Authorization: `Bearer ${request.apiKey}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify(payload),
|
|
||||||
})
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
const responseText = await response.text()
|
|
||||||
logger.error('Vertex AI API error details:', {
|
|
||||||
status: response.status,
|
|
||||||
statusText: response.statusText,
|
|
||||||
responseBody: responseText,
|
|
||||||
})
|
|
||||||
throw new Error(`Vertex AI API error: ${response.status} ${response.statusText}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const firstResponseTime = Date.now() - initialCallTime
|
|
||||||
|
|
||||||
if (shouldStream) {
|
|
||||||
logger.info('Handling Vertex AI streaming response')
|
|
||||||
|
|
||||||
const streamingResult: StreamingExecution = {
|
|
||||||
stream: null as any,
|
|
||||||
execution: {
|
|
||||||
success: true,
|
|
||||||
output: {
|
|
||||||
content: '',
|
|
||||||
model: request.model,
|
|
||||||
tokens: {
|
|
||||||
prompt: 0,
|
|
||||||
completion: 0,
|
|
||||||
total: 0,
|
|
||||||
},
|
|
||||||
providerTiming: {
|
|
||||||
startTime: providerStartTimeISO,
|
|
||||||
endTime: new Date().toISOString(),
|
|
||||||
duration: firstResponseTime,
|
|
||||||
modelTime: firstResponseTime,
|
|
||||||
toolsTime: 0,
|
|
||||||
firstResponseTime,
|
|
||||||
iterations: 1,
|
|
||||||
timeSegments: [
|
|
||||||
{
|
|
||||||
type: 'model',
|
|
||||||
name: 'Initial streaming response',
|
|
||||||
startTime: initialCallTime,
|
|
||||||
endTime: initialCallTime + firstResponseTime,
|
|
||||||
duration: firstResponseTime,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
logs: [],
|
|
||||||
metadata: {
|
|
||||||
startTime: providerStartTimeISO,
|
|
||||||
endTime: new Date().toISOString(),
|
|
||||||
duration: firstResponseTime,
|
|
||||||
},
|
|
||||||
isStreaming: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
streamingResult.stream = createReadableStreamFromVertexStream(
|
|
||||||
response,
|
|
||||||
(content, usage) => {
|
|
||||||
streamingResult.execution.output.content = content
|
|
||||||
|
|
||||||
const streamEndTime = Date.now()
|
|
||||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
|
||||||
|
|
||||||
if (streamingResult.execution.output.providerTiming) {
|
|
||||||
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
|
|
||||||
streamingResult.execution.output.providerTiming.duration =
|
|
||||||
streamEndTime - providerStartTime
|
|
||||||
|
|
||||||
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
|
|
||||||
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
|
|
||||||
streamEndTime
|
|
||||||
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
|
|
||||||
streamEndTime - providerStartTime
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const promptTokens = usage?.promptTokenCount || 0
|
|
||||||
const completionTokens = usage?.candidatesTokenCount || 0
|
|
||||||
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
|
|
||||||
|
|
||||||
streamingResult.execution.output.tokens = {
|
|
||||||
prompt: promptTokens,
|
|
||||||
completion: completionTokens,
|
|
||||||
total: totalTokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
const costResult = calculateCost(request.model, promptTokens, completionTokens)
|
|
||||||
streamingResult.execution.output.cost = {
|
|
||||||
input: costResult.input,
|
|
||||||
output: costResult.output,
|
|
||||||
total: costResult.total,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return streamingResult
|
|
||||||
}
|
|
||||||
|
|
||||||
let geminiResponse = await response.json()
|
|
||||||
|
|
||||||
if (payload.generationConfig?.responseSchema) {
|
|
||||||
const candidate = geminiResponse.candidates?.[0]
|
|
||||||
if (candidate?.content?.parts?.[0]?.text) {
|
|
||||||
const text = candidate.content.parts[0].text
|
|
||||||
try {
|
|
||||||
JSON.parse(text)
|
|
||||||
logger.info('Successfully received structured JSON output')
|
|
||||||
} catch (_e) {
|
|
||||||
logger.warn('Failed to parse structured output as JSON')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let content = ''
|
|
||||||
let tokens = {
|
|
||||||
prompt: 0,
|
|
||||||
completion: 0,
|
|
||||||
total: 0,
|
|
||||||
}
|
|
||||||
const toolCalls = []
|
|
||||||
const toolResults = []
|
|
||||||
let iterationCount = 0
|
|
||||||
|
|
||||||
const originalToolConfig = preparedTools?.toolConfig
|
|
||||||
const forcedTools = preparedTools?.forcedTools || []
|
|
||||||
let usedForcedTools: string[] = []
|
|
||||||
let hasUsedForcedTool = false
|
|
||||||
let currentToolConfig = originalToolConfig
|
|
||||||
|
|
||||||
const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => {
|
|
||||||
if (currentToolConfig && forcedTools.length > 0) {
|
|
||||||
const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }]
|
|
||||||
const result = trackForcedToolUsage(
|
|
||||||
toolCallsForTracking,
|
|
||||||
currentToolConfig,
|
|
||||||
logger,
|
|
||||||
'google',
|
|
||||||
forcedTools,
|
|
||||||
usedForcedTools
|
|
||||||
)
|
|
||||||
hasUsedForcedTool = result.hasUsedForcedTool
|
|
||||||
usedForcedTools = result.usedForcedTools
|
|
||||||
|
|
||||||
if (result.nextToolConfig) {
|
|
||||||
currentToolConfig = result.nextToolConfig
|
|
||||||
logger.info('Updated tool config for next iteration', {
|
|
||||||
hasNextToolConfig: !!currentToolConfig,
|
|
||||||
usedForcedTools: usedForcedTools,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let modelTime = firstResponseTime
|
|
||||||
let toolsTime = 0
|
|
||||||
|
|
||||||
const timeSegments: TimeSegment[] = [
|
|
||||||
{
|
|
||||||
type: 'model',
|
|
||||||
name: 'Initial response',
|
|
||||||
startTime: initialCallTime,
|
|
||||||
endTime: initialCallTime + firstResponseTime,
|
|
||||||
duration: firstResponseTime,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
try {
|
|
||||||
const candidate = geminiResponse.candidates?.[0]
|
|
||||||
|
|
||||||
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
|
|
||||||
logger.warn(
|
|
||||||
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
|
|
||||||
{
|
|
||||||
finishReason: candidate.finishReason,
|
|
||||||
hasContent: !!candidate?.content,
|
|
||||||
hasParts: !!candidate?.content?.parts,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
content = extractTextContent(candidate)
|
|
||||||
}
|
|
||||||
|
|
||||||
const functionCall = extractFunctionCall(candidate)
|
|
||||||
|
|
||||||
if (functionCall) {
|
|
||||||
logger.info(`Received function call from Vertex AI: ${functionCall.name}`)
|
|
||||||
|
|
||||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
|
||||||
const latestResponse = geminiResponse.candidates?.[0]
|
|
||||||
const latestFunctionCall = extractFunctionCall(latestResponse)
|
|
||||||
|
|
||||||
if (!latestFunctionCall) {
|
|
||||||
content = extractTextContent(latestResponse)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
`Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
|
||||||
)
|
|
||||||
|
|
||||||
const toolsStartTime = Date.now()
|
|
||||||
|
|
||||||
try {
|
|
||||||
const toolName = latestFunctionCall.name
|
|
||||||
const toolArgs = latestFunctionCall.args || {}
|
|
||||||
|
|
||||||
const tool = request.tools?.find((t) => t.id === toolName)
|
|
||||||
if (!tool) {
|
|
||||||
logger.warn(`Tool ${toolName} not found in registry, skipping`)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
const toolCallStartTime = Date.now()
|
|
||||||
|
|
||||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
|
||||||
const result = await executeTool(toolName, executionParams, true)
|
|
||||||
const toolCallEndTime = Date.now()
|
|
||||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
|
||||||
|
|
||||||
timeSegments.push({
|
|
||||||
type: 'tool',
|
|
||||||
name: toolName,
|
|
||||||
startTime: toolCallStartTime,
|
|
||||||
endTime: toolCallEndTime,
|
|
||||||
duration: toolCallDuration,
|
|
||||||
})
|
|
||||||
|
|
||||||
let resultContent: any
|
|
||||||
if (result.success) {
|
|
||||||
toolResults.push(result.output)
|
|
||||||
resultContent = result.output
|
|
||||||
} else {
|
|
||||||
resultContent = {
|
|
||||||
error: true,
|
|
||||||
message: result.error || 'Tool execution failed',
|
|
||||||
tool: toolName,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
toolCalls.push({
|
|
||||||
name: toolName,
|
|
||||||
arguments: toolParams,
|
|
||||||
startTime: new Date(toolCallStartTime).toISOString(),
|
|
||||||
endTime: new Date(toolCallEndTime).toISOString(),
|
|
||||||
duration: toolCallDuration,
|
|
||||||
result: resultContent,
|
|
||||||
success: result.success,
|
|
||||||
})
|
|
||||||
|
|
||||||
const simplifiedMessages = [
|
|
||||||
...(contents.filter((m) => m.role === 'user').length > 0
|
|
||||||
? [contents.filter((m) => m.role === 'user')[0]]
|
|
||||||
: [contents[0]]),
|
|
||||||
{
|
|
||||||
role: 'model',
|
|
||||||
parts: [
|
|
||||||
{
|
|
||||||
functionCall: {
|
|
||||||
name: latestFunctionCall.name,
|
|
||||||
args: latestFunctionCall.args,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: 'user',
|
|
||||||
parts: [
|
|
||||||
{
|
|
||||||
text: `Function ${latestFunctionCall.name} result: ${JSON.stringify(resultContent)}`,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
const thisToolsTime = Date.now() - toolsStartTime
|
|
||||||
toolsTime += thisToolsTime
|
|
||||||
|
|
||||||
checkForForcedToolUsage(latestFunctionCall)
|
|
||||||
|
|
||||||
const nextModelStartTime = Date.now()
|
|
||||||
|
|
||||||
try {
|
|
||||||
if (request.stream) {
|
|
||||||
const streamingPayload = {
|
|
||||||
...payload,
|
|
||||||
contents: simplifiedMessages,
|
|
||||||
}
|
|
||||||
|
|
||||||
const allForcedToolsUsed =
|
|
||||||
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
|
|
||||||
|
|
||||||
if (allForcedToolsUsed && request.responseFormat) {
|
|
||||||
streamingPayload.tools = undefined
|
|
||||||
streamingPayload.toolConfig = undefined
|
|
||||||
|
|
||||||
const responseFormatSchema =
|
|
||||||
request.responseFormat.schema || request.responseFormat
|
|
||||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
|
||||||
|
|
||||||
if (!streamingPayload.generationConfig) {
|
|
||||||
streamingPayload.generationConfig = {}
|
|
||||||
}
|
|
||||||
streamingPayload.generationConfig.responseMimeType = 'application/json'
|
|
||||||
streamingPayload.generationConfig.responseSchema = cleanSchema
|
|
||||||
|
|
||||||
logger.info('Using structured output for final response after tool execution')
|
|
||||||
} else {
|
|
||||||
if (currentToolConfig) {
|
|
||||||
streamingPayload.toolConfig = currentToolConfig
|
|
||||||
} else {
|
|
||||||
streamingPayload.toolConfig = { functionCallingConfig: { mode: 'AUTO' } }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const checkPayload = {
|
|
||||||
...streamingPayload,
|
|
||||||
}
|
|
||||||
checkPayload.stream = undefined
|
|
||||||
|
|
||||||
const checkEndpoint = buildVertexEndpoint(
|
|
||||||
vertexProject,
|
|
||||||
vertexLocation,
|
|
||||||
requestedModel,
|
|
||||||
false
|
|
||||||
)
|
|
||||||
|
|
||||||
const checkResponse = await fetch(checkEndpoint, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
Authorization: `Bearer ${request.apiKey}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify(checkPayload),
|
|
||||||
})
|
|
||||||
|
|
||||||
if (!checkResponse.ok) {
|
|
||||||
const errorBody = await checkResponse.text()
|
|
||||||
logger.error('Error in Vertex AI check request:', {
|
|
||||||
status: checkResponse.status,
|
|
||||||
statusText: checkResponse.statusText,
|
|
||||||
responseBody: errorBody,
|
|
||||||
})
|
|
||||||
throw new Error(
|
|
||||||
`Vertex AI API check error: ${checkResponse.status} ${checkResponse.statusText}`
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const checkResult = await checkResponse.json()
|
|
||||||
const checkCandidate = checkResult.candidates?.[0]
|
|
||||||
const checkFunctionCall = extractFunctionCall(checkCandidate)
|
|
||||||
|
|
||||||
if (checkFunctionCall) {
|
|
||||||
logger.info(
|
|
||||||
'Function call detected in follow-up, handling in non-streaming mode',
|
|
||||||
{
|
|
||||||
functionName: checkFunctionCall.name,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
geminiResponse = checkResult
|
|
||||||
|
|
||||||
if (checkResult.usageMetadata) {
|
|
||||||
tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0
|
|
||||||
tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0
|
|
||||||
tokens.total +=
|
|
||||||
(checkResult.usageMetadata.promptTokenCount || 0) +
|
|
||||||
(checkResult.usageMetadata.candidatesTokenCount || 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
const nextModelEndTime = Date.now()
|
|
||||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
|
||||||
modelTime += thisModelTime
|
|
||||||
|
|
||||||
timeSegments.push({
|
|
||||||
type: 'model',
|
|
||||||
name: `Model response (iteration ${iterationCount + 1})`,
|
|
||||||
startTime: nextModelStartTime,
|
|
||||||
endTime: nextModelEndTime,
|
|
||||||
duration: thisModelTime,
|
|
||||||
})
|
|
||||||
|
|
||||||
iterationCount++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info('No function call detected, proceeding with streaming response')
|
|
||||||
|
|
||||||
if (request.responseFormat) {
|
|
||||||
streamingPayload.tools = undefined
|
|
||||||
streamingPayload.toolConfig = undefined
|
|
||||||
|
|
||||||
const responseFormatSchema =
|
|
||||||
request.responseFormat.schema || request.responseFormat
|
|
||||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
|
||||||
|
|
||||||
if (!streamingPayload.generationConfig) {
|
|
||||||
streamingPayload.generationConfig = {}
|
|
||||||
}
|
|
||||||
streamingPayload.generationConfig.responseMimeType = 'application/json'
|
|
||||||
streamingPayload.generationConfig.responseSchema = cleanSchema
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
'Using structured output for final streaming response after tool execution'
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const streamEndpoint = buildVertexEndpoint(
|
|
||||||
vertexProject,
|
|
||||||
vertexLocation,
|
|
||||||
requestedModel,
|
|
||||||
true
|
|
||||||
)
|
|
||||||
|
|
||||||
const streamingResponse = await fetch(streamEndpoint, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
Authorization: `Bearer ${request.apiKey}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify(streamingPayload),
|
|
||||||
})
|
|
||||||
|
|
||||||
if (!streamingResponse.ok) {
|
|
||||||
const errorBody = await streamingResponse.text()
|
|
||||||
logger.error('Error in Vertex AI streaming follow-up request:', {
|
|
||||||
status: streamingResponse.status,
|
|
||||||
statusText: streamingResponse.statusText,
|
|
||||||
responseBody: errorBody,
|
|
||||||
})
|
|
||||||
throw new Error(
|
|
||||||
`Vertex AI API streaming error: ${streamingResponse.status} ${streamingResponse.statusText}`
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const nextModelEndTime = Date.now()
|
|
||||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
|
||||||
modelTime += thisModelTime
|
|
||||||
|
|
||||||
timeSegments.push({
|
|
||||||
type: 'model',
|
|
||||||
name: 'Final streaming response after tool calls',
|
|
||||||
startTime: nextModelStartTime,
|
|
||||||
endTime: nextModelEndTime,
|
|
||||||
duration: thisModelTime,
|
|
||||||
})
|
|
||||||
|
|
||||||
const streamingExecution: StreamingExecution = {
|
|
||||||
stream: null as any,
|
|
||||||
execution: {
|
|
||||||
success: true,
|
|
||||||
output: {
|
|
||||||
content: '',
|
|
||||||
model: request.model,
|
|
||||||
tokens,
|
|
||||||
toolCalls:
|
|
||||||
toolCalls.length > 0
|
|
||||||
? {
|
|
||||||
list: toolCalls,
|
|
||||||
count: toolCalls.length,
|
|
||||||
}
|
|
||||||
: undefined,
|
|
||||||
toolResults,
|
|
||||||
providerTiming: {
|
|
||||||
startTime: providerStartTimeISO,
|
|
||||||
endTime: new Date().toISOString(),
|
|
||||||
duration: Date.now() - providerStartTime,
|
|
||||||
modelTime,
|
|
||||||
toolsTime,
|
|
||||||
firstResponseTime,
|
|
||||||
iterations: iterationCount + 1,
|
|
||||||
timeSegments,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
logs: [],
|
|
||||||
metadata: {
|
|
||||||
startTime: providerStartTimeISO,
|
|
||||||
endTime: new Date().toISOString(),
|
|
||||||
duration: Date.now() - providerStartTime,
|
|
||||||
},
|
|
||||||
isStreaming: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
streamingExecution.stream = createReadableStreamFromVertexStream(
|
|
||||||
streamingResponse,
|
|
||||||
(content, usage) => {
|
|
||||||
streamingExecution.execution.output.content = content
|
|
||||||
|
|
||||||
const streamEndTime = Date.now()
|
|
||||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
|
||||||
|
|
||||||
if (streamingExecution.execution.output.providerTiming) {
|
|
||||||
streamingExecution.execution.output.providerTiming.endTime =
|
|
||||||
streamEndTimeISO
|
|
||||||
streamingExecution.execution.output.providerTiming.duration =
|
|
||||||
streamEndTime - providerStartTime
|
|
||||||
}
|
|
||||||
|
|
||||||
const promptTokens = usage?.promptTokenCount || 0
|
|
||||||
const completionTokens = usage?.candidatesTokenCount || 0
|
|
||||||
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
|
|
||||||
|
|
||||||
const existingTokens = streamingExecution.execution.output.tokens || {
|
|
||||||
prompt: 0,
|
|
||||||
completion: 0,
|
|
||||||
total: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
const existingPrompt = existingTokens.prompt || 0
|
|
||||||
const existingCompletion = existingTokens.completion || 0
|
|
||||||
const existingTotal = existingTokens.total || 0
|
|
||||||
|
|
||||||
streamingExecution.execution.output.tokens = {
|
|
||||||
prompt: existingPrompt + promptTokens,
|
|
||||||
completion: existingCompletion + completionTokens,
|
|
||||||
total: existingTotal + totalTokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
const accumulatedCost = calculateCost(
|
|
||||||
request.model,
|
|
||||||
existingPrompt,
|
|
||||||
existingCompletion
|
|
||||||
)
|
|
||||||
const streamCost = calculateCost(
|
|
||||||
request.model,
|
|
||||||
promptTokens,
|
|
||||||
completionTokens
|
|
||||||
)
|
|
||||||
streamingExecution.execution.output.cost = {
|
|
||||||
input: accumulatedCost.input + streamCost.input,
|
|
||||||
output: accumulatedCost.output + streamCost.output,
|
|
||||||
total: accumulatedCost.total + streamCost.total,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return streamingExecution
|
|
||||||
}
|
|
||||||
|
|
||||||
const nextPayload = {
|
|
||||||
...payload,
|
|
||||||
contents: simplifiedMessages,
|
|
||||||
}
|
|
||||||
|
|
||||||
const allForcedToolsUsed =
|
|
||||||
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
|
|
||||||
|
|
||||||
if (allForcedToolsUsed && request.responseFormat) {
|
|
||||||
nextPayload.tools = undefined
|
|
||||||
nextPayload.toolConfig = undefined
|
|
||||||
|
|
||||||
const responseFormatSchema =
|
|
||||||
request.responseFormat.schema || request.responseFormat
|
|
||||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
|
||||||
|
|
||||||
if (!nextPayload.generationConfig) {
|
|
||||||
nextPayload.generationConfig = {}
|
|
||||||
}
|
|
||||||
nextPayload.generationConfig.responseMimeType = 'application/json'
|
|
||||||
nextPayload.generationConfig.responseSchema = cleanSchema
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
'Using structured output for final non-streaming response after tool execution'
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
if (currentToolConfig) {
|
|
||||||
nextPayload.toolConfig = currentToolConfig
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const nextEndpoint = buildVertexEndpoint(
|
|
||||||
vertexProject,
|
|
||||||
vertexLocation,
|
|
||||||
requestedModel,
|
|
||||||
false
|
|
||||||
)
|
|
||||||
|
|
||||||
const nextResponse = await fetch(nextEndpoint, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
Authorization: `Bearer ${request.apiKey}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify(nextPayload),
|
|
||||||
})
|
|
||||||
|
|
||||||
if (!nextResponse.ok) {
|
|
||||||
const errorBody = await nextResponse.text()
|
|
||||||
logger.error('Error in Vertex AI follow-up request:', {
|
|
||||||
status: nextResponse.status,
|
|
||||||
statusText: nextResponse.statusText,
|
|
||||||
responseBody: errorBody,
|
|
||||||
iterationCount,
|
|
||||||
})
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
geminiResponse = await nextResponse.json()
|
|
||||||
|
|
||||||
const nextModelEndTime = Date.now()
|
|
||||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
|
||||||
|
|
||||||
timeSegments.push({
|
|
||||||
type: 'model',
|
|
||||||
name: `Model response (iteration ${iterationCount + 1})`,
|
|
||||||
startTime: nextModelStartTime,
|
|
||||||
endTime: nextModelEndTime,
|
|
||||||
duration: thisModelTime,
|
|
||||||
})
|
|
||||||
|
|
||||||
modelTime += thisModelTime
|
|
||||||
|
|
||||||
const nextCandidate = geminiResponse.candidates?.[0]
|
|
||||||
const nextFunctionCall = extractFunctionCall(nextCandidate)
|
|
||||||
|
|
||||||
if (!nextFunctionCall) {
|
|
||||||
if (request.responseFormat) {
|
|
||||||
const finalPayload = {
|
|
||||||
...payload,
|
|
||||||
contents: nextPayload.contents,
|
|
||||||
tools: undefined,
|
|
||||||
toolConfig: undefined,
|
|
||||||
}
|
|
||||||
|
|
||||||
const responseFormatSchema =
|
|
||||||
request.responseFormat.schema || request.responseFormat
|
|
||||||
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
|
|
||||||
|
|
||||||
if (!finalPayload.generationConfig) {
|
|
||||||
finalPayload.generationConfig = {}
|
|
||||||
}
|
|
||||||
finalPayload.generationConfig.responseMimeType = 'application/json'
|
|
||||||
finalPayload.generationConfig.responseSchema = cleanSchema
|
|
||||||
|
|
||||||
logger.info('Making final request with structured output after tool execution')
|
|
||||||
|
|
||||||
const finalEndpoint = buildVertexEndpoint(
|
|
||||||
vertexProject,
|
|
||||||
vertexLocation,
|
|
||||||
requestedModel,
|
|
||||||
false
|
|
||||||
)
|
|
||||||
|
|
||||||
const finalResponse = await fetch(finalEndpoint, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
Authorization: `Bearer ${request.apiKey}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify(finalPayload),
|
|
||||||
})
|
|
||||||
|
|
||||||
if (finalResponse.ok) {
|
|
||||||
const finalResult = await finalResponse.json()
|
|
||||||
const finalCandidate = finalResult.candidates?.[0]
|
|
||||||
content = extractTextContent(finalCandidate)
|
|
||||||
|
|
||||||
if (finalResult.usageMetadata) {
|
|
||||||
tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0
|
|
||||||
tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0
|
|
||||||
tokens.total +=
|
|
||||||
(finalResult.usageMetadata.promptTokenCount || 0) +
|
|
||||||
(finalResult.usageMetadata.candidatesTokenCount || 0)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.warn(
|
|
||||||
'Failed to get structured output, falling back to regular response'
|
|
||||||
)
|
|
||||||
content = extractTextContent(nextCandidate)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
content = extractTextContent(nextCandidate)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
iterationCount++
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('Error in Vertex AI follow-up request:', {
|
|
||||||
error: error instanceof Error ? error.message : String(error),
|
|
||||||
iterationCount,
|
|
||||||
})
|
|
||||||
break
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('Error processing function call:', {
|
|
||||||
error: error instanceof Error ? error.message : String(error),
|
|
||||||
functionName: latestFunctionCall?.name || 'unknown',
|
|
||||||
})
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
content = extractTextContent(candidate)
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('Error processing Vertex AI response:', {
|
|
||||||
error: error instanceof Error ? error.message : String(error),
|
|
||||||
iterationCount,
|
|
||||||
})
|
|
||||||
|
|
||||||
if (!content && toolCalls.length > 0) {
|
|
||||||
content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.`
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const providerEndTime = Date.now()
|
|
||||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
|
||||||
const totalDuration = providerEndTime - providerStartTime
|
|
||||||
|
|
||||||
if (geminiResponse.usageMetadata) {
|
|
||||||
tokens = {
|
|
||||||
prompt: geminiResponse.usageMetadata.promptTokenCount || 0,
|
|
||||||
completion: geminiResponse.usageMetadata.candidatesTokenCount || 0,
|
|
||||||
total:
|
|
||||||
(geminiResponse.usageMetadata.promptTokenCount || 0) +
|
|
||||||
(geminiResponse.usageMetadata.candidatesTokenCount || 0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
content,
|
|
||||||
model: request.model,
|
|
||||||
tokens,
|
|
||||||
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
|
|
||||||
toolResults: toolResults.length > 0 ? toolResults : undefined,
|
|
||||||
timing: {
|
|
||||||
startTime: providerStartTimeISO,
|
|
||||||
endTime: providerEndTimeISO,
|
|
||||||
duration: totalDuration,
|
|
||||||
modelTime: modelTime,
|
|
||||||
toolsTime: toolsTime,
|
|
||||||
firstResponseTime: firstResponseTime,
|
|
||||||
iterations: iterationCount + 1,
|
|
||||||
timeSegments: timeSegments,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
const providerEndTime = Date.now()
|
|
||||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
|
||||||
const totalDuration = providerEndTime - providerStartTime
|
|
||||||
|
|
||||||
logger.error('Error in Vertex AI request:', {
|
|
||||||
error: error instanceof Error ? error.message : String(error),
|
|
||||||
duration: totalDuration,
|
|
||||||
})
|
|
||||||
|
|
||||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
|
||||||
// @ts-ignore
|
|
||||||
enhancedError.timing = {
|
|
||||||
startTime: providerStartTimeISO,
|
|
||||||
endTime: providerEndTimeISO,
|
|
||||||
duration: totalDuration,
|
|
||||||
}
|
|
||||||
|
|
||||||
throw enhancedError
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,231 +0,0 @@
|
|||||||
import { createLogger } from '@/lib/logs/console/logger'
|
|
||||||
import { extractFunctionCall, extractTextContent } from '@/providers/google/utils'
|
|
||||||
|
|
||||||
const logger = createLogger('VertexUtils')
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a ReadableStream from Vertex AI's Gemini stream response
|
|
||||||
*/
|
|
||||||
export function createReadableStreamFromVertexStream(
|
|
||||||
response: Response,
|
|
||||||
onComplete?: (
|
|
||||||
content: string,
|
|
||||||
usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number }
|
|
||||||
) => void
|
|
||||||
): ReadableStream<Uint8Array> {
|
|
||||||
const reader = response.body?.getReader()
|
|
||||||
if (!reader) {
|
|
||||||
throw new Error('Failed to get reader from response body')
|
|
||||||
}
|
|
||||||
|
|
||||||
return new ReadableStream({
|
|
||||||
async start(controller) {
|
|
||||||
try {
|
|
||||||
let buffer = ''
|
|
||||||
let fullContent = ''
|
|
||||||
let usageData: {
|
|
||||||
promptTokenCount?: number
|
|
||||||
candidatesTokenCount?: number
|
|
||||||
totalTokenCount?: number
|
|
||||||
} | null = null
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
const { done, value } = await reader.read()
|
|
||||||
if (done) {
|
|
||||||
if (buffer.trim()) {
|
|
||||||
try {
|
|
||||||
const data = JSON.parse(buffer.trim())
|
|
||||||
if (data.usageMetadata) {
|
|
||||||
usageData = data.usageMetadata
|
|
||||||
}
|
|
||||||
const candidate = data.candidates?.[0]
|
|
||||||
if (candidate?.content?.parts) {
|
|
||||||
const functionCall = extractFunctionCall(candidate)
|
|
||||||
if (functionCall) {
|
|
||||||
logger.debug(
|
|
||||||
'Function call detected in final buffer, ending stream to execute tool',
|
|
||||||
{
|
|
||||||
functionName: functionCall.name,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
|
||||||
controller.close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
const content = extractTextContent(candidate)
|
|
||||||
if (content) {
|
|
||||||
fullContent += content
|
|
||||||
controller.enqueue(new TextEncoder().encode(content))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
if (buffer.trim().startsWith('[')) {
|
|
||||||
try {
|
|
||||||
const dataArray = JSON.parse(buffer.trim())
|
|
||||||
if (Array.isArray(dataArray)) {
|
|
||||||
for (const item of dataArray) {
|
|
||||||
if (item.usageMetadata) {
|
|
||||||
usageData = item.usageMetadata
|
|
||||||
}
|
|
||||||
const candidate = item.candidates?.[0]
|
|
||||||
if (candidate?.content?.parts) {
|
|
||||||
const functionCall = extractFunctionCall(candidate)
|
|
||||||
if (functionCall) {
|
|
||||||
logger.debug(
|
|
||||||
'Function call detected in array item, ending stream to execute tool',
|
|
||||||
{
|
|
||||||
functionName: functionCall.name,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
|
||||||
controller.close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
const content = extractTextContent(candidate)
|
|
||||||
if (content) {
|
|
||||||
fullContent += content
|
|
||||||
controller.enqueue(new TextEncoder().encode(content))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (arrayError) {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
|
||||||
controller.close()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
const text = new TextDecoder().decode(value)
|
|
||||||
buffer += text
|
|
||||||
|
|
||||||
let searchIndex = 0
|
|
||||||
while (searchIndex < buffer.length) {
|
|
||||||
const openBrace = buffer.indexOf('{', searchIndex)
|
|
||||||
if (openBrace === -1) break
|
|
||||||
|
|
||||||
let braceCount = 0
|
|
||||||
let inString = false
|
|
||||||
let escaped = false
|
|
||||||
let closeBrace = -1
|
|
||||||
|
|
||||||
for (let i = openBrace; i < buffer.length; i++) {
|
|
||||||
const char = buffer[i]
|
|
||||||
|
|
||||||
if (!inString) {
|
|
||||||
if (char === '"' && !escaped) {
|
|
||||||
inString = true
|
|
||||||
} else if (char === '{') {
|
|
||||||
braceCount++
|
|
||||||
} else if (char === '}') {
|
|
||||||
braceCount--
|
|
||||||
if (braceCount === 0) {
|
|
||||||
closeBrace = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (char === '"' && !escaped) {
|
|
||||||
inString = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
escaped = char === '\\' && !escaped
|
|
||||||
}
|
|
||||||
|
|
||||||
if (closeBrace !== -1) {
|
|
||||||
const jsonStr = buffer.substring(openBrace, closeBrace + 1)
|
|
||||||
|
|
||||||
try {
|
|
||||||
const data = JSON.parse(jsonStr)
|
|
||||||
|
|
||||||
if (data.usageMetadata) {
|
|
||||||
usageData = data.usageMetadata
|
|
||||||
}
|
|
||||||
|
|
||||||
const candidate = data.candidates?.[0]
|
|
||||||
|
|
||||||
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
|
|
||||||
logger.warn(
|
|
||||||
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
|
|
||||||
{
|
|
||||||
finishReason: candidate.finishReason,
|
|
||||||
hasContent: !!candidate?.content,
|
|
||||||
hasParts: !!candidate?.content?.parts,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
const textContent = extractTextContent(candidate)
|
|
||||||
if (textContent) {
|
|
||||||
fullContent += textContent
|
|
||||||
controller.enqueue(new TextEncoder().encode(textContent))
|
|
||||||
}
|
|
||||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
|
||||||
controller.close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (candidate?.content?.parts) {
|
|
||||||
const functionCall = extractFunctionCall(candidate)
|
|
||||||
if (functionCall) {
|
|
||||||
logger.debug(
|
|
||||||
'Function call detected in stream, ending stream to execute tool',
|
|
||||||
{
|
|
||||||
functionName: functionCall.name,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
|
||||||
controller.close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
const content = extractTextContent(candidate)
|
|
||||||
if (content) {
|
|
||||||
fullContent += content
|
|
||||||
controller.enqueue(new TextEncoder().encode(content))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
logger.error('Error parsing JSON from stream', {
|
|
||||||
error: e instanceof Error ? e.message : String(e),
|
|
||||||
jsonPreview: jsonStr.substring(0, 200),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer = buffer.substring(closeBrace + 1)
|
|
||||||
searchIndex = 0
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
logger.error('Error reading Vertex AI stream', {
|
|
||||||
error: e instanceof Error ? e.message : String(e),
|
|
||||||
})
|
|
||||||
controller.error(e)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
async cancel() {
|
|
||||||
await reader.cancel()
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build Vertex AI endpoint URL
|
|
||||||
*/
|
|
||||||
export function buildVertexEndpoint(
|
|
||||||
project: string,
|
|
||||||
location: string,
|
|
||||||
model: string,
|
|
||||||
isStreaming: boolean
|
|
||||||
): string {
|
|
||||||
const action = isStreaming ? 'streamGenerateContent' : 'generateContent'
|
|
||||||
|
|
||||||
if (location === 'global') {
|
|
||||||
return `https://aiplatform.googleapis.com/v1/projects/${project}/locations/global/publishers/google/models/${model}:${action}`
|
|
||||||
}
|
|
||||||
|
|
||||||
return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/google/models/${model}:${action}`
|
|
||||||
}
|
|
||||||
@@ -130,7 +130,7 @@ export const vllmProvider: ProviderConfig = {
|
|||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
const payload: any = {
|
const payload: any = {
|
||||||
model: (request.model || getProviderDefaultModel('vllm')).replace(/^vllm\//, ''),
|
model: request.model.replace(/^vllm\//, ''),
|
||||||
messages: allMessages,
|
messages: allMessages,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ export const xAIProvider: ProviderConfig = {
|
|||||||
hasTools: !!request.tools?.length,
|
hasTools: !!request.tools?.length,
|
||||||
toolCount: request.tools?.length || 0,
|
toolCount: request.tools?.length || 0,
|
||||||
hasResponseFormat: !!request.responseFormat,
|
hasResponseFormat: !!request.responseFormat,
|
||||||
model: request.model || 'grok-3-latest',
|
model: request.model,
|
||||||
streaming: !!request.stream,
|
streaming: !!request.stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ export const xAIProvider: ProviderConfig = {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
const basePayload: any = {
|
const basePayload: any = {
|
||||||
model: request.model || 'grok-3-latest',
|
model: request.model,
|
||||||
messages: allMessages,
|
messages: allMessages,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,7 +139,7 @@ export const xAIProvider: ProviderConfig = {
|
|||||||
success: true,
|
success: true,
|
||||||
output: {
|
output: {
|
||||||
content: '',
|
content: '',
|
||||||
model: request.model || 'grok-3-latest',
|
model: request.model,
|
||||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||||
toolCalls: undefined,
|
toolCalls: undefined,
|
||||||
providerTiming: {
|
providerTiming: {
|
||||||
@@ -505,7 +505,7 @@ export const xAIProvider: ProviderConfig = {
|
|||||||
success: true,
|
success: true,
|
||||||
output: {
|
output: {
|
||||||
content: '',
|
content: '',
|
||||||
model: request.model || 'grok-3-latest',
|
model: request.model,
|
||||||
tokens: {
|
tokens: {
|
||||||
prompt: tokens.prompt,
|
prompt: tokens.prompt,
|
||||||
completion: tokens.completion,
|
completion: tokens.completion,
|
||||||
|
|||||||
Reference in New Issue
Block a user