fix(search): removed full text param from built-in search, anthropic provider streaming fix (#2542)

* fix(search): removed full text param from built-in search, anthropic provider streaming fix

* rewrite gemini provider with official sdk + add thinking capability

* vertex gemini consolidation

* never silently use different model

* pass oauth client through the googleAuthOptions param directly

* make server side provider registry

* remove comments

* take oauth selector below model selector

---------

Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
This commit is contained in:
Waleed
2025-12-22 23:57:11 -08:00
committed by GitHub
parent f5245f3eca
commit b0748c82f9
31 changed files with 1607 additions and 2431 deletions

View File

@@ -56,7 +56,7 @@ export async function POST(request: NextRequest) {
query: validated.query,
type: 'auto',
useAutoprompt: true,
text: true,
highlights: true,
apiKey: env.EXA_API_KEY,
})
@@ -77,7 +77,7 @@ export async function POST(request: NextRequest) {
const results = (result.output.results || []).map((r: any, index: number) => ({
title: r.title || '',
link: r.url || '',
snippet: r.text || '',
snippet: Array.isArray(r.highlights) ? r.highlights.join(' ... ') : '',
date: r.publishedDate || undefined,
position: index + 1,
}))

View File

@@ -43,6 +43,8 @@ const SCOPE_DESCRIPTIONS: Record<string, string> = {
'https://www.googleapis.com/auth/admin.directory.group.readonly': 'View Google Workspace groups',
'https://www.googleapis.com/auth/admin.directory.group.member.readonly':
'View Google Workspace group memberships',
'https://www.googleapis.com/auth/cloud-platform':
'Full access to Google Cloud resources for Vertex AI',
'read:confluence-content.all': 'Read all Confluence content',
'read:confluence-space.summary': 'Read Confluence space information',
'read:space:confluence': 'View Confluence spaces',

View File

@@ -9,8 +9,10 @@ import {
getMaxTemperature,
getProviderIcon,
getReasoningEffortValuesForModel,
getThinkingLevelsForModel,
getVerbosityValuesForModel,
MODELS_WITH_REASONING_EFFORT,
MODELS_WITH_THINKING,
MODELS_WITH_VERBOSITY,
providers,
supportsTemperature,
@@ -108,7 +110,19 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
})
},
},
{
id: 'vertexCredential',
title: 'Google Cloud Account',
type: 'oauth-input',
serviceId: 'vertex-ai',
requiredScopes: ['https://www.googleapis.com/auth/cloud-platform'],
placeholder: 'Select Google Cloud account',
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'reasoningEffort',
title: 'Reasoning Effort',
@@ -215,6 +229,57 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
value: MODELS_WITH_VERBOSITY,
},
},
{
id: 'thinkingLevel',
title: 'Thinking Level',
type: 'dropdown',
placeholder: 'Select thinking level...',
options: [
{ label: 'minimal', id: 'minimal' },
{ label: 'low', id: 'low' },
{ label: 'medium', id: 'medium' },
{ label: 'high', id: 'high' },
],
dependsOn: ['model'],
fetchOptions: async (blockId: string) => {
const { useSubBlockStore } = await import('@/stores/workflows/subblock/store')
const { useWorkflowRegistry } = await import('@/stores/workflows/registry/store')
const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId
if (!activeWorkflowId) {
return [
{ label: 'low', id: 'low' },
{ label: 'high', id: 'high' },
]
}
const workflowValues = useSubBlockStore.getState().workflowValues[activeWorkflowId]
const blockValues = workflowValues?.[blockId]
const modelValue = blockValues?.model as string
if (!modelValue) {
return [
{ label: 'low', id: 'low' },
{ label: 'high', id: 'high' },
]
}
const validOptions = getThinkingLevelsForModel(modelValue)
if (!validOptions) {
return [
{ label: 'low', id: 'low' },
{ label: 'high', id: 'high' },
]
}
return validOptions.map((opt) => ({ label: opt, id: opt }))
},
value: () => 'high',
condition: {
field: 'model',
value: MODELS_WITH_THINKING,
},
},
{
id: 'azureEndpoint',
@@ -275,17 +340,21 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
password: true,
connectionDroppable: false,
required: true,
// Hide API key for hosted models, Ollama models, and vLLM models
// Hide API key for hosted models, Ollama models, vLLM models, and Vertex models (uses OAuth)
condition: isHosted
? {
field: 'model',
value: getHostedModels(),
value: [...getHostedModels(), ...providers.vertex.models],
not: true, // Show for all models EXCEPT those listed
}
: () => ({
field: 'model',
value: [...getCurrentOllamaModels(), ...getCurrentVLLMModels()],
not: true, // Show for all models EXCEPT Ollama and vLLM models
value: [
...getCurrentOllamaModels(),
...getCurrentVLLMModels(),
...providers.vertex.models,
],
not: true, // Show for all models EXCEPT Ollama, vLLM, and Vertex models
}),
},
{
@@ -609,6 +678,7 @@ Example 3 (Array Input):
temperature: { type: 'number', description: 'Response randomness level' },
reasoningEffort: { type: 'string', description: 'Reasoning effort level for GPT-5 models' },
verbosity: { type: 'string', description: 'Verbosity level for GPT-5 models' },
thinkingLevel: { type: 'string', description: 'Thinking level for Gemini 3 models' },
tools: { type: 'json', description: 'Available tools configuration' },
},
outputs: {

View File

@@ -1,8 +1,9 @@
import { db } from '@sim/db'
import { mcpServers } from '@sim/db/schema'
import { account, mcpServers } from '@sim/db/schema'
import { and, eq, inArray, isNull } from 'drizzle-orm'
import { createLogger } from '@/lib/logs/console/logger'
import { createMcpToolId } from '@/lib/mcp/utils'
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
import { getAllBlocks } from '@/blocks'
import type { BlockOutput } from '@/blocks/types'
import { AGENT, BlockType, DEFAULTS, HTTP } from '@/executor/constants'
@@ -919,6 +920,7 @@ export class AgentBlockHandler implements BlockHandler {
azureApiVersion: inputs.azureApiVersion,
vertexProject: inputs.vertexProject,
vertexLocation: inputs.vertexLocation,
vertexCredential: inputs.vertexCredential,
responseFormat,
workflowId: ctx.workflowId,
workspaceId: ctx.workspaceId,
@@ -997,7 +999,17 @@ export class AgentBlockHandler implements BlockHandler {
responseFormat: any,
providerStartTime: number
) {
const finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey)
let finalApiKey: string
// For Vertex AI, resolve OAuth credential to access token
if (providerId === 'vertex' && providerRequest.vertexCredential) {
finalApiKey = await this.resolveVertexCredential(
providerRequest.vertexCredential,
ctx.workflowId
)
} else {
finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey)
}
const { blockData, blockNameMapping } = collectBlockData(ctx)
@@ -1024,7 +1036,6 @@ export class AgentBlockHandler implements BlockHandler {
blockNameMapping,
})
this.logExecutionSuccess(providerId, model, ctx, block, providerStartTime, response)
return this.processProviderResponse(response, block, responseFormat)
}
@@ -1049,15 +1060,6 @@ export class AgentBlockHandler implements BlockHandler {
throw new Error(errorMessage)
}
this.logExecutionSuccess(
providerRequest.provider,
providerRequest.model,
ctx,
block,
providerStartTime,
'HTTP response'
)
const contentType = response.headers.get('Content-Type')
if (contentType?.includes(HTTP.CONTENT_TYPE.EVENT_STREAM)) {
return this.handleStreamingResponse(response, block, ctx, inputs)
@@ -1117,21 +1119,33 @@ export class AgentBlockHandler implements BlockHandler {
}
}
private logExecutionSuccess(
provider: string,
model: string,
ctx: ExecutionContext,
block: SerializedBlock,
startTime: number,
response: any
) {
const executionTime = Date.now() - startTime
const responseType =
response instanceof ReadableStream
? 'stream'
: response && typeof response === 'object' && 'stream' in response
? 'streaming-execution'
: 'json'
/**
* Resolves a Vertex AI OAuth credential to an access token
*/
private async resolveVertexCredential(credentialId: string, workflowId: string): Promise<string> {
const requestId = `vertex-${Date.now()}`
logger.info(`[${requestId}] Resolving Vertex AI credential: ${credentialId}`)
// Get the credential - we need to find the owner
// Since we're in a workflow context, we can query the credential directly
const credential = await db.query.account.findFirst({
where: eq(account.id, credentialId),
})
if (!credential) {
throw new Error(`Vertex AI credential not found: ${credentialId}`)
}
// Refresh the token if needed
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
if (!accessToken) {
throw new Error('Failed to get Vertex AI access token')
}
logger.info(`[${requestId}] Successfully resolved Vertex AI credential`)
return accessToken
}
private handleExecutionError(

View File

@@ -21,6 +21,7 @@ export interface AgentInputs {
azureApiVersion?: string
vertexProject?: string
vertexLocation?: string
vertexCredential?: string
reasoningEffort?: string
verbosity?: string
}

View File

@@ -579,6 +579,21 @@ export const auth = betterAuth({
redirectURI: `${getBaseUrl()}/api/auth/oauth2/callback/google-groups`,
},
{
providerId: 'vertex-ai',
clientId: env.GOOGLE_CLIENT_ID as string,
clientSecret: env.GOOGLE_CLIENT_SECRET as string,
discoveryUrl: 'https://accounts.google.com/.well-known/openid-configuration',
accessType: 'offline',
scopes: [
'https://www.googleapis.com/auth/userinfo.email',
'https://www.googleapis.com/auth/userinfo.profile',
'https://www.googleapis.com/auth/cloud-platform',
],
prompt: 'consent',
redirectURI: `${getBaseUrl()}/api/auth/oauth2/callback/vertex-ai`,
},
{
providerId: 'microsoft-teams',
clientId: env.MICROSOFT_CLIENT_ID as string,

View File

@@ -41,7 +41,7 @@ function filterUserFile(data: any): any {
const DISPLAY_FILTERS = [filterUserFile]
export function filterForDisplay(data: any): any {
const seen = new WeakSet()
const seen = new Set<object>()
return filterForDisplayInternal(data, seen, 0)
}
@@ -49,7 +49,7 @@ function getObjectType(data: unknown): string {
return Object.prototype.toString.call(data).slice(8, -1)
}
function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: number): any {
function filterForDisplayInternal(data: any, seen: Set<object>, depth: number): any {
try {
if (data === null || data === undefined) {
return data
@@ -93,6 +93,7 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
return '[Unknown Type]'
}
// True circular reference: object is an ancestor in the current path
if (seen.has(data)) {
return '[Circular Reference]'
}
@@ -131,18 +132,24 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
return `[ArrayBuffer: ${(data as ArrayBuffer).byteLength} bytes]`
case 'Map': {
seen.add(data)
const obj: Record<string, any> = {}
for (const [key, value] of (data as Map<any, any>).entries()) {
const keyStr = typeof key === 'string' ? key : String(key)
obj[keyStr] = filterForDisplayInternal(value, seen, depth + 1)
}
seen.delete(data)
return obj
}
case 'Set':
return Array.from(data as Set<any>).map((item) =>
case 'Set': {
seen.add(data)
const result = Array.from(data as Set<any>).map((item) =>
filterForDisplayInternal(item, seen, depth + 1)
)
seen.delete(data)
return result
}
case 'WeakMap':
return '[WeakMap]'
@@ -161,17 +168,22 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
return `[${objectType}: ${(data as ArrayBufferView).byteLength} bytes]`
}
// Add to current path before processing children
seen.add(data)
for (const filterFn of DISPLAY_FILTERS) {
const result = filterFn(data)
if (result !== data) {
return filterForDisplayInternal(result, seen, depth + 1)
const filtered = filterFn(data)
if (filtered !== data) {
const result = filterForDisplayInternal(filtered, seen, depth + 1)
seen.delete(data)
return result
}
}
if (Array.isArray(data)) {
return data.map((item) => filterForDisplayInternal(item, seen, depth + 1))
const result = data.map((item) => filterForDisplayInternal(item, seen, depth + 1))
seen.delete(data)
return result
}
const result: Record<string, any> = {}
@@ -182,6 +194,8 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
result[key] = '[Error accessing property]'
}
}
// Remove from current path after processing children
seen.delete(data)
return result
} catch {
return '[Unserializable]'

View File

@@ -32,6 +32,7 @@ import {
SlackIcon,
SpotifyIcon,
TrelloIcon,
VertexIcon,
WealthboxIcon,
WebflowIcon,
WordpressIcon,
@@ -80,6 +81,7 @@ export type OAuthService =
| 'google-vault'
| 'google-forms'
| 'google-groups'
| 'vertex-ai'
| 'github'
| 'x'
| 'confluence'
@@ -237,6 +239,16 @@ export const OAUTH_PROVIDERS: Record<string, OAuthProviderConfig> = {
],
scopeHints: ['admin.directory.group'],
},
'vertex-ai': {
id: 'vertex-ai',
name: 'Vertex AI',
description: 'Access Google Cloud Vertex AI for Gemini models with OAuth.',
providerId: 'vertex-ai',
icon: (props) => VertexIcon(props),
baseProviderIcon: (props) => VertexIcon(props),
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
scopeHints: ['cloud-platform', 'vertex', 'aiplatform'],
},
},
defaultService: 'gmail',
},
@@ -1099,6 +1111,12 @@ export function parseProvider(provider: OAuthProvider): ProviderConfig {
featureType: 'microsoft-planner',
}
}
if (provider === 'vertex-ai') {
return {
baseProvider: 'google',
featureType: 'vertex-ai',
}
}
// Handle compound providers (e.g., 'google-email' -> { baseProvider: 'google', featureType: 'email' })
const [base, feature] = provider.split('-')

View File

@@ -58,7 +58,7 @@ export const anthropicProvider: ProviderConfig = {
throw new Error('API key is required for Anthropic')
}
const modelId = request.model || 'claude-3-7-sonnet-20250219'
const modelId = request.model
const useNativeStructuredOutputs = !!(
request.responseFormat && supportsNativeStructuredOutputs(modelId)
)
@@ -174,7 +174,7 @@ export const anthropicProvider: ProviderConfig = {
}
const payload: any = {
model: request.model || 'claude-3-7-sonnet-20250219',
model: request.model,
messages,
system: systemPrompt,
max_tokens: Number.parseInt(String(request.maxTokens)) || 1024,
@@ -561,37 +561,93 @@ export const anthropicProvider: ProviderConfig = {
throw error
}
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
return {
content,
model: request.model || 'claude-3-7-sonnet-20250219',
tokens,
toolCalls:
toolCalls.length > 0
? toolCalls.map((tc) => ({
name: tc.name,
arguments: tc.arguments as Record<string, any>,
startTime: tc.startTime,
endTime: tc.endTime,
duration: tc.duration,
result: tc.result,
}))
: undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
timing: {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
modelTime: modelTime,
toolsTime: toolsTime,
firstResponseTime: firstResponseTime,
iterations: iterationCount + 1,
timeSegments: timeSegments,
const streamingPayload = {
...payload,
messages: currentMessages,
stream: true,
tool_choice: undefined,
}
const streamResponse: any = await anthropic.messages.create(streamingPayload)
const streamingResult = {
stream: createReadableStreamFromAnthropicStream(
streamResponse,
(streamContent, usage) => {
streamingResult.execution.output.content = streamContent
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.input_tokens,
completion: tokens.completion + usage.output_tokens,
total: tokens.total + usage.input_tokens + usage.output_tokens,
}
const streamCost = calculateCost(
request.model,
usage.input_tokens,
usage.output_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
if (streamingResult.execution.output.providerTiming) {
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
streamingResult.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
}
}
),
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,
total: tokens.total,
},
toolCalls:
toolCalls.length > 0
? {
list: toolCalls,
count: toolCalls.length,
}
: undefined,
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
modelTime: modelTime,
toolsTime: toolsTime,
firstResponseTime: firstResponseTime,
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
cost: {
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
},
isStreaming: true,
},
}
return streamingResult as StreamingExecution
} catch (error) {
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
@@ -934,7 +990,7 @@ export const anthropicProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'claude-3-7-sonnet-20250219',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,
@@ -978,7 +1034,7 @@ export const anthropicProvider: ProviderConfig = {
return {
content,
model: request.model || 'claude-3-7-sonnet-20250219',
model: request.model,
tokens,
toolCalls:
toolCalls.length > 0

View File

@@ -39,7 +39,7 @@ export const azureOpenAIProvider: ProviderConfig = {
request: ProviderRequest
): Promise<ProviderResponse | StreamingExecution> => {
logger.info('Preparing Azure OpenAI request', {
model: request.model || 'azure/gpt-4o',
model: request.model,
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
@@ -95,7 +95,7 @@ export const azureOpenAIProvider: ProviderConfig = {
}))
: undefined
const deploymentName = (request.model || 'azure/gpt-4o').replace('azure/', '')
const deploymentName = request.model.replace('azure/', '')
const payload: any = {
model: deploymentName,
messages: allMessages,

View File

@@ -73,7 +73,7 @@ export const cerebrasProvider: ProviderConfig = {
: undefined
const payload: any = {
model: (request.model || 'cerebras/llama-3.3-70b').replace('cerebras/', ''),
model: request.model.replace('cerebras/', ''),
messages: allMessages,
}
if (request.temperature !== undefined) payload.temperature = request.temperature
@@ -145,7 +145,7 @@ export const cerebrasProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'cerebras/llama-3.3-70b',
model: request.model,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
@@ -470,7 +470,7 @@ export const cerebrasProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'cerebras/llama-3.3-70b',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,

View File

@@ -105,7 +105,7 @@ export const deepseekProvider: ProviderConfig = {
: toolChoice.type === 'any'
? `force:${toolChoice.any?.name || 'unknown'}`
: 'unknown',
model: request.model || 'deepseek-v3',
model: request.model,
})
}
}
@@ -145,7 +145,7 @@ export const deepseekProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'deepseek-chat',
model: request.model,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
@@ -469,7 +469,7 @@ export const deepseekProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'deepseek-chat',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,

View File

@@ -0,0 +1,58 @@
import { GoogleGenAI } from '@google/genai'
import { createLogger } from '@/lib/logs/console/logger'
import type { GeminiClientConfig } from './types'
const logger = createLogger('GeminiClient')
/**
* Creates a GoogleGenAI client configured for either Google Gemini API or Vertex AI
*
* For Google Gemini API:
* - Uses API key authentication
*
* For Vertex AI:
* - Uses OAuth access token via HTTP Authorization header
* - Requires project and location
*/
export function createGeminiClient(config: GeminiClientConfig): GoogleGenAI {
if (config.vertexai) {
if (!config.project) {
throw new Error('Vertex AI requires a project ID')
}
if (!config.accessToken) {
throw new Error('Vertex AI requires an access token')
}
const location = config.location ?? 'us-central1'
logger.info('Creating Vertex AI client', {
project: config.project,
location,
hasAccessToken: !!config.accessToken,
})
// Create client with Vertex AI configuration
// Use httpOptions.headers to pass the access token directly
return new GoogleGenAI({
vertexai: true,
project: config.project,
location,
httpOptions: {
headers: {
Authorization: `Bearer ${config.accessToken}`,
},
},
})
}
// Google Gemini API with API key
if (!config.apiKey) {
throw new Error('Google Gemini API requires an API key')
}
logger.info('Creating Google Gemini client')
return new GoogleGenAI({
apiKey: config.apiKey,
})
}

View File

@@ -0,0 +1,680 @@
import {
type Content,
FunctionCallingConfigMode,
type FunctionDeclaration,
type GenerateContentConfig,
type GenerateContentResponse,
type GoogleGenAI,
type Part,
type Schema,
type ThinkingConfig,
type ToolConfig,
} from '@google/genai'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import {
checkForForcedToolUsage,
cleanSchemaForGemini,
convertToGeminiFormat,
convertUsageMetadata,
createReadableStreamFromGeminiStream,
extractFunctionCallPart,
extractTextContent,
mapToThinkingLevel,
} from '@/providers/google/utils'
import { getThinkingCapability } from '@/providers/models'
import type { FunctionCallResponse, ProviderRequest, ProviderResponse } from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
} from '@/providers/utils'
import { executeTool } from '@/tools'
import type { ExecutionState, GeminiProviderType, GeminiUsage, ParsedFunctionCall } from './types'
/**
* Creates initial execution state
*/
function createInitialState(
contents: Content[],
initialUsage: GeminiUsage,
firstResponseTime: number,
initialCallTime: number,
model: string,
toolConfig: ToolConfig | undefined
): ExecutionState {
const initialCost = calculateCost(
model,
initialUsage.promptTokenCount,
initialUsage.candidatesTokenCount
)
return {
contents,
tokens: {
prompt: initialUsage.promptTokenCount,
completion: initialUsage.candidatesTokenCount,
total: initialUsage.totalTokenCount,
},
cost: initialCost,
toolCalls: [],
toolResults: [],
iterationCount: 0,
modelTime: firstResponseTime,
toolsTime: 0,
timeSegments: [
{
type: 'model',
name: 'Initial response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
],
usedForcedTools: [],
currentToolConfig: toolConfig,
}
}
/**
* Executes a tool call and updates state
*/
async function executeToolCall(
functionCallPart: Part,
functionCall: ParsedFunctionCall,
request: ProviderRequest,
state: ExecutionState,
forcedTools: string[],
logger: ReturnType<typeof createLogger>
): Promise<{ success: boolean; state: ExecutionState }> {
const toolCallStartTime = Date.now()
const toolName = functionCall.name
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) {
logger.warn(`Tool ${toolName} not found in registry, skipping`)
return { success: false, state }
}
try {
const { toolParams, executionParams } = prepareToolExecution(tool, functionCall.args, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const duration = toolCallEndTime - toolCallStartTime
const resultContent: Record<string, unknown> = result.success
? (result.output as Record<string, unknown>)
: { error: true, message: result.error || 'Tool execution failed', tool: toolName }
const toolCall: FunctionCallResponse = {
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration,
result: resultContent,
}
const updatedContents: Content[] = [
...state.contents,
{
role: 'model',
parts: [functionCallPart],
},
{
role: 'user',
parts: [
{
functionResponse: {
name: functionCall.name,
response: resultContent,
},
},
],
},
]
const forcedToolCheck = checkForForcedToolUsage(
[{ name: functionCall.name, args: functionCall.args }],
state.currentToolConfig,
forcedTools,
state.usedForcedTools
)
return {
success: true,
state: {
...state,
contents: updatedContents,
toolCalls: [...state.toolCalls, toolCall],
toolResults: result.success
? [...state.toolResults, result.output as Record<string, unknown>]
: state.toolResults,
toolsTime: state.toolsTime + duration,
timeSegments: [
...state.timeSegments,
{
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration,
},
],
usedForcedTools: forcedToolCheck?.usedForcedTools ?? state.usedForcedTools,
currentToolConfig: forcedToolCheck?.nextToolConfig ?? state.currentToolConfig,
},
}
} catch (error) {
logger.error('Error processing function call:', {
error: error instanceof Error ? error.message : String(error),
functionName: toolName,
})
return { success: false, state }
}
}
/**
* Updates state with model response metadata
*/
function updateStateWithResponse(
state: ExecutionState,
response: GenerateContentResponse,
model: string,
startTime: number,
endTime: number
): ExecutionState {
const usage = convertUsageMetadata(response.usageMetadata)
const cost = calculateCost(model, usage.promptTokenCount, usage.candidatesTokenCount)
const duration = endTime - startTime
return {
...state,
tokens: {
prompt: state.tokens.prompt + usage.promptTokenCount,
completion: state.tokens.completion + usage.candidatesTokenCount,
total: state.tokens.total + usage.totalTokenCount,
},
cost: {
input: state.cost.input + cost.input,
output: state.cost.output + cost.output,
total: state.cost.total + cost.total,
pricing: cost.pricing, // Use latest pricing
},
modelTime: state.modelTime + duration,
timeSegments: [
...state.timeSegments,
{
type: 'model',
name: `Model response (iteration ${state.iterationCount + 1})`,
startTime,
endTime,
duration,
},
],
iterationCount: state.iterationCount + 1,
}
}
/**
* Builds config for next iteration
*/
function buildNextConfig(
baseConfig: GenerateContentConfig,
state: ExecutionState,
forcedTools: string[],
request: ProviderRequest,
logger: ReturnType<typeof createLogger>
): GenerateContentConfig {
const nextConfig = { ...baseConfig }
const allForcedToolsUsed =
forcedTools.length > 0 && state.usedForcedTools.length === forcedTools.length
if (allForcedToolsUsed && request.responseFormat) {
nextConfig.tools = undefined
nextConfig.toolConfig = undefined
nextConfig.responseMimeType = 'application/json'
nextConfig.responseSchema = cleanSchemaForGemini(request.responseFormat.schema) as Schema
logger.info('Using structured output for final response after tool execution')
} else if (state.currentToolConfig) {
nextConfig.toolConfig = state.currentToolConfig
} else {
nextConfig.toolConfig = { functionCallingConfig: { mode: FunctionCallingConfigMode.AUTO } }
}
return nextConfig
}
/**
* Creates streaming execution result template
*/
function createStreamingResult(
providerStartTime: number,
providerStartTimeISO: string,
firstResponseTime: number,
initialCallTime: number,
state?: ExecutionState
): StreamingExecution {
return {
stream: undefined as unknown as ReadableStream<Uint8Array>,
execution: {
success: true,
output: {
content: '',
model: '',
tokens: state?.tokens ?? { prompt: 0, completion: 0, total: 0 },
toolCalls: state?.toolCalls.length
? { list: state.toolCalls, count: state.toolCalls.length }
: undefined,
toolResults: state?.toolResults,
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
modelTime: state?.modelTime ?? firstResponseTime,
toolsTime: state?.toolsTime ?? 0,
firstResponseTime,
iterations: (state?.iterationCount ?? 0) + 1,
timeSegments: state?.timeSegments ?? [
{
type: 'model',
name: 'Initial streaming response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
],
},
cost: state?.cost ?? {
input: 0,
output: 0,
total: 0,
pricing: { input: 0, output: 0, updatedAt: new Date().toISOString() },
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
},
isStreaming: true,
},
}
}
/**
* Configuration for executing a Gemini request
*/
export interface GeminiExecutionConfig {
ai: GoogleGenAI
model: string
request: ProviderRequest
providerType: GeminiProviderType
}
/**
* Executes a request using the Gemini API
*
* This is the shared core logic for both Google and Vertex AI providers.
* The only difference is how the GoogleGenAI client is configured.
*/
export async function executeGeminiRequest(
config: GeminiExecutionConfig
): Promise<ProviderResponse | StreamingExecution> {
const { ai, model, request, providerType } = config
const logger = createLogger(providerType === 'google' ? 'GoogleProvider' : 'VertexProvider')
logger.info(`Preparing ${providerType} Gemini request`, {
model,
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
toolCount: request.tools?.length ?? 0,
hasResponseFormat: !!request.responseFormat,
streaming: !!request.stream,
})
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
// Build configuration
const geminiConfig: GenerateContentConfig = {}
if (request.temperature !== undefined) {
geminiConfig.temperature = request.temperature
}
if (request.maxTokens !== undefined) {
geminiConfig.maxOutputTokens = request.maxTokens
}
if (systemInstruction) {
geminiConfig.systemInstruction = systemInstruction
}
// Handle response format (only when no tools)
if (request.responseFormat && !tools?.length) {
geminiConfig.responseMimeType = 'application/json'
geminiConfig.responseSchema = cleanSchemaForGemini(request.responseFormat.schema) as Schema
logger.info('Using Gemini native structured output format')
} else if (request.responseFormat && tools?.length) {
logger.warn('Gemini does not support responseFormat with tools. Structured output ignored.')
}
// Configure thinking for models that support it
const thinkingCapability = getThinkingCapability(model)
if (thinkingCapability) {
const level = request.thinkingLevel ?? thinkingCapability.default ?? 'high'
const thinkingConfig: ThinkingConfig = {
includeThoughts: false,
thinkingLevel: mapToThinkingLevel(level),
}
geminiConfig.thinkingConfig = thinkingConfig
}
// Prepare tools
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
let toolConfig: ToolConfig | undefined
if (tools?.length) {
const functionDeclarations: FunctionDeclaration[] = tools.map((t) => ({
name: t.name,
description: t.description,
parameters: t.parameters,
}))
preparedTools = prepareToolsWithUsageControl(
functionDeclarations,
request.tools,
logger,
'google'
)
const { tools: filteredTools, toolConfig: preparedToolConfig } = preparedTools
if (filteredTools?.length) {
geminiConfig.tools = [{ functionDeclarations: filteredTools as FunctionDeclaration[] }]
if (preparedToolConfig) {
toolConfig = {
functionCallingConfig: {
mode:
{
AUTO: FunctionCallingConfigMode.AUTO,
ANY: FunctionCallingConfigMode.ANY,
NONE: FunctionCallingConfigMode.NONE,
}[preparedToolConfig.functionCallingConfig.mode] ?? FunctionCallingConfigMode.AUTO,
allowedFunctionNames: preparedToolConfig.functionCallingConfig.allowedFunctionNames,
},
}
geminiConfig.toolConfig = toolConfig
}
logger.info('Gemini request with tools:', {
toolCount: filteredTools.length,
model,
tools: filteredTools.map((t) => (t as FunctionDeclaration).name),
})
}
}
const initialCallTime = Date.now()
const shouldStream = request.stream && !tools?.length
// Streaming without tools
if (shouldStream) {
logger.info('Handling Gemini streaming response')
const streamGenerator = await ai.models.generateContentStream({
model,
contents,
config: geminiConfig,
})
const firstResponseTime = Date.now() - initialCallTime
const streamingResult = createStreamingResult(
providerStartTime,
providerStartTimeISO,
firstResponseTime,
initialCallTime
)
streamingResult.execution.output.model = model
streamingResult.stream = createReadableStreamFromGeminiStream(
streamGenerator,
(content: string, usage: GeminiUsage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: usage.promptTokenCount,
completion: usage.candidatesTokenCount,
total: usage.totalTokenCount,
}
const costResult = calculateCost(
model,
usage.promptTokenCount,
usage.candidatesTokenCount
)
streamingResult.execution.output.cost = costResult
const streamEndTime = Date.now()
if (streamingResult.execution.output.providerTiming) {
streamingResult.execution.output.providerTiming.endTime = new Date(
streamEndTime
).toISOString()
streamingResult.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
const segments = streamingResult.execution.output.providerTiming.timeSegments
if (segments?.[0]) {
segments[0].endTime = streamEndTime
segments[0].duration = streamEndTime - providerStartTime
}
}
}
)
return streamingResult
}
// Non-streaming request
const response = await ai.models.generateContent({ model, contents, config: geminiConfig })
const firstResponseTime = Date.now() - initialCallTime
// Check for UNEXPECTED_TOOL_CALL
const candidate = response.candidates?.[0]
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
logger.warn('Gemini returned UNEXPECTED_TOOL_CALL - model attempted to call unknown tool')
}
const initialUsage = convertUsageMetadata(response.usageMetadata)
let state = createInitialState(
contents,
initialUsage,
firstResponseTime,
initialCallTime,
model,
toolConfig
)
const forcedTools = preparedTools?.forcedTools ?? []
let currentResponse = response
let content = ''
// Tool execution loop
const functionCalls = response.functionCalls
if (functionCalls?.length) {
logger.info(`Received function call from Gemini: ${functionCalls[0].name}`)
while (state.iterationCount < MAX_TOOL_ITERATIONS) {
const functionCallPart = extractFunctionCallPart(currentResponse.candidates?.[0])
if (!functionCallPart?.functionCall) {
content = extractTextContent(currentResponse.candidates?.[0])
break
}
const functionCall: ParsedFunctionCall = {
name: functionCallPart.functionCall.name ?? '',
args: (functionCallPart.functionCall.args ?? {}) as Record<string, unknown>,
}
logger.info(
`Processing function call: ${functionCall.name} (iteration ${state.iterationCount + 1})`
)
const { success, state: updatedState } = await executeToolCall(
functionCallPart,
functionCall,
request,
state,
forcedTools,
logger
)
if (!success) {
content = extractTextContent(currentResponse.candidates?.[0])
break
}
state = { ...updatedState, iterationCount: updatedState.iterationCount + 1 }
const nextConfig = buildNextConfig(geminiConfig, state, forcedTools, request, logger)
// Stream final response if requested
if (request.stream) {
const checkResponse = await ai.models.generateContent({
model,
contents: state.contents,
config: nextConfig,
})
state = updateStateWithResponse(state, checkResponse, model, Date.now() - 100, Date.now())
if (checkResponse.functionCalls?.length) {
currentResponse = checkResponse
continue
}
logger.info('No more function calls, streaming final response')
if (request.responseFormat) {
nextConfig.tools = undefined
nextConfig.toolConfig = undefined
nextConfig.responseMimeType = 'application/json'
nextConfig.responseSchema = cleanSchemaForGemini(
request.responseFormat.schema
) as Schema
}
// Capture accumulated cost before streaming
const accumulatedCost = {
input: state.cost.input,
output: state.cost.output,
total: state.cost.total,
}
const accumulatedTokens = { ...state.tokens }
const streamGenerator = await ai.models.generateContentStream({
model,
contents: state.contents,
config: nextConfig,
})
const streamingResult = createStreamingResult(
providerStartTime,
providerStartTimeISO,
firstResponseTime,
initialCallTime,
state
)
streamingResult.execution.output.model = model
streamingResult.stream = createReadableStreamFromGeminiStream(
streamGenerator,
(streamContent: string, usage: GeminiUsage) => {
streamingResult.execution.output.content = streamContent
streamingResult.execution.output.tokens = {
prompt: accumulatedTokens.prompt + usage.promptTokenCount,
completion: accumulatedTokens.completion + usage.candidatesTokenCount,
total: accumulatedTokens.total + usage.totalTokenCount,
}
const streamCost = calculateCost(
model,
usage.promptTokenCount,
usage.candidatesTokenCount
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
pricing: streamCost.pricing,
}
if (streamingResult.execution.output.providerTiming) {
streamingResult.execution.output.providerTiming.endTime = new Date().toISOString()
streamingResult.execution.output.providerTiming.duration =
Date.now() - providerStartTime
}
}
)
return streamingResult
}
// Non-streaming: get next response
const nextModelStartTime = Date.now()
const nextResponse = await ai.models.generateContent({
model,
contents: state.contents,
config: nextConfig,
})
state = updateStateWithResponse(state, nextResponse, model, nextModelStartTime, Date.now())
currentResponse = nextResponse
}
if (!content) {
content = extractTextContent(currentResponse.candidates?.[0])
}
} else {
content = extractTextContent(candidate)
}
const providerEndTime = Date.now()
return {
content,
model,
tokens: state.tokens,
toolCalls: state.toolCalls.length ? state.toolCalls : undefined,
toolResults: state.toolResults.length ? state.toolResults : undefined,
timing: {
startTime: providerStartTimeISO,
endTime: new Date(providerEndTime).toISOString(),
duration: providerEndTime - providerStartTime,
modelTime: state.modelTime,
toolsTime: state.toolsTime,
firstResponseTime,
iterations: state.iterationCount + 1,
timeSegments: state.timeSegments,
},
cost: state.cost,
}
} catch (error) {
const providerEndTime = Date.now()
const duration = providerEndTime - providerStartTime
logger.error('Error in Gemini request:', {
error: error instanceof Error ? error.message : String(error),
stack: error instanceof Error ? error.stack : undefined,
})
const enhancedError = error instanceof Error ? error : new Error(String(error))
Object.assign(enhancedError, {
timing: {
startTime: providerStartTimeISO,
endTime: new Date(providerEndTime).toISOString(),
duration,
},
})
throw enhancedError
}
}

View File

@@ -0,0 +1,18 @@
/**
* Shared Gemini execution core
*
* This module provides the shared execution logic for both Google Gemini API
* and Vertex AI providers. The only difference between providers is how the
* GoogleGenAI client is configured (API key vs OAuth).
*/
export { createGeminiClient } from './client'
export { executeGeminiRequest, type GeminiExecutionConfig } from './core'
export type {
ExecutionState,
ForcedToolResult,
GeminiClientConfig,
GeminiProviderType,
GeminiUsage,
ParsedFunctionCall,
} from './types'

View File

@@ -0,0 +1,64 @@
import type { Content, ToolConfig } from '@google/genai'
import type { FunctionCallResponse, ModelPricing, TimeSegment } from '@/providers/types'
/**
* Usage metadata from Gemini responses
*/
export interface GeminiUsage {
promptTokenCount: number
candidatesTokenCount: number
totalTokenCount: number
}
/**
* Parsed function call from Gemini response
*/
export interface ParsedFunctionCall {
name: string
args: Record<string, unknown>
}
/**
* Accumulated state during tool execution loop
*/
export interface ExecutionState {
contents: Content[]
tokens: { prompt: number; completion: number; total: number }
cost: { input: number; output: number; total: number; pricing: ModelPricing }
toolCalls: FunctionCallResponse[]
toolResults: Record<string, unknown>[]
iterationCount: number
modelTime: number
toolsTime: number
timeSegments: TimeSegment[]
usedForcedTools: string[]
currentToolConfig: ToolConfig | undefined
}
/**
* Result from forced tool usage check
*/
export interface ForcedToolResult {
hasUsedForcedTool: boolean
usedForcedTools: string[]
nextToolConfig: ToolConfig | undefined
}
/**
* Configuration for creating a Gemini client
*/
export interface GeminiClientConfig {
/** For Google Gemini API */
apiKey?: string
/** For Vertex AI */
vertexai?: boolean
project?: string
location?: string
/** OAuth access token for Vertex AI */
accessToken?: string
}
/**
* Provider type for logging and model lookup
*/
export type GeminiProviderType = 'google' | 'vertex'

File diff suppressed because it is too large Load Diff

View File

@@ -1,61 +1,89 @@
import type { Candidate } from '@google/genai'
import {
type Candidate,
type Content,
type FunctionCall,
FunctionCallingConfigMode,
type GenerateContentResponse,
type GenerateContentResponseUsageMetadata,
type Part,
type Schema,
type SchemaUnion,
ThinkingLevel,
type ToolConfig,
Type,
} from '@google/genai'
import { createLogger } from '@/lib/logs/console/logger'
import type { ProviderRequest } from '@/providers/types'
import { trackForcedToolUsage } from '@/providers/utils'
const logger = createLogger('GoogleUtils')
/**
* Usage metadata for Google Gemini responses
*/
export interface GeminiUsage {
promptTokenCount: number
candidatesTokenCount: number
totalTokenCount: number
}
/**
* Parsed function call from Gemini response
*/
export interface ParsedFunctionCall {
name: string
args: Record<string, unknown>
}
/**
* Removes additionalProperties from a schema object (not supported by Gemini)
*/
export function cleanSchemaForGemini(schema: any): any {
export function cleanSchemaForGemini(schema: SchemaUnion): SchemaUnion {
if (schema === null || schema === undefined) return schema
if (typeof schema !== 'object') return schema
if (Array.isArray(schema)) {
return schema.map((item) => cleanSchemaForGemini(item))
}
const cleanedSchema: any = {}
const cleanedSchema: Record<string, unknown> = {}
const schemaObj = schema as Record<string, unknown>
for (const key in schema) {
for (const key in schemaObj) {
if (key === 'additionalProperties') continue
cleanedSchema[key] = cleanSchemaForGemini(schema[key])
cleanedSchema[key] = cleanSchemaForGemini(schemaObj[key] as SchemaUnion)
}
return cleanedSchema
}
/**
* Extracts text content from a Gemini response candidate, handling structured output
* Extracts text content from a Gemini response candidate.
* Filters out thought parts (model reasoning) from the output.
*/
export function extractTextContent(candidate: Candidate | undefined): string {
if (!candidate?.content?.parts) return ''
if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) {
const text = candidate.content.parts[0].text
if (text && (text.trim().startsWith('{') || text.trim().startsWith('['))) {
try {
JSON.parse(text)
return text
} catch (_e) {}
}
}
const textParts = candidate.content.parts.filter(
(part): part is Part & { text: string } => Boolean(part.text) && part.thought !== true
)
return candidate.content.parts
.filter((part: any) => part.text)
.map((part: any) => part.text)
.join('\n')
if (textParts.length === 0) return ''
if (textParts.length === 1) return textParts[0].text
return textParts.map((part) => part.text).join('\n')
}
/**
* Extracts a function call from a Gemini response candidate
* Extracts the first function call from a Gemini response candidate
*/
export function extractFunctionCall(
candidate: Candidate | undefined
): { name: string; args: any } | null {
export function extractFunctionCall(candidate: Candidate | undefined): ParsedFunctionCall | null {
if (!candidate?.content?.parts) return null
for (const part of candidate.content.parts) {
if (part.functionCall) {
return {
name: part.functionCall.name ?? '',
args: part.functionCall.args ?? {},
args: (part.functionCall.args ?? {}) as Record<string, unknown>,
}
}
}
@@ -63,16 +91,55 @@ export function extractFunctionCall(
return null
}
/**
* Extracts the full Part containing the function call (preserves thoughtSignature)
*/
export function extractFunctionCallPart(candidate: Candidate | undefined): Part | null {
if (!candidate?.content?.parts) return null
for (const part of candidate.content.parts) {
if (part.functionCall) {
return part
}
}
return null
}
/**
* Converts usage metadata from SDK response to our format
*/
export function convertUsageMetadata(
usageMetadata: GenerateContentResponseUsageMetadata | undefined
): GeminiUsage {
const promptTokenCount = usageMetadata?.promptTokenCount ?? 0
const candidatesTokenCount = usageMetadata?.candidatesTokenCount ?? 0
return {
promptTokenCount,
candidatesTokenCount,
totalTokenCount: usageMetadata?.totalTokenCount ?? promptTokenCount + candidatesTokenCount,
}
}
/**
* Tool definition for Gemini format
*/
export interface GeminiToolDef {
name: string
description: string
parameters: Schema
}
/**
* Converts OpenAI-style request format to Gemini format
*/
export function convertToGeminiFormat(request: ProviderRequest): {
contents: any[]
tools: any[] | undefined
systemInstruction: any | undefined
contents: Content[]
tools: GeminiToolDef[] | undefined
systemInstruction: Content | undefined
} {
const contents: any[] = []
let systemInstruction
const contents: Content[] = []
let systemInstruction: Content | undefined
if (request.systemPrompt) {
systemInstruction = { parts: [{ text: request.systemPrompt }] }
@@ -82,13 +149,13 @@ export function convertToGeminiFormat(request: ProviderRequest): {
contents.push({ role: 'user', parts: [{ text: request.context }] })
}
if (request.messages && request.messages.length > 0) {
if (request.messages?.length) {
for (const message of request.messages) {
if (message.role === 'system') {
if (!systemInstruction) {
systemInstruction = { parts: [{ text: message.content }] }
} else {
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text || ''}\n${message.content}`
systemInstruction = { parts: [{ text: message.content ?? '' }] }
} else if (systemInstruction.parts?.[0] && 'text' in systemInstruction.parts[0]) {
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text}\n${message.content}`
}
} else if (message.role === 'user' || message.role === 'assistant') {
const geminiRole = message.role === 'user' ? 'user' : 'model'
@@ -97,60 +164,200 @@ export function convertToGeminiFormat(request: ProviderRequest): {
contents.push({ role: geminiRole, parts: [{ text: message.content }] })
}
if (message.role === 'assistant' && message.tool_calls && message.tool_calls.length > 0) {
if (message.role === 'assistant' && message.tool_calls?.length) {
const functionCalls = message.tool_calls.map((toolCall) => ({
functionCall: {
name: toolCall.function?.name,
args: JSON.parse(toolCall.function?.arguments || '{}'),
args: JSON.parse(toolCall.function?.arguments || '{}') as Record<string, unknown>,
},
}))
contents.push({ role: 'model', parts: functionCalls })
}
} else if (message.role === 'tool') {
if (!message.name) {
logger.warn('Tool message missing function name, skipping')
continue
}
let responseData: Record<string, unknown>
try {
responseData = JSON.parse(message.content ?? '{}')
} catch {
responseData = { output: message.content }
}
contents.push({
role: 'user',
parts: [{ text: `Function result: ${message.content}` }],
parts: [
{
functionResponse: {
id: message.tool_call_id,
name: message.name,
response: responseData,
},
},
],
})
}
}
}
const tools = request.tools?.map((tool) => {
const tools = request.tools?.map((tool): GeminiToolDef => {
const toolParameters = { ...(tool.parameters || {}) }
if (toolParameters.properties) {
const properties = { ...toolParameters.properties }
const required = toolParameters.required ? [...toolParameters.required] : []
// Remove default values from properties (not supported by Gemini)
for (const key in properties) {
const prop = properties[key] as any
const prop = properties[key] as Record<string, unknown>
if (prop.default !== undefined) {
const { default: _, ...cleanProp } = prop
properties[key] = cleanProp
}
}
const parameters = {
type: toolParameters.type || 'object',
properties,
const parameters: Schema = {
type: (toolParameters.type as Schema['type']) || Type.OBJECT,
properties: properties as Record<string, Schema>,
...(required.length > 0 ? { required } : {}),
}
return {
name: tool.id,
description: tool.description || `Execute the ${tool.id} function`,
parameters: cleanSchemaForGemini(parameters),
parameters: cleanSchemaForGemini(parameters) as Schema,
}
}
return {
name: tool.id,
description: tool.description || `Execute the ${tool.id} function`,
parameters: cleanSchemaForGemini(toolParameters),
parameters: cleanSchemaForGemini(toolParameters) as Schema,
}
})
return { contents, tools, systemInstruction }
}
/**
* Creates a ReadableStream from a Google Gemini streaming response
*/
export function createReadableStreamFromGeminiStream(
stream: AsyncGenerator<GenerateContentResponse>,
onComplete?: (content: string, usage: GeminiUsage) => void
): ReadableStream<Uint8Array> {
let fullContent = ''
let usage: GeminiUsage = { promptTokenCount: 0, candidatesTokenCount: 0, totalTokenCount: 0 }
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of stream) {
if (chunk.usageMetadata) {
usage = convertUsageMetadata(chunk.usageMetadata)
}
const text = chunk.text
if (text) {
fullContent += text
controller.enqueue(new TextEncoder().encode(text))
}
}
onComplete?.(fullContent, usage)
controller.close()
} catch (error) {
logger.error('Error reading Google Gemini stream', {
error: error instanceof Error ? error.message : String(error),
})
controller.error(error)
}
},
})
}
/**
* Maps string mode to FunctionCallingConfigMode enum
*/
function mapToFunctionCallingMode(mode: string): FunctionCallingConfigMode {
switch (mode) {
case 'AUTO':
return FunctionCallingConfigMode.AUTO
case 'ANY':
return FunctionCallingConfigMode.ANY
case 'NONE':
return FunctionCallingConfigMode.NONE
default:
return FunctionCallingConfigMode.AUTO
}
}
/**
* Maps string level to ThinkingLevel enum
*/
export function mapToThinkingLevel(level: string): ThinkingLevel {
switch (level.toLowerCase()) {
case 'minimal':
return ThinkingLevel.MINIMAL
case 'low':
return ThinkingLevel.LOW
case 'medium':
return ThinkingLevel.MEDIUM
case 'high':
return ThinkingLevel.HIGH
default:
return ThinkingLevel.HIGH
}
}
/**
* Result of checking forced tool usage
*/
export interface ForcedToolResult {
hasUsedForcedTool: boolean
usedForcedTools: string[]
nextToolConfig: ToolConfig | undefined
}
/**
* Checks for forced tool usage in Google Gemini responses
*/
export function checkForForcedToolUsage(
functionCalls: FunctionCall[] | undefined,
toolConfig: ToolConfig | undefined,
forcedTools: string[],
usedForcedTools: string[]
): ForcedToolResult | null {
if (!functionCalls?.length) return null
const adaptedToolCalls = functionCalls.map((fc) => ({
name: fc.name ?? '',
arguments: (fc.args ?? {}) as Record<string, unknown>,
}))
const result = trackForcedToolUsage(
adaptedToolCalls,
toolConfig,
logger,
'google',
forcedTools,
usedForcedTools
)
if (!result) return null
const nextToolConfig: ToolConfig | undefined = result.nextToolConfig?.functionCallingConfig?.mode
? {
functionCallingConfig: {
mode: mapToFunctionCallingMode(result.nextToolConfig.functionCallingConfig.mode),
allowedFunctionNames: result.nextToolConfig.functionCallingConfig.allowedFunctionNames,
},
}
: undefined
return {
hasUsedForcedTool: result.hasUsedForcedTool,
usedForcedTools: result.usedForcedTools,
nextToolConfig,
}
}

View File

@@ -69,10 +69,7 @@ export const groqProvider: ProviderConfig = {
: undefined
const payload: any = {
model: (request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct').replace(
'groq/',
''
),
model: request.model.replace('groq/', ''),
messages: allMessages,
}
@@ -109,7 +106,7 @@ export const groqProvider: ProviderConfig = {
toolChoice: payload.tool_choice,
forcedToolsCount: forcedTools.length,
hasFilteredTools,
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
model: request.model,
})
}
}
@@ -149,7 +146,7 @@ export const groqProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
model: request.model,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
@@ -393,7 +390,7 @@ export const groqProvider: ProviderConfig = {
const streamingPayload = {
...payload,
messages: currentMessages,
tool_choice: 'auto',
tool_choice: originalToolChoice || 'auto',
stream: true,
}
@@ -425,7 +422,7 @@ export const groqProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,

View File

@@ -1,11 +1,11 @@
import { getCostMultiplier } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import type { ProviderRequest, ProviderResponse } from '@/providers/types'
import { getProviderExecutor } from '@/providers/registry'
import type { ProviderId, ProviderRequest, ProviderResponse } from '@/providers/types'
import {
calculateCost,
generateStructuredOutputInstructions,
getProvider,
shouldBillModelUsage,
supportsTemperature,
} from '@/providers/utils'
@@ -40,7 +40,7 @@ export async function executeProviderRequest(
providerId: string,
request: ProviderRequest
): Promise<ProviderResponse | ReadableStream | StreamingExecution> {
const provider = getProvider(providerId)
const provider = await getProviderExecutor(providerId as ProviderId)
if (!provider) {
throw new Error(`Provider not found: ${providerId}`)
}

View File

@@ -36,7 +36,7 @@ export const mistralProvider: ProviderConfig = {
request: ProviderRequest
): Promise<ProviderResponse | StreamingExecution> => {
logger.info('Preparing Mistral request', {
model: request.model || 'mistral-large-latest',
model: request.model,
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
@@ -86,7 +86,7 @@ export const mistralProvider: ProviderConfig = {
: undefined
const payload: any = {
model: request.model || 'mistral-large-latest',
model: request.model,
messages: allMessages,
}
@@ -126,7 +126,7 @@ export const mistralProvider: ProviderConfig = {
: toolChoice.type === 'any'
? `force:${toolChoice.any?.name || 'unknown'}`
: 'unknown',
model: request.model || 'mistral-large-latest',
model: request.model,
})
}
}

View File

@@ -39,6 +39,10 @@ export interface ModelCapabilities {
verbosity?: {
values: string[]
}
thinking?: {
levels: string[]
default?: string
}
}
export interface ModelDefinition {
@@ -730,6 +734,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 2 },
thinking: {
levels: ['low', 'high'],
default: 'high',
},
},
contextWindow: 1000000,
},
@@ -743,6 +751,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 2 },
thinking: {
levels: ['minimal', 'low', 'medium', 'high'],
default: 'high',
},
},
contextWindow: 1000000,
},
@@ -832,6 +844,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 2 },
thinking: {
levels: ['low', 'high'],
default: 'high',
},
},
contextWindow: 1000000,
},
@@ -845,6 +861,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 2 },
thinking: {
levels: ['minimal', 'low', 'medium', 'high'],
default: 'high',
},
},
contextWindow: 1000000,
},
@@ -1864,3 +1884,49 @@ export function supportsNativeStructuredOutputs(modelId: string): boolean {
}
return false
}
/**
* Check if a model supports thinking/reasoning features.
* Returns the thinking capability config if supported, null otherwise.
*/
export function getThinkingCapability(
modelId: string
): { levels: string[]; default?: string } | null {
const normalizedModelId = modelId.toLowerCase()
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
for (const model of provider.models) {
if (model.capabilities.thinking) {
const baseModelId = model.id.toLowerCase()
if (normalizedModelId === baseModelId || normalizedModelId.startsWith(`${baseModelId}-`)) {
return model.capabilities.thinking
}
}
}
}
return null
}
/**
* Get all models that support thinking capability
*/
export function getModelsWithThinking(): string[] {
const models: string[] = []
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
for (const model of provider.models) {
if (model.capabilities.thinking) {
models.push(model.id)
}
}
}
return models
}
/**
* Get the thinking levels for a specific model
* Returns the valid levels for that model, or null if the model doesn't support thinking
*/
export function getThinkingLevelsForModel(modelId: string): string[] | null {
const capability = getThinkingCapability(modelId)
return capability?.levels ?? null
}

View File

@@ -33,7 +33,7 @@ export const openaiProvider: ProviderConfig = {
request: ProviderRequest
): Promise<ProviderResponse | StreamingExecution> => {
logger.info('Preparing OpenAI request', {
model: request.model || 'gpt-4o',
model: request.model,
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
@@ -76,7 +76,7 @@ export const openaiProvider: ProviderConfig = {
: undefined
const payload: any = {
model: request.model || 'gpt-4o',
model: request.model,
messages: allMessages,
}
@@ -121,7 +121,7 @@ export const openaiProvider: ProviderConfig = {
: toolChoice.type === 'any'
? `force:${toolChoice.any?.name || 'unknown'}`
: 'unknown',
model: request.model || 'gpt-4o',
model: request.model,
})
}
}

View File

@@ -78,7 +78,7 @@ export const openRouterProvider: ProviderConfig = {
baseURL: 'https://openrouter.ai/api/v1',
})
const requestedModel = (request.model || '').replace(/^openrouter\//, '')
const requestedModel = request.model.replace(/^openrouter\//, '')
logger.info('Preparing OpenRouter request', {
model: requestedModel,

View File

@@ -0,0 +1,59 @@
import { createLogger } from '@/lib/logs/console/logger'
import { anthropicProvider } from '@/providers/anthropic'
import { azureOpenAIProvider } from '@/providers/azure-openai'
import { cerebrasProvider } from '@/providers/cerebras'
import { deepseekProvider } from '@/providers/deepseek'
import { googleProvider } from '@/providers/google'
import { groqProvider } from '@/providers/groq'
import { mistralProvider } from '@/providers/mistral'
import { ollamaProvider } from '@/providers/ollama'
import { openaiProvider } from '@/providers/openai'
import { openRouterProvider } from '@/providers/openrouter'
import type { ProviderConfig, ProviderId } from '@/providers/types'
import { vertexProvider } from '@/providers/vertex'
import { vllmProvider } from '@/providers/vllm'
import { xAIProvider } from '@/providers/xai'
const logger = createLogger('ProviderRegistry')
const providerRegistry: Record<ProviderId, ProviderConfig> = {
openai: openaiProvider,
anthropic: anthropicProvider,
google: googleProvider,
vertex: vertexProvider,
deepseek: deepseekProvider,
xai: xAIProvider,
cerebras: cerebrasProvider,
groq: groqProvider,
vllm: vllmProvider,
mistral: mistralProvider,
'azure-openai': azureOpenAIProvider,
openrouter: openRouterProvider,
ollama: ollamaProvider,
}
export async function getProviderExecutor(
providerId: ProviderId
): Promise<ProviderConfig | undefined> {
const provider = providerRegistry[providerId]
if (!provider) {
logger.error(`Provider not found: ${providerId}`)
return undefined
}
return provider
}
export async function initializeProviders(): Promise<void> {
for (const [id, provider] of Object.entries(providerRegistry)) {
if (provider.initialize) {
try {
await provider.initialize()
logger.info(`Initialized provider: ${id}`)
} catch (error) {
logger.error(`Failed to initialize ${id} provider`, {
error: error instanceof Error ? error.message : 'Unknown error',
})
}
}
}
}

View File

@@ -164,6 +164,7 @@ export interface ProviderRequest {
vertexLocation?: string
reasoningEffort?: string
verbosity?: string
thinkingLevel?: string
}
export const providers: Record<string, ProviderConfig> = {}

View File

@@ -3,13 +3,6 @@ import type { CompletionUsage } from 'openai/resources/completions'
import { getEnv, isTruthy } from '@/lib/core/config/env'
import { isHosted } from '@/lib/core/config/feature-flags'
import { createLogger, type Logger } from '@/lib/logs/console/logger'
import { anthropicProvider } from '@/providers/anthropic'
import { azureOpenAIProvider } from '@/providers/azure-openai'
import { cerebrasProvider } from '@/providers/cerebras'
import { deepseekProvider } from '@/providers/deepseek'
import { googleProvider } from '@/providers/google'
import { groqProvider } from '@/providers/groq'
import { mistralProvider } from '@/providers/mistral'
import {
getComputerUseModels,
getEmbeddingModelPricing,
@@ -20,117 +13,82 @@ import {
getModelsWithTemperatureSupport,
getModelsWithTempRange01,
getModelsWithTempRange02,
getModelsWithThinking,
getModelsWithVerbosity,
getProviderDefaultModel as getProviderDefaultModelFromDefinitions,
getProviderModels as getProviderModelsFromDefinitions,
getProvidersWithToolUsageControl,
getReasoningEffortValuesForModel as getReasoningEffortValuesForModelFromDefinitions,
getThinkingLevelsForModel as getThinkingLevelsForModelFromDefinitions,
getVerbosityValuesForModel as getVerbosityValuesForModelFromDefinitions,
PROVIDER_DEFINITIONS,
supportsTemperature as supportsTemperatureFromDefinitions,
supportsToolUsageControl as supportsToolUsageControlFromDefinitions,
updateOllamaModels as updateOllamaModelsInDefinitions,
} from '@/providers/models'
import { ollamaProvider } from '@/providers/ollama'
import { openaiProvider } from '@/providers/openai'
import { openRouterProvider } from '@/providers/openrouter'
import type { ProviderConfig, ProviderId, ProviderToolConfig } from '@/providers/types'
import { vertexProvider } from '@/providers/vertex'
import { vllmProvider } from '@/providers/vllm'
import { xAIProvider } from '@/providers/xai'
import type { ProviderId, ProviderToolConfig } from '@/providers/types'
import { useCustomToolsStore } from '@/stores/custom-tools/store'
import { useProvidersStore } from '@/stores/providers/store'
const logger = createLogger('ProviderUtils')
export const providers: Record<
ProviderId,
ProviderConfig & {
models: string[]
computerUseModels?: string[]
modelPatterns?: RegExp[]
/**
* Client-safe provider metadata.
* This object contains only model lists and patterns - no executeRequest implementations.
* For server-side execution, use @/providers/registry.
*/
export interface ProviderMetadata {
id: string
name: string
description: string
version: string
models: string[]
defaultModel: string
computerUseModels?: string[]
modelPatterns?: RegExp[]
}
/**
* Build provider metadata from PROVIDER_DEFINITIONS.
* This is client-safe as it doesn't import any provider implementations.
*/
function buildProviderMetadata(providerId: ProviderId): ProviderMetadata {
const def = PROVIDER_DEFINITIONS[providerId]
return {
id: providerId,
name: def?.name || providerId,
description: def?.description || '',
version: '1.0.0',
models: getProviderModelsFromDefinitions(providerId),
defaultModel: getProviderDefaultModelFromDefinitions(providerId),
modelPatterns: def?.modelPatterns,
}
> = {
}
export const providers: Record<ProviderId, ProviderMetadata> = {
openai: {
...openaiProvider,
models: getProviderModelsFromDefinitions('openai'),
...buildProviderMetadata('openai'),
computerUseModels: ['computer-use-preview'],
modelPatterns: PROVIDER_DEFINITIONS.openai.modelPatterns,
},
anthropic: {
...anthropicProvider,
models: getProviderModelsFromDefinitions('anthropic'),
...buildProviderMetadata('anthropic'),
computerUseModels: getComputerUseModels().filter((model) =>
getProviderModelsFromDefinitions('anthropic').includes(model)
),
modelPatterns: PROVIDER_DEFINITIONS.anthropic.modelPatterns,
},
google: {
...googleProvider,
models: getProviderModelsFromDefinitions('google'),
modelPatterns: PROVIDER_DEFINITIONS.google.modelPatterns,
},
vertex: {
...vertexProvider,
models: getProviderModelsFromDefinitions('vertex'),
modelPatterns: PROVIDER_DEFINITIONS.vertex.modelPatterns,
},
deepseek: {
...deepseekProvider,
models: getProviderModelsFromDefinitions('deepseek'),
modelPatterns: PROVIDER_DEFINITIONS.deepseek.modelPatterns,
},
xai: {
...xAIProvider,
models: getProviderModelsFromDefinitions('xai'),
modelPatterns: PROVIDER_DEFINITIONS.xai.modelPatterns,
},
cerebras: {
...cerebrasProvider,
models: getProviderModelsFromDefinitions('cerebras'),
modelPatterns: PROVIDER_DEFINITIONS.cerebras.modelPatterns,
},
groq: {
...groqProvider,
models: getProviderModelsFromDefinitions('groq'),
modelPatterns: PROVIDER_DEFINITIONS.groq.modelPatterns,
},
vllm: {
...vllmProvider,
models: getProviderModelsFromDefinitions('vllm'),
modelPatterns: PROVIDER_DEFINITIONS.vllm.modelPatterns,
},
mistral: {
...mistralProvider,
models: getProviderModelsFromDefinitions('mistral'),
modelPatterns: PROVIDER_DEFINITIONS.mistral.modelPatterns,
},
'azure-openai': {
...azureOpenAIProvider,
models: getProviderModelsFromDefinitions('azure-openai'),
modelPatterns: PROVIDER_DEFINITIONS['azure-openai'].modelPatterns,
},
openrouter: {
...openRouterProvider,
models: getProviderModelsFromDefinitions('openrouter'),
modelPatterns: PROVIDER_DEFINITIONS.openrouter.modelPatterns,
},
ollama: {
...ollamaProvider,
models: getProviderModelsFromDefinitions('ollama'),
modelPatterns: PROVIDER_DEFINITIONS.ollama.modelPatterns,
},
google: buildProviderMetadata('google'),
vertex: buildProviderMetadata('vertex'),
deepseek: buildProviderMetadata('deepseek'),
xai: buildProviderMetadata('xai'),
cerebras: buildProviderMetadata('cerebras'),
groq: buildProviderMetadata('groq'),
vllm: buildProviderMetadata('vllm'),
mistral: buildProviderMetadata('mistral'),
'azure-openai': buildProviderMetadata('azure-openai'),
openrouter: buildProviderMetadata('openrouter'),
ollama: buildProviderMetadata('ollama'),
}
Object.entries(providers).forEach(([id, provider]) => {
if (provider.initialize) {
provider.initialize().catch((error) => {
logger.error(`Failed to initialize ${id} provider`, {
error: error instanceof Error ? error.message : 'Unknown error',
})
})
}
})
export function updateOllamaProviderModels(models: string[]): void {
updateOllamaModelsInDefinitions(models)
providers.ollama.models = getProviderModelsFromDefinitions('ollama')
@@ -211,12 +169,12 @@ export function getProviderFromModel(model: string): ProviderId {
return 'ollama'
}
export function getProvider(id: string): ProviderConfig | undefined {
export function getProvider(id: string): ProviderMetadata | undefined {
const providerId = id.split('/')[0] as ProviderId
return providers[providerId]
}
export function getProviderConfigFromModel(model: string): ProviderConfig | undefined {
export function getProviderConfigFromModel(model: string): ProviderMetadata | undefined {
const providerId = getProviderFromModel(model)
return providers[providerId]
}
@@ -929,6 +887,7 @@ export const MODELS_TEMP_RANGE_0_1 = getModelsWithTempRange01()
export const MODELS_WITH_TEMPERATURE_SUPPORT = getModelsWithTemperatureSupport()
export const MODELS_WITH_REASONING_EFFORT = getModelsWithReasoningEffort()
export const MODELS_WITH_VERBOSITY = getModelsWithVerbosity()
export const MODELS_WITH_THINKING = getModelsWithThinking()
export const PROVIDERS_WITH_TOOL_USAGE_CONTROL = getProvidersWithToolUsageControl()
export function supportsTemperature(model: string): boolean {
@@ -963,6 +922,14 @@ export function getVerbosityValuesForModel(model: string): string[] | null {
return getVerbosityValuesForModelFromDefinitions(model)
}
/**
* Get thinking levels for a specific model
* Returns the valid levels for that model, or null if the model doesn't support thinking
*/
export function getThinkingLevelsForModel(model: string): string[] | null {
return getThinkingLevelsForModelFromDefinitions(model)
}
/**
* Prepare tool execution parameters, separating tool parameters from system parameters
*/

View File

@@ -1,33 +1,23 @@
import { GoogleGenAI } from '@google/genai'
import { OAuth2Client } from 'google-auth-library'
import { env } from '@/lib/core/config/env'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import {
cleanSchemaForGemini,
convertToGeminiFormat,
extractFunctionCall,
extractTextContent,
} from '@/providers/google/utils'
import { executeGeminiRequest } from '@/providers/gemini/core'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
ProviderRequest,
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
} from '@/providers/utils'
import { buildVertexEndpoint, createReadableStreamFromVertexStream } from '@/providers/vertex/utils'
import { executeTool } from '@/tools'
import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types'
const logger = createLogger('VertexProvider')
/**
* Vertex AI provider configuration
* Vertex AI provider
*
* Uses the @google/genai SDK with Vertex AI backend and OAuth authentication.
* Shares core execution logic with Google Gemini provider.
*
* Authentication:
* - Uses OAuth access token passed via googleAuthOptions.authClient
* - Token refresh is handled at the OAuth layer before calling this provider
*/
export const vertexProvider: ProviderConfig = {
id: 'vertex',
@@ -55,869 +45,35 @@ export const vertexProvider: ProviderConfig = {
)
}
logger.info('Preparing Vertex AI request', {
model: request.model || 'vertex/gemini-2.5-pro',
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
toolCount: request.tools?.length || 0,
hasResponseFormat: !!request.responseFormat,
streaming: !!request.stream,
// Strip 'vertex/' prefix from model name if present
const model = request.model.replace('vertex/', '')
logger.info('Creating Vertex AI client', {
project: vertexProject,
location: vertexLocation,
model,
})
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
const requestedModel = (request.model || 'vertex/gemini-2.5-pro').replace('vertex/', '')
const payload: any = {
contents,
generationConfig: {},
}
if (request.temperature !== undefined && request.temperature !== null) {
payload.generationConfig.temperature = request.temperature
}
if (request.maxTokens !== undefined) {
payload.generationConfig.maxOutputTokens = request.maxTokens
}
if (systemInstruction) {
payload.systemInstruction = systemInstruction
}
if (request.responseFormat && !tools?.length) {
const responseFormatSchema = request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
payload.generationConfig.responseMimeType = 'application/json'
payload.generationConfig.responseSchema = cleanSchema
logger.info('Using Vertex AI native structured output format', {
hasSchema: !!cleanSchema,
mimeType: 'application/json',
})
} else if (request.responseFormat && tools?.length) {
logger.warn(
'Vertex AI does not support structured output (responseFormat) with function calling (tools). Structured output will be ignored.'
)
}
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
if (tools?.length) {
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'google')
const { tools: filteredTools, toolConfig } = preparedTools
if (filteredTools?.length) {
payload.tools = [
{
functionDeclarations: filteredTools,
},
]
if (toolConfig) {
payload.toolConfig = toolConfig
}
logger.info('Vertex AI request with tools:', {
toolCount: filteredTools.length,
model: requestedModel,
tools: filteredTools.map((t) => t.name),
hasToolConfig: !!toolConfig,
toolConfig: toolConfig,
})
}
}
const initialCallTime = Date.now()
const shouldStream = !!(request.stream && !tools?.length)
const endpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
shouldStream
)
if (request.stream && tools?.length) {
logger.info('Streaming disabled for initial request due to tools presence', {
toolCount: tools.length,
willStreamAfterTools: true,
})
}
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(payload),
})
if (!response.ok) {
const responseText = await response.text()
logger.error('Vertex AI API error details:', {
status: response.status,
statusText: response.statusText,
responseBody: responseText,
})
throw new Error(`Vertex AI API error: ${response.status} ${response.statusText}`)
}
const firstResponseTime = Date.now() - initialCallTime
if (shouldStream) {
logger.info('Handling Vertex AI streaming response')
const streamingResult: StreamingExecution = {
stream: null as any,
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens: {
prompt: 0,
completion: 0,
total: 0,
},
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: firstResponseTime,
modelTime: firstResponseTime,
toolsTime: 0,
firstResponseTime,
iterations: 1,
timeSegments: [
{
type: 'model',
name: 'Initial streaming response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
],
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: firstResponseTime,
},
isStreaming: true,
},
}
streamingResult.stream = createReadableStreamFromVertexStream(
response,
(content, usage) => {
streamingResult.execution.output.content = content
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
if (streamingResult.execution.output.providerTiming) {
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
streamingResult.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
streamEndTime
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
streamEndTime - providerStartTime
}
}
const promptTokens = usage?.promptTokenCount || 0
const completionTokens = usage?.candidatesTokenCount || 0
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
streamingResult.execution.output.tokens = {
prompt: promptTokens,
completion: completionTokens,
total: totalTokens,
}
const costResult = calculateCost(request.model, promptTokens, completionTokens)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
}
)
return streamingResult
}
let geminiResponse = await response.json()
if (payload.generationConfig?.responseSchema) {
const candidate = geminiResponse.candidates?.[0]
if (candidate?.content?.parts?.[0]?.text) {
const text = candidate.content.parts[0].text
try {
JSON.parse(text)
logger.info('Successfully received structured JSON output')
} catch (_e) {
logger.warn('Failed to parse structured output as JSON')
}
}
}
let content = ''
let tokens = {
prompt: 0,
completion: 0,
total: 0,
}
const toolCalls = []
const toolResults = []
let iterationCount = 0
const originalToolConfig = preparedTools?.toolConfig
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
let hasUsedForcedTool = false
let currentToolConfig = originalToolConfig
const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => {
if (currentToolConfig && forcedTools.length > 0) {
const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }]
const result = trackForcedToolUsage(
toolCallsForTracking,
currentToolConfig,
logger,
'google',
forcedTools,
usedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
if (result.nextToolConfig) {
currentToolConfig = result.nextToolConfig
logger.info('Updated tool config for next iteration', {
hasNextToolConfig: !!currentToolConfig,
usedForcedTools: usedForcedTools,
})
}
}
}
let modelTime = firstResponseTime
let toolsTime = 0
const timeSegments: TimeSegment[] = [
{
type: 'model',
name: 'Initial response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
]
try {
const candidate = geminiResponse.candidates?.[0]
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
logger.warn(
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
{
finishReason: candidate.finishReason,
hasContent: !!candidate?.content,
hasParts: !!candidate?.content?.parts,
}
)
content = extractTextContent(candidate)
}
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.info(`Received function call from Vertex AI: ${functionCall.name}`)
while (iterationCount < MAX_TOOL_ITERATIONS) {
const latestResponse = geminiResponse.candidates?.[0]
const latestFunctionCall = extractFunctionCall(latestResponse)
if (!latestFunctionCall) {
content = extractTextContent(latestResponse)
break
}
logger.info(
`Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
const toolsStartTime = Date.now()
try {
const toolName = latestFunctionCall.name
const toolArgs = latestFunctionCall.args || {}
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) {
logger.warn(`Tool ${toolName} not found in registry, skipping`)
break
}
const toolCallStartTime = Date.now()
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
})
let resultContent: any
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
} else {
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
tool: toolName,
}
}
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
result: resultContent,
success: result.success,
})
const simplifiedMessages = [
...(contents.filter((m) => m.role === 'user').length > 0
? [contents.filter((m) => m.role === 'user')[0]]
: [contents[0]]),
{
role: 'model',
parts: [
{
functionCall: {
name: latestFunctionCall.name,
args: latestFunctionCall.args,
},
},
],
},
{
role: 'user',
parts: [
{
text: `Function ${latestFunctionCall.name} result: ${JSON.stringify(resultContent)}`,
},
],
},
]
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
checkForForcedToolUsage(latestFunctionCall)
const nextModelStartTime = Date.now()
try {
if (request.stream) {
const streamingPayload = {
...payload,
contents: simplifiedMessages,
}
const allForcedToolsUsed =
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
if (allForcedToolsUsed && request.responseFormat) {
streamingPayload.tools = undefined
streamingPayload.toolConfig = undefined
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!streamingPayload.generationConfig) {
streamingPayload.generationConfig = {}
}
streamingPayload.generationConfig.responseMimeType = 'application/json'
streamingPayload.generationConfig.responseSchema = cleanSchema
logger.info('Using structured output for final response after tool execution')
} else {
if (currentToolConfig) {
streamingPayload.toolConfig = currentToolConfig
} else {
streamingPayload.toolConfig = { functionCallingConfig: { mode: 'AUTO' } }
}
}
const checkPayload = {
...streamingPayload,
}
checkPayload.stream = undefined
const checkEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
false
)
const checkResponse = await fetch(checkEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(checkPayload),
})
if (!checkResponse.ok) {
const errorBody = await checkResponse.text()
logger.error('Error in Vertex AI check request:', {
status: checkResponse.status,
statusText: checkResponse.statusText,
responseBody: errorBody,
})
throw new Error(
`Vertex AI API check error: ${checkResponse.status} ${checkResponse.statusText}`
)
}
const checkResult = await checkResponse.json()
const checkCandidate = checkResult.candidates?.[0]
const checkFunctionCall = extractFunctionCall(checkCandidate)
if (checkFunctionCall) {
logger.info(
'Function call detected in follow-up, handling in non-streaming mode',
{
functionName: checkFunctionCall.name,
}
)
geminiResponse = checkResult
if (checkResult.usageMetadata) {
tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0
tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0
tokens.total +=
(checkResult.usageMetadata.promptTokenCount || 0) +
(checkResult.usageMetadata.candidatesTokenCount || 0)
}
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
modelTime += thisModelTime
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
startTime: nextModelStartTime,
endTime: nextModelEndTime,
duration: thisModelTime,
})
iterationCount++
continue
}
logger.info('No function call detected, proceeding with streaming response')
if (request.responseFormat) {
streamingPayload.tools = undefined
streamingPayload.toolConfig = undefined
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!streamingPayload.generationConfig) {
streamingPayload.generationConfig = {}
}
streamingPayload.generationConfig.responseMimeType = 'application/json'
streamingPayload.generationConfig.responseSchema = cleanSchema
logger.info(
'Using structured output for final streaming response after tool execution'
)
}
const streamEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
true
)
const streamingResponse = await fetch(streamEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(streamingPayload),
})
if (!streamingResponse.ok) {
const errorBody = await streamingResponse.text()
logger.error('Error in Vertex AI streaming follow-up request:', {
status: streamingResponse.status,
statusText: streamingResponse.statusText,
responseBody: errorBody,
})
throw new Error(
`Vertex AI API streaming error: ${streamingResponse.status} ${streamingResponse.statusText}`
)
}
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
modelTime += thisModelTime
timeSegments.push({
type: 'model',
name: 'Final streaming response after tool calls',
startTime: nextModelStartTime,
endTime: nextModelEndTime,
duration: thisModelTime,
})
const streamingExecution: StreamingExecution = {
stream: null as any,
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens,
toolCalls:
toolCalls.length > 0
? {
list: toolCalls,
count: toolCalls.length,
}
: undefined,
toolResults,
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
modelTime,
toolsTime,
firstResponseTime,
iterations: iterationCount + 1,
timeSegments,
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
},
isStreaming: true,
},
}
streamingExecution.stream = createReadableStreamFromVertexStream(
streamingResponse,
(content, usage) => {
streamingExecution.execution.output.content = content
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
if (streamingExecution.execution.output.providerTiming) {
streamingExecution.execution.output.providerTiming.endTime =
streamEndTimeISO
streamingExecution.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
}
const promptTokens = usage?.promptTokenCount || 0
const completionTokens = usage?.candidatesTokenCount || 0
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
const existingTokens = streamingExecution.execution.output.tokens || {
prompt: 0,
completion: 0,
total: 0,
}
const existingPrompt = existingTokens.prompt || 0
const existingCompletion = existingTokens.completion || 0
const existingTotal = existingTokens.total || 0
streamingExecution.execution.output.tokens = {
prompt: existingPrompt + promptTokens,
completion: existingCompletion + completionTokens,
total: existingTotal + totalTokens,
}
const accumulatedCost = calculateCost(
request.model,
existingPrompt,
existingCompletion
)
const streamCost = calculateCost(
request.model,
promptTokens,
completionTokens
)
streamingExecution.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}
)
return streamingExecution
}
const nextPayload = {
...payload,
contents: simplifiedMessages,
}
const allForcedToolsUsed =
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
if (allForcedToolsUsed && request.responseFormat) {
nextPayload.tools = undefined
nextPayload.toolConfig = undefined
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!nextPayload.generationConfig) {
nextPayload.generationConfig = {}
}
nextPayload.generationConfig.responseMimeType = 'application/json'
nextPayload.generationConfig.responseSchema = cleanSchema
logger.info(
'Using structured output for final non-streaming response after tool execution'
)
} else {
if (currentToolConfig) {
nextPayload.toolConfig = currentToolConfig
}
}
const nextEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
false
)
const nextResponse = await fetch(nextEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(nextPayload),
})
if (!nextResponse.ok) {
const errorBody = await nextResponse.text()
logger.error('Error in Vertex AI follow-up request:', {
status: nextResponse.status,
statusText: nextResponse.statusText,
responseBody: errorBody,
iterationCount,
})
break
}
geminiResponse = await nextResponse.json()
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
startTime: nextModelStartTime,
endTime: nextModelEndTime,
duration: thisModelTime,
})
modelTime += thisModelTime
const nextCandidate = geminiResponse.candidates?.[0]
const nextFunctionCall = extractFunctionCall(nextCandidate)
if (!nextFunctionCall) {
if (request.responseFormat) {
const finalPayload = {
...payload,
contents: nextPayload.contents,
tools: undefined,
toolConfig: undefined,
}
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!finalPayload.generationConfig) {
finalPayload.generationConfig = {}
}
finalPayload.generationConfig.responseMimeType = 'application/json'
finalPayload.generationConfig.responseSchema = cleanSchema
logger.info('Making final request with structured output after tool execution')
const finalEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
false
)
const finalResponse = await fetch(finalEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(finalPayload),
})
if (finalResponse.ok) {
const finalResult = await finalResponse.json()
const finalCandidate = finalResult.candidates?.[0]
content = extractTextContent(finalCandidate)
if (finalResult.usageMetadata) {
tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0
tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0
tokens.total +=
(finalResult.usageMetadata.promptTokenCount || 0) +
(finalResult.usageMetadata.candidatesTokenCount || 0)
}
} else {
logger.warn(
'Failed to get structured output, falling back to regular response'
)
content = extractTextContent(nextCandidate)
}
} else {
content = extractTextContent(nextCandidate)
}
break
}
iterationCount++
} catch (error) {
logger.error('Error in Vertex AI follow-up request:', {
error: error instanceof Error ? error.message : String(error),
iterationCount,
})
break
}
} catch (error) {
logger.error('Error processing function call:', {
error: error instanceof Error ? error.message : String(error),
functionName: latestFunctionCall?.name || 'unknown',
})
break
}
}
} else {
content = extractTextContent(candidate)
}
} catch (error) {
logger.error('Error processing Vertex AI response:', {
error: error instanceof Error ? error.message : String(error),
iterationCount,
})
if (!content && toolCalls.length > 0) {
content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.`
}
}
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
if (geminiResponse.usageMetadata) {
tokens = {
prompt: geminiResponse.usageMetadata.promptTokenCount || 0,
completion: geminiResponse.usageMetadata.candidatesTokenCount || 0,
total:
(geminiResponse.usageMetadata.promptTokenCount || 0) +
(geminiResponse.usageMetadata.candidatesTokenCount || 0),
}
}
return {
content,
model: request.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
timing: {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
modelTime: modelTime,
toolsTime: toolsTime,
firstResponseTime: firstResponseTime,
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
}
} catch (error) {
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
logger.error('Error in Vertex AI request:', {
error: error instanceof Error ? error.message : String(error),
duration: totalDuration,
})
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
}
throw enhancedError
}
// Create an OAuth2Client and set the access token
// This allows us to use an OAuth access token with the SDK
const authClient = new OAuth2Client()
authClient.setCredentials({ access_token: request.apiKey })
// Create client with Vertex AI configuration
const ai = new GoogleGenAI({
vertexai: true,
project: vertexProject,
location: vertexLocation,
googleAuthOptions: {
authClient,
},
})
return executeGeminiRequest({
ai,
model,
request,
providerType: 'vertex',
})
},
}

View File

@@ -1,231 +0,0 @@
import { createLogger } from '@/lib/logs/console/logger'
import { extractFunctionCall, extractTextContent } from '@/providers/google/utils'
const logger = createLogger('VertexUtils')
/**
* Creates a ReadableStream from Vertex AI's Gemini stream response
*/
export function createReadableStreamFromVertexStream(
response: Response,
onComplete?: (
content: string,
usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number }
) => void
): ReadableStream<Uint8Array> {
const reader = response.body?.getReader()
if (!reader) {
throw new Error('Failed to get reader from response body')
}
return new ReadableStream({
async start(controller) {
try {
let buffer = ''
let fullContent = ''
let usageData: {
promptTokenCount?: number
candidatesTokenCount?: number
totalTokenCount?: number
} | null = null
while (true) {
const { done, value } = await reader.read()
if (done) {
if (buffer.trim()) {
try {
const data = JSON.parse(buffer.trim())
if (data.usageMetadata) {
usageData = data.usageMetadata
}
const candidate = data.candidates?.[0]
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in final buffer, ending stream to execute tool',
{
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
} catch (e) {
if (buffer.trim().startsWith('[')) {
try {
const dataArray = JSON.parse(buffer.trim())
if (Array.isArray(dataArray)) {
for (const item of dataArray) {
if (item.usageMetadata) {
usageData = item.usageMetadata
}
const candidate = item.candidates?.[0]
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in array item, ending stream to execute tool',
{
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
}
}
} catch (arrayError) {}
}
}
}
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
break
}
const text = new TextDecoder().decode(value)
buffer += text
let searchIndex = 0
while (searchIndex < buffer.length) {
const openBrace = buffer.indexOf('{', searchIndex)
if (openBrace === -1) break
let braceCount = 0
let inString = false
let escaped = false
let closeBrace = -1
for (let i = openBrace; i < buffer.length; i++) {
const char = buffer[i]
if (!inString) {
if (char === '"' && !escaped) {
inString = true
} else if (char === '{') {
braceCount++
} else if (char === '}') {
braceCount--
if (braceCount === 0) {
closeBrace = i
break
}
}
} else {
if (char === '"' && !escaped) {
inString = false
}
}
escaped = char === '\\' && !escaped
}
if (closeBrace !== -1) {
const jsonStr = buffer.substring(openBrace, closeBrace + 1)
try {
const data = JSON.parse(jsonStr)
if (data.usageMetadata) {
usageData = data.usageMetadata
}
const candidate = data.candidates?.[0]
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
logger.warn(
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
{
finishReason: candidate.finishReason,
hasContent: !!candidate?.content,
hasParts: !!candidate?.content?.parts,
}
)
const textContent = extractTextContent(candidate)
if (textContent) {
fullContent += textContent
controller.enqueue(new TextEncoder().encode(textContent))
}
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in stream, ending stream to execute tool',
{
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
} catch (e) {
logger.error('Error parsing JSON from stream', {
error: e instanceof Error ? e.message : String(e),
jsonPreview: jsonStr.substring(0, 200),
})
}
buffer = buffer.substring(closeBrace + 1)
searchIndex = 0
} else {
break
}
}
}
} catch (e) {
logger.error('Error reading Vertex AI stream', {
error: e instanceof Error ? e.message : String(e),
})
controller.error(e)
}
},
async cancel() {
await reader.cancel()
},
})
}
/**
* Build Vertex AI endpoint URL
*/
export function buildVertexEndpoint(
project: string,
location: string,
model: string,
isStreaming: boolean
): string {
const action = isStreaming ? 'streamGenerateContent' : 'generateContent'
if (location === 'global') {
return `https://aiplatform.googleapis.com/v1/projects/${project}/locations/global/publishers/google/models/${model}:${action}`
}
return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/google/models/${model}:${action}`
}

View File

@@ -130,7 +130,7 @@ export const vllmProvider: ProviderConfig = {
: undefined
const payload: any = {
model: (request.model || getProviderDefaultModel('vllm')).replace(/^vllm\//, ''),
model: request.model.replace(/^vllm\//, ''),
messages: allMessages,
}

View File

@@ -48,7 +48,7 @@ export const xAIProvider: ProviderConfig = {
hasTools: !!request.tools?.length,
toolCount: request.tools?.length || 0,
hasResponseFormat: !!request.responseFormat,
model: request.model || 'grok-3-latest',
model: request.model,
streaming: !!request.stream,
})
@@ -87,7 +87,7 @@ export const xAIProvider: ProviderConfig = {
)
}
const basePayload: any = {
model: request.model || 'grok-3-latest',
model: request.model,
messages: allMessages,
}
@@ -139,7 +139,7 @@ export const xAIProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'grok-3-latest',
model: request.model,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
@@ -505,7 +505,7 @@ export const xAIProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'grok-3-latest',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,