Compare commits

...

13 Commits

Author SHA1 Message Date
Vikhyath Mondreti
e12dd204ed v0.5.41: memory fixes, copilot improvements, knowledgebase improvements, LLM providers standardization 2025-12-23 00:15:18 -08:00
Vikhyath Mondreti
621f9a40c7 improvement(landing): free usage limit (#2547) 2025-12-23 00:07:13 -08:00
Siddharth Ganesan
3100daa346 feat(copilot): add tools to access block outputs and upstream references (#2546)
* Add copilot references tools

* Minor fixes

* Omit vars field in block outputs when id is provided
2025-12-23 00:06:24 -08:00
Priyanshu Solanki
c252e885af improvement(logs): fixed logs for parallel and loop execution flow (#2468)
* fixed logs for parallel and loop execution flow

* Fix array check for collection

* fixed for empty loop and paralle blocks and showing input on dashboard

* extracted utility functions

* fixed the refrencing errors and making sure it propogates to the console

* fix parallel

* fix tests'

---------

Co-authored-by: priyanshu.solanki <priyanshu.solanki@saviynt.com>
Co-authored-by: Siddharth Ganesan <siddharthganesan@gmail.com>
Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
2025-12-23 00:02:02 -08:00
Waleed
b0748c82f9 fix(search): removed full text param from built-in search, anthropic provider streaming fix (#2542)
* fix(search): removed full text param from built-in search, anthropic provider streaming fix

* rewrite gemini provider with official sdk + add thinking capability

* vertex gemini consolidation

* never silently use different model

* pass oauth client through the googleAuthOptions param directly

* make server side provider registry

* remove comments

* take oauth selector below model selector

---------

Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
2025-12-22 23:57:11 -08:00
Waleed
f5245f3eca fix(billing): add line items for wand (#2543)
* fix(billing): add line items for wand

* ack pr comment
2025-12-22 23:06:14 -08:00
Vikhyath Mondreti
3d9d9cbc54 v0.5.40: supabase ops to allow non-public schemas, jira uuid 2025-12-21 22:28:05 -08:00
Waleed
0f4ec962ad v0.5.39: notion, workflow variables fixes 2025-12-20 20:44:00 -08:00
Waleed
4827866f9a v0.5.38: snap to grid, copilot ux improvements, billing line items 2025-12-20 17:24:38 -08:00
Waleed
3e697d9ed9 v0.5.37: redaction utils consolidation, logs updates, autoconnect improvements, additional kb tag types 2025-12-19 22:31:55 -08:00
Martin Yankov
4431a1a484 fix(helm): add custom egress rules to realtime network policy (#2481)
The realtime service network policy was missing the custom egress rules section
that allows configuration of additional egress rules via values.yaml. This caused
the realtime pods to be unable to connect to external databases (e.g., PostgreSQL
on port 5432) when using external database configurations.

The app network policy already had this section, but the realtime network policy
was missing it, creating an inconsistency and preventing the realtime service
from accessing external databases configured via networkPolicy.egress values.

This fix adds the same custom egress rules template section to the realtime
network policy, matching the app network policy behavior and allowing users to
configure database connectivity via values.yaml.
2025-12-19 18:59:08 -08:00
Waleed
4d1a9a3f22 v0.5.36: hitl improvements, opengraph, slack fixes, one-click unsubscribe, auth checks, new db indexes 2025-12-19 01:27:49 -08:00
Vikhyath Mondreti
eb07a080fb v0.5.35: helm updates, copilot improvements, 404 for docs, salesforce fixes, subflow resize clamping 2025-12-18 16:23:19 -08:00
50 changed files with 2663 additions and 2576 deletions

View File

@@ -41,7 +41,7 @@ interface PricingTier {
* Free plan features with consistent icons
*/
const FREE_PLAN_FEATURES: PricingFeature[] = [
{ icon: DollarSign, text: '$10 usage limit' },
{ icon: DollarSign, text: '$20 usage limit' },
{ icon: HardDrive, text: '5GB file storage' },
{ icon: Workflow, text: 'Public template access' },
{ icon: Database, text: 'Limited log retention' },

View File

@@ -56,7 +56,7 @@ export async function POST(request: NextRequest) {
query: validated.query,
type: 'auto',
useAutoprompt: true,
text: true,
highlights: true,
apiKey: env.EXA_API_KEY,
})
@@ -77,7 +77,7 @@ export async function POST(request: NextRequest) {
const results = (result.output.results || []).map((r: any, index: number) => ({
title: r.title || '',
link: r.url || '',
snippet: r.text || '',
snippet: Array.isArray(r.highlights) ? r.highlights.join(' ... ') : '',
date: r.publishedDate || undefined,
position: index + 1,
}))

View File

@@ -69,7 +69,7 @@ function safeStringify(value: unknown): string {
}
async function updateUserStatsForWand(
workflowId: string,
userId: string,
usage: {
prompt_tokens?: number
completion_tokens?: number
@@ -88,21 +88,6 @@ async function updateUserStatsForWand(
}
try {
const [workflowRecord] = await db
.select({ userId: workflow.userId, workspaceId: workflow.workspaceId })
.from(workflow)
.where(eq(workflow.id, workflowId))
.limit(1)
if (!workflowRecord?.userId) {
logger.warn(
`[${requestId}] No user found for workflow ${workflowId}, cannot update user stats`
)
return
}
const userId = workflowRecord.userId
const workspaceId = workflowRecord.workspaceId
const totalTokens = usage.total_tokens || 0
const promptTokens = usage.prompt_tokens || 0
const completionTokens = usage.completion_tokens || 0
@@ -146,8 +131,6 @@ async function updateUserStatsForWand(
inputTokens: promptTokens,
outputTokens: completionTokens,
cost: costToStore,
workspaceId: workspaceId ?? undefined,
workflowId,
})
await checkAndBillOverageThreshold(userId)
@@ -325,6 +308,11 @@ export async function POST(req: NextRequest) {
if (data === '[DONE]') {
logger.info(`[${requestId}] Received [DONE] signal`)
if (finalUsage) {
await updateUserStatsForWand(session.user.id, finalUsage, requestId)
}
controller.enqueue(
encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`)
)
@@ -353,10 +341,6 @@ export async function POST(req: NextRequest) {
`[${requestId}] Received usage data: ${JSON.stringify(parsed.usage)}`
)
}
if (chunkCount % 10 === 0) {
logger.debug(`[${requestId}] Processed ${chunkCount} chunks`)
}
} catch (parseError) {
logger.debug(
`[${requestId}] Skipped non-JSON line: ${data.substring(0, 100)}`
@@ -365,12 +349,6 @@ export async function POST(req: NextRequest) {
}
}
}
logger.info(`[${requestId}] Wand generation streaming completed successfully`)
if (finalUsage && workflowId) {
await updateUserStatsForWand(workflowId, finalUsage, requestId)
}
} catch (streamError: any) {
logger.error(`[${requestId}] Streaming error`, {
name: streamError?.name,
@@ -438,8 +416,8 @@ export async function POST(req: NextRequest) {
logger.info(`[${requestId}] Wand generation successful`)
if (completion.usage && workflowId) {
await updateUserStatsForWand(workflowId, completion.usage, requestId)
if (completion.usage) {
await updateUserStatsForWand(session.user.id, completion.usage, requestId)
}
return NextResponse.json({ success: true, content: generatedContent })

View File

@@ -43,6 +43,8 @@ const SCOPE_DESCRIPTIONS: Record<string, string> = {
'https://www.googleapis.com/auth/admin.directory.group.readonly': 'View Google Workspace groups',
'https://www.googleapis.com/auth/admin.directory.group.member.readonly':
'View Google Workspace group memberships',
'https://www.googleapis.com/auth/cloud-platform':
'Full access to Google Cloud resources for Vertex AI',
'read:confluence-content.all': 'Read all Confluence content',
'read:confluence-space.summary': 'Read Confluence space information',
'read:space:confluence': 'View Confluence spaces',

View File

@@ -9,8 +9,10 @@ import {
getMaxTemperature,
getProviderIcon,
getReasoningEffortValuesForModel,
getThinkingLevelsForModel,
getVerbosityValuesForModel,
MODELS_WITH_REASONING_EFFORT,
MODELS_WITH_THINKING,
MODELS_WITH_VERBOSITY,
providers,
supportsTemperature,
@@ -108,7 +110,19 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
})
},
},
{
id: 'vertexCredential',
title: 'Google Cloud Account',
type: 'oauth-input',
serviceId: 'vertex-ai',
requiredScopes: ['https://www.googleapis.com/auth/cloud-platform'],
placeholder: 'Select Google Cloud account',
required: true,
condition: {
field: 'model',
value: providers.vertex.models,
},
},
{
id: 'reasoningEffort',
title: 'Reasoning Effort',
@@ -215,6 +229,57 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
value: MODELS_WITH_VERBOSITY,
},
},
{
id: 'thinkingLevel',
title: 'Thinking Level',
type: 'dropdown',
placeholder: 'Select thinking level...',
options: [
{ label: 'minimal', id: 'minimal' },
{ label: 'low', id: 'low' },
{ label: 'medium', id: 'medium' },
{ label: 'high', id: 'high' },
],
dependsOn: ['model'],
fetchOptions: async (blockId: string) => {
const { useSubBlockStore } = await import('@/stores/workflows/subblock/store')
const { useWorkflowRegistry } = await import('@/stores/workflows/registry/store')
const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId
if (!activeWorkflowId) {
return [
{ label: 'low', id: 'low' },
{ label: 'high', id: 'high' },
]
}
const workflowValues = useSubBlockStore.getState().workflowValues[activeWorkflowId]
const blockValues = workflowValues?.[blockId]
const modelValue = blockValues?.model as string
if (!modelValue) {
return [
{ label: 'low', id: 'low' },
{ label: 'high', id: 'high' },
]
}
const validOptions = getThinkingLevelsForModel(modelValue)
if (!validOptions) {
return [
{ label: 'low', id: 'low' },
{ label: 'high', id: 'high' },
]
}
return validOptions.map((opt) => ({ label: opt, id: opt }))
},
value: () => 'high',
condition: {
field: 'model',
value: MODELS_WITH_THINKING,
},
},
{
id: 'azureEndpoint',
@@ -275,17 +340,21 @@ export const AgentBlock: BlockConfig<AgentResponse> = {
password: true,
connectionDroppable: false,
required: true,
// Hide API key for hosted models, Ollama models, and vLLM models
// Hide API key for hosted models, Ollama models, vLLM models, and Vertex models (uses OAuth)
condition: isHosted
? {
field: 'model',
value: getHostedModels(),
value: [...getHostedModels(), ...providers.vertex.models],
not: true, // Show for all models EXCEPT those listed
}
: () => ({
field: 'model',
value: [...getCurrentOllamaModels(), ...getCurrentVLLMModels()],
not: true, // Show for all models EXCEPT Ollama and vLLM models
value: [
...getCurrentOllamaModels(),
...getCurrentVLLMModels(),
...providers.vertex.models,
],
not: true, // Show for all models EXCEPT Ollama, vLLM, and Vertex models
}),
},
{
@@ -609,6 +678,7 @@ Example 3 (Array Input):
temperature: { type: 'number', description: 'Response randomness level' },
reasoningEffort: { type: 'string', description: 'Reasoning effort level for GPT-5 models' },
verbosity: { type: 'string', description: 'Verbosity level for GPT-5 models' },
thinkingLevel: { type: 'string', description: 'Thinking level for Gemini 3 models' },
tools: { type: 'json', description: 'Available tools configuration' },
},
outputs: {

View File

@@ -128,6 +128,8 @@ export const DEFAULTS = {
BLOCK_TITLE: 'Untitled Block',
WORKFLOW_NAME: 'Workflow',
MAX_LOOP_ITERATIONS: 1000,
MAX_FOREACH_ITEMS: 1000,
MAX_PARALLEL_BRANCHES: 20,
MAX_WORKFLOW_DEPTH: 10,
EXECUTION_TIME: 0,
TOKENS: {

View File

@@ -4,6 +4,7 @@ import { LoopConstructor } from '@/executor/dag/construction/loops'
import { NodeConstructor } from '@/executor/dag/construction/nodes'
import { PathConstructor } from '@/executor/dag/construction/paths'
import type { DAGEdge, NodeMetadata } from '@/executor/dag/types'
import { buildSentinelStartId, extractBaseBlockId } from '@/executor/utils/subflow-utils'
import type {
SerializedBlock,
SerializedLoop,
@@ -79,6 +80,9 @@ export class DAGBuilder {
}
}
// Validate loop and parallel structure
this.validateSubflowStructure(dag)
logger.info('DAG built', {
totalNodes: dag.nodes.size,
loopCount: dag.loopConfigs.size,
@@ -105,4 +109,43 @@ export class DAGBuilder {
}
}
}
/**
* Validates that loops and parallels have proper internal structure.
* Throws an error if a loop/parallel has no blocks inside or no connections from start.
*/
private validateSubflowStructure(dag: DAG): void {
for (const [id, config] of dag.loopConfigs) {
this.validateSubflow(dag, id, config.nodes, 'Loop')
}
for (const [id, config] of dag.parallelConfigs) {
this.validateSubflow(dag, id, config.nodes, 'Parallel')
}
}
private validateSubflow(
dag: DAG,
id: string,
nodes: string[] | undefined,
type: 'Loop' | 'Parallel'
): void {
if (!nodes || nodes.length === 0) {
throw new Error(
`${type} has no blocks inside. Add at least one block to the ${type.toLowerCase()}.`
)
}
const sentinelStartNode = dag.nodes.get(buildSentinelStartId(id))
if (!sentinelStartNode) return
const hasConnections = Array.from(sentinelStartNode.outgoingEdges.values()).some((edge) =>
nodes.includes(extractBaseBlockId(edge.target))
)
if (!hasConnections) {
throw new Error(
`${type} start is not connected to any blocks. Connect a block to the ${type.toLowerCase()} start.`
)
}
}
}

View File

@@ -63,8 +63,10 @@ export class DAGExecutor {
const resolver = new VariableResolver(this.workflow, this.workflowVariables, state)
const loopOrchestrator = new LoopOrchestrator(dag, state, resolver)
loopOrchestrator.setContextExtensions(this.contextExtensions)
const parallelOrchestrator = new ParallelOrchestrator(dag, state)
parallelOrchestrator.setResolver(resolver)
parallelOrchestrator.setContextExtensions(this.contextExtensions)
const allHandlers = createBlockHandlers()
const blockExecutor = new BlockExecutor(allHandlers, resolver, this.contextExtensions, state)
const edgeManager = new EdgeManager(dag)

View File

@@ -14,6 +14,8 @@ export interface LoopScope {
condition?: string
loopType?: 'for' | 'forEach' | 'while' | 'doWhile'
skipFirstConditionCheck?: boolean
/** Error message if loop validation failed (e.g., exceeded max iterations) */
validationError?: string
}
export interface ParallelScope {
@@ -23,6 +25,8 @@ export interface ParallelScope {
completedCount: number
totalExpectedNodes: number
items?: any[]
/** Error message if parallel validation failed (e.g., exceeded max branches) */
validationError?: string
}
export class ExecutionState implements BlockStateController {

View File

@@ -1,8 +1,9 @@
import { db } from '@sim/db'
import { mcpServers } from '@sim/db/schema'
import { account, mcpServers } from '@sim/db/schema'
import { and, eq, inArray, isNull } from 'drizzle-orm'
import { createLogger } from '@/lib/logs/console/logger'
import { createMcpToolId } from '@/lib/mcp/utils'
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
import { getAllBlocks } from '@/blocks'
import type { BlockOutput } from '@/blocks/types'
import { AGENT, BlockType, DEFAULTS, HTTP } from '@/executor/constants'
@@ -919,6 +920,7 @@ export class AgentBlockHandler implements BlockHandler {
azureApiVersion: inputs.azureApiVersion,
vertexProject: inputs.vertexProject,
vertexLocation: inputs.vertexLocation,
vertexCredential: inputs.vertexCredential,
responseFormat,
workflowId: ctx.workflowId,
workspaceId: ctx.workspaceId,
@@ -997,7 +999,17 @@ export class AgentBlockHandler implements BlockHandler {
responseFormat: any,
providerStartTime: number
) {
const finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey)
let finalApiKey: string
// For Vertex AI, resolve OAuth credential to access token
if (providerId === 'vertex' && providerRequest.vertexCredential) {
finalApiKey = await this.resolveVertexCredential(
providerRequest.vertexCredential,
ctx.workflowId
)
} else {
finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey)
}
const { blockData, blockNameMapping } = collectBlockData(ctx)
@@ -1024,7 +1036,6 @@ export class AgentBlockHandler implements BlockHandler {
blockNameMapping,
})
this.logExecutionSuccess(providerId, model, ctx, block, providerStartTime, response)
return this.processProviderResponse(response, block, responseFormat)
}
@@ -1049,15 +1060,6 @@ export class AgentBlockHandler implements BlockHandler {
throw new Error(errorMessage)
}
this.logExecutionSuccess(
providerRequest.provider,
providerRequest.model,
ctx,
block,
providerStartTime,
'HTTP response'
)
const contentType = response.headers.get('Content-Type')
if (contentType?.includes(HTTP.CONTENT_TYPE.EVENT_STREAM)) {
return this.handleStreamingResponse(response, block, ctx, inputs)
@@ -1117,21 +1119,33 @@ export class AgentBlockHandler implements BlockHandler {
}
}
private logExecutionSuccess(
provider: string,
model: string,
ctx: ExecutionContext,
block: SerializedBlock,
startTime: number,
response: any
) {
const executionTime = Date.now() - startTime
const responseType =
response instanceof ReadableStream
? 'stream'
: response && typeof response === 'object' && 'stream' in response
? 'streaming-execution'
: 'json'
/**
* Resolves a Vertex AI OAuth credential to an access token
*/
private async resolveVertexCredential(credentialId: string, workflowId: string): Promise<string> {
const requestId = `vertex-${Date.now()}`
logger.info(`[${requestId}] Resolving Vertex AI credential: ${credentialId}`)
// Get the credential - we need to find the owner
// Since we're in a workflow context, we can query the credential directly
const credential = await db.query.account.findFirst({
where: eq(account.id, credentialId),
})
if (!credential) {
throw new Error(`Vertex AI credential not found: ${credentialId}`)
}
// Refresh the token if needed
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
if (!accessToken) {
throw new Error('Failed to get Vertex AI access token')
}
logger.info(`[${requestId}] Successfully resolved Vertex AI credential`)
return accessToken
}
private handleExecutionError(

View File

@@ -21,6 +21,7 @@ export interface AgentInputs {
azureApiVersion?: string
vertexProject?: string
vertexLocation?: string
vertexCredential?: string
reasoningEffort?: string
verbosity?: string
}

View File

@@ -5,14 +5,17 @@ import { buildLoopIndexCondition, DEFAULTS, EDGE } from '@/executor/constants'
import type { DAG } from '@/executor/dag/builder'
import type { EdgeManager } from '@/executor/execution/edge-manager'
import type { LoopScope } from '@/executor/execution/state'
import type { BlockStateController } from '@/executor/execution/types'
import type { BlockStateController, ContextExtensions } from '@/executor/execution/types'
import type { ExecutionContext, NormalizedBlockOutput } from '@/executor/types'
import type { LoopConfigWithNodes } from '@/executor/types/loop'
import { replaceValidReferences } from '@/executor/utils/reference-validation'
import {
addSubflowErrorLog,
buildSentinelEndId,
buildSentinelStartId,
extractBaseBlockId,
resolveArrayInput,
validateMaxCount,
} from '@/executor/utils/subflow-utils'
import type { VariableResolver } from '@/executor/variables/resolver'
import type { SerializedLoop } from '@/serializer/types'
@@ -32,6 +35,7 @@ export interface LoopContinuationResult {
export class LoopOrchestrator {
private edgeManager: EdgeManager | null = null
private contextExtensions: ContextExtensions | null = null
constructor(
private dag: DAG,
@@ -39,6 +43,10 @@ export class LoopOrchestrator {
private resolver: VariableResolver
) {}
setContextExtensions(contextExtensions: ContextExtensions): void {
this.contextExtensions = contextExtensions
}
setEdgeManager(edgeManager: EdgeManager): void {
this.edgeManager = edgeManager
}
@@ -48,7 +56,6 @@ export class LoopOrchestrator {
if (!loopConfig) {
throw new Error(`Loop config not found: ${loopId}`)
}
const scope: LoopScope = {
iteration: 0,
currentIterationOutputs: new Map(),
@@ -58,15 +65,70 @@ export class LoopOrchestrator {
const loopType = loopConfig.loopType
switch (loopType) {
case 'for':
case 'for': {
scope.loopType = 'for'
scope.maxIterations = loopConfig.iterations || DEFAULTS.MAX_LOOP_ITERATIONS
const requestedIterations = loopConfig.iterations || DEFAULTS.MAX_LOOP_ITERATIONS
const iterationError = validateMaxCount(
requestedIterations,
DEFAULTS.MAX_LOOP_ITERATIONS,
'For loop iterations'
)
if (iterationError) {
logger.error(iterationError, { loopId, requestedIterations })
this.addLoopErrorLog(ctx, loopId, loopType, iterationError, {
iterations: requestedIterations,
})
scope.maxIterations = 0
scope.validationError = iterationError
scope.condition = buildLoopIndexCondition(0)
ctx.loopExecutions?.set(loopId, scope)
throw new Error(iterationError)
}
scope.maxIterations = requestedIterations
scope.condition = buildLoopIndexCondition(scope.maxIterations)
break
}
case 'forEach': {
scope.loopType = 'forEach'
const items = this.resolveForEachItems(ctx, loopConfig.forEachItems)
let items: any[]
try {
items = this.resolveForEachItems(ctx, loopConfig.forEachItems)
} catch (error) {
const errorMessage = `ForEach loop resolution failed: ${error instanceof Error ? error.message : String(error)}`
logger.error(errorMessage, { loopId, forEachItems: loopConfig.forEachItems })
this.addLoopErrorLog(ctx, loopId, loopType, errorMessage, {
forEachItems: loopConfig.forEachItems,
})
scope.items = []
scope.maxIterations = 0
scope.validationError = errorMessage
scope.condition = buildLoopIndexCondition(0)
ctx.loopExecutions?.set(loopId, scope)
throw new Error(errorMessage)
}
const sizeError = validateMaxCount(
items.length,
DEFAULTS.MAX_FOREACH_ITEMS,
'ForEach loop collection size'
)
if (sizeError) {
logger.error(sizeError, { loopId, collectionSize: items.length })
this.addLoopErrorLog(ctx, loopId, loopType, sizeError, {
forEachItems: loopConfig.forEachItems,
collectionSize: items.length,
})
scope.items = []
scope.maxIterations = 0
scope.validationError = sizeError
scope.condition = buildLoopIndexCondition(0)
ctx.loopExecutions?.set(loopId, scope)
throw new Error(sizeError)
}
scope.items = items
scope.maxIterations = items.length
scope.item = items[0]
@@ -79,15 +141,35 @@ export class LoopOrchestrator {
scope.condition = loopConfig.whileCondition
break
case 'doWhile':
case 'doWhile': {
scope.loopType = 'doWhile'
if (loopConfig.doWhileCondition) {
scope.condition = loopConfig.doWhileCondition
} else {
scope.maxIterations = loopConfig.iterations || DEFAULTS.MAX_LOOP_ITERATIONS
const requestedIterations = loopConfig.iterations || DEFAULTS.MAX_LOOP_ITERATIONS
const iterationError = validateMaxCount(
requestedIterations,
DEFAULTS.MAX_LOOP_ITERATIONS,
'Do-While loop iterations'
)
if (iterationError) {
logger.error(iterationError, { loopId, requestedIterations })
this.addLoopErrorLog(ctx, loopId, loopType, iterationError, {
iterations: requestedIterations,
})
scope.maxIterations = 0
scope.validationError = iterationError
scope.condition = buildLoopIndexCondition(0)
ctx.loopExecutions?.set(loopId, scope)
throw new Error(iterationError)
}
scope.maxIterations = requestedIterations
scope.condition = buildLoopIndexCondition(scope.maxIterations)
}
break
}
default:
throw new Error(`Unknown loop type: ${loopType}`)
@@ -100,6 +182,23 @@ export class LoopOrchestrator {
return scope
}
private addLoopErrorLog(
ctx: ExecutionContext,
loopId: string,
loopType: string,
errorMessage: string,
inputData?: any
): void {
addSubflowErrorLog(
ctx,
loopId,
'loop',
errorMessage,
{ loopType, ...inputData },
this.contextExtensions
)
}
storeLoopNodeOutput(
ctx: ExecutionContext,
loopId: string,
@@ -412,54 +511,6 @@ export class LoopOrchestrator {
}
private resolveForEachItems(ctx: ExecutionContext, items: any): any[] {
if (Array.isArray(items)) {
return items
}
if (typeof items === 'object' && items !== null) {
return Object.entries(items)
}
if (typeof items === 'string') {
if (items.startsWith('<') && items.endsWith('>')) {
const resolved = this.resolver.resolveSingleReference(ctx, '', items)
if (Array.isArray(resolved)) {
return resolved
}
return []
}
try {
const normalized = items.replace(/'/g, '"')
const parsed = JSON.parse(normalized)
if (Array.isArray(parsed)) {
return parsed
}
return []
} catch (error) {
logger.error('Failed to parse forEach items', { items, error })
return []
}
}
try {
const resolved = this.resolver.resolveInputs(ctx, 'loop_foreach_items', { items }).items
if (Array.isArray(resolved)) {
return resolved
}
logger.warn('ForEach items did not resolve to array', {
items,
resolved,
})
return []
} catch (error: any) {
logger.error('Error resolving forEach items, returning empty array:', {
error: error.message,
})
return []
}
return resolveArrayInput(ctx, items, this.resolver)
}
}

View File

@@ -1,15 +1,19 @@
import { createLogger } from '@/lib/logs/console/logger'
import { DEFAULTS } from '@/executor/constants'
import type { DAG, DAGNode } from '@/executor/dag/builder'
import type { ParallelScope } from '@/executor/execution/state'
import type { BlockStateWriter } from '@/executor/execution/types'
import type { BlockStateWriter, ContextExtensions } from '@/executor/execution/types'
import type { ExecutionContext, NormalizedBlockOutput } from '@/executor/types'
import type { ParallelConfigWithNodes } from '@/executor/types/parallel'
import {
addSubflowErrorLog,
buildBranchNodeId,
calculateBranchCount,
extractBaseBlockId,
extractBranchIndex,
parseDistributionItems,
resolveArrayInput,
validateMaxCount,
} from '@/executor/utils/subflow-utils'
import type { VariableResolver } from '@/executor/variables/resolver'
import type { SerializedParallel } from '@/serializer/types'
@@ -32,6 +36,7 @@ export interface ParallelAggregationResult {
export class ParallelOrchestrator {
private resolver: VariableResolver | null = null
private contextExtensions: ContextExtensions | null = null
constructor(
private dag: DAG,
@@ -42,6 +47,10 @@ export class ParallelOrchestrator {
this.resolver = resolver
}
setContextExtensions(contextExtensions: ContextExtensions): void {
this.contextExtensions = contextExtensions
}
initializeParallelScope(
ctx: ExecutionContext,
parallelId: string,
@@ -49,11 +58,42 @@ export class ParallelOrchestrator {
terminalNodesCount = 1
): ParallelScope {
const parallelConfig = this.dag.parallelConfigs.get(parallelId)
const items = parallelConfig ? this.resolveDistributionItems(ctx, parallelConfig) : undefined
// If we have more items than pre-built branches, expand the DAG
let items: any[] | undefined
if (parallelConfig) {
try {
items = this.resolveDistributionItems(ctx, parallelConfig)
} catch (error) {
const errorMessage = `Parallel Items did not resolve: ${error instanceof Error ? error.message : String(error)}`
logger.error(errorMessage, {
parallelId,
distribution: parallelConfig.distribution,
})
this.addParallelErrorLog(ctx, parallelId, errorMessage, {
distribution: parallelConfig.distribution,
})
this.setErrorScope(ctx, parallelId, errorMessage)
throw new Error(errorMessage)
}
}
const actualBranchCount = items && items.length > totalBranches ? items.length : totalBranches
const branchError = validateMaxCount(
actualBranchCount,
DEFAULTS.MAX_PARALLEL_BRANCHES,
'Parallel branch count'
)
if (branchError) {
logger.error(branchError, { parallelId, actualBranchCount })
this.addParallelErrorLog(ctx, parallelId, branchError, {
distribution: parallelConfig?.distribution,
branchCount: actualBranchCount,
})
this.setErrorScope(ctx, parallelId, branchError)
throw new Error(branchError)
}
const scope: ParallelScope = {
parallelId,
totalBranches: actualBranchCount,
@@ -108,6 +148,38 @@ export class ParallelOrchestrator {
return scope
}
private addParallelErrorLog(
ctx: ExecutionContext,
parallelId: string,
errorMessage: string,
inputData?: any
): void {
addSubflowErrorLog(
ctx,
parallelId,
'parallel',
errorMessage,
inputData || {},
this.contextExtensions
)
}
private setErrorScope(ctx: ExecutionContext, parallelId: string, errorMessage: string): void {
const scope: ParallelScope = {
parallelId,
totalBranches: 0,
branchOutputs: new Map(),
completedCount: 0,
totalExpectedNodes: 0,
items: [],
validationError: errorMessage,
}
if (!ctx.parallelExecutions) {
ctx.parallelExecutions = new Map()
}
ctx.parallelExecutions.set(parallelId, scope)
}
/**
* Dynamically expand the DAG to include additional branch nodes when
* the resolved item count exceeds the pre-built branch count.
@@ -291,63 +363,19 @@ export class ParallelOrchestrator {
}
}
/**
* Resolve distribution items at runtime, handling references like <previousBlock.items>
* This mirrors how LoopOrchestrator.resolveForEachItems works.
*/
private resolveDistributionItems(ctx: ExecutionContext, config: SerializedParallel): any[] {
const rawItems = config.distribution
if (rawItems === undefined || rawItems === null) {
if (config.parallelType === 'count') {
return []
}
// Already an array - return as-is
if (Array.isArray(rawItems)) {
return rawItems
if (
config.distribution === undefined ||
config.distribution === null ||
config.distribution === ''
) {
return []
}
// Object - convert to entries array (consistent with loop forEach behavior)
if (typeof rawItems === 'object') {
return Object.entries(rawItems)
}
// String handling
if (typeof rawItems === 'string') {
// Resolve references at runtime using the variable resolver
if (rawItems.startsWith('<') && rawItems.endsWith('>') && this.resolver) {
const resolved = this.resolver.resolveSingleReference(ctx, '', rawItems)
if (Array.isArray(resolved)) {
return resolved
}
if (typeof resolved === 'object' && resolved !== null) {
return Object.entries(resolved)
}
logger.warn('Distribution reference did not resolve to array or object', {
rawItems,
resolved,
})
return []
}
// Try to parse as JSON
try {
const normalized = rawItems.replace(/'/g, '"')
const parsed = JSON.parse(normalized)
if (Array.isArray(parsed)) {
return parsed
}
if (typeof parsed === 'object' && parsed !== null) {
return Object.entries(parsed)
}
return []
} catch (error) {
logger.error('Failed to parse distribution items', { rawItems, error })
return []
}
}
return []
return resolveArrayInput(ctx, config.distribution, this.resolver)
}
handleParallelBranchCompletion(

View File

@@ -1,5 +1,8 @@
import { createLogger } from '@/lib/logs/console/logger'
import { LOOP, PARALLEL, PARSING, REFERENCE } from '@/executor/constants'
import type { ContextExtensions } from '@/executor/execution/types'
import type { BlockLog, ExecutionContext } from '@/executor/types'
import type { VariableResolver } from '@/executor/variables/resolver'
import type { SerializedParallel } from '@/serializer/types'
const logger = createLogger('SubflowUtils')
@@ -132,3 +135,131 @@ export function normalizeNodeId(nodeId: string): string {
}
return nodeId
}
/**
* Validates that a count doesn't exceed a maximum limit.
* Returns an error message if validation fails, undefined otherwise.
*/
export function validateMaxCount(count: number, max: number, itemType: string): string | undefined {
if (count > max) {
return `${itemType} (${count}) exceeds maximum allowed (${max}). Execution blocked.`
}
return undefined
}
/**
* Resolves array input at runtime. Handles arrays, objects, references, and JSON strings.
* Used by both loop forEach and parallel distribution resolution.
* Throws an error if resolution fails.
*/
export function resolveArrayInput(
ctx: ExecutionContext,
items: any,
resolver: VariableResolver | null
): any[] {
if (Array.isArray(items)) {
return items
}
if (typeof items === 'object' && items !== null) {
return Object.entries(items)
}
if (typeof items === 'string') {
if (items.startsWith(REFERENCE.START) && items.endsWith(REFERENCE.END) && resolver) {
try {
const resolved = resolver.resolveSingleReference(ctx, '', items)
if (Array.isArray(resolved)) {
return resolved
}
if (typeof resolved === 'object' && resolved !== null) {
return Object.entries(resolved)
}
throw new Error(`Reference "${items}" did not resolve to an array or object`)
} catch (error) {
if (error instanceof Error && error.message.startsWith('Reference "')) {
throw error
}
throw new Error(
`Failed to resolve reference "${items}": ${error instanceof Error ? error.message : String(error)}`
)
}
}
try {
const normalized = items.replace(/'/g, '"')
const parsed = JSON.parse(normalized)
if (Array.isArray(parsed)) {
return parsed
}
if (typeof parsed === 'object' && parsed !== null) {
return Object.entries(parsed)
}
throw new Error(`Parsed value is not an array or object`)
} catch (error) {
if (error instanceof Error && error.message.startsWith('Parsed value')) {
throw error
}
throw new Error(`Failed to parse items as JSON: "${items}"`)
}
}
if (resolver) {
try {
const resolved = resolver.resolveInputs(ctx, 'subflow_items', { items }).items
if (Array.isArray(resolved)) {
return resolved
}
throw new Error(`Resolved items is not an array`)
} catch (error) {
if (error instanceof Error && error.message.startsWith('Resolved items')) {
throw error
}
throw new Error(
`Failed to resolve items: ${error instanceof Error ? error.message : String(error)}`
)
}
}
return []
}
/**
* Creates and logs an error for a subflow (loop or parallel).
*/
export function addSubflowErrorLog(
ctx: ExecutionContext,
blockId: string,
blockType: 'loop' | 'parallel',
errorMessage: string,
inputData: Record<string, any>,
contextExtensions: ContextExtensions | null
): void {
const now = new Date().toISOString()
const block = ctx.workflow?.blocks?.find((b) => b.id === blockId)
const blockName = block?.metadata?.name || (blockType === 'loop' ? 'Loop' : 'Parallel')
const blockLog: BlockLog = {
blockId,
blockName,
blockType,
startedAt: now,
endedAt: now,
durationMs: 0,
success: false,
error: errorMessage,
input: inputData,
output: { error: errorMessage },
...(blockType === 'loop' ? { loopId: blockId } : { parallelId: blockId }),
}
ctx.blockLogs.push(blockLog)
if (contextExtensions?.onBlockComplete) {
contextExtensions.onBlockComplete(blockId, blockName, blockType, {
input: inputData,
output: { error: errorMessage },
executionTime: 0,
})
}
}

View File

@@ -579,6 +579,21 @@ export const auth = betterAuth({
redirectURI: `${getBaseUrl()}/api/auth/oauth2/callback/google-groups`,
},
{
providerId: 'vertex-ai',
clientId: env.GOOGLE_CLIENT_ID as string,
clientSecret: env.GOOGLE_CLIENT_SECRET as string,
discoveryUrl: 'https://accounts.google.com/.well-known/openid-configuration',
accessType: 'offline',
scopes: [
'https://www.googleapis.com/auth/userinfo.email',
'https://www.googleapis.com/auth/userinfo.profile',
'https://www.googleapis.com/auth/cloud-platform',
],
prompt: 'consent',
redirectURI: `${getBaseUrl()}/api/auth/oauth2/callback/vertex-ai`,
},
{
providerId: 'microsoft-teams',
clientId: env.MICROSOFT_CLIENT_ID as string,

View File

@@ -36,6 +36,8 @@ export const ToolIds = z.enum([
'manage_custom_tool',
'manage_mcp_tool',
'sleep',
'get_block_outputs',
'get_block_upstream_references',
])
export type ToolId = z.infer<typeof ToolIds>
@@ -277,6 +279,24 @@ export const ToolArgSchemas = {
.max(180)
.describe('The number of seconds to sleep (0-180, max 3 minutes)'),
}),
get_block_outputs: z.object({
blockIds: z
.array(z.string())
.optional()
.describe(
'Optional array of block UUIDs. If provided, returns outputs only for those blocks. If not provided, returns outputs for all blocks in the workflow.'
),
}),
get_block_upstream_references: z.object({
blockIds: z
.array(z.string())
.min(1)
.describe(
'Array of block UUIDs. Returns all upstream references (block outputs and variables) accessible to each block based on workflow connections.'
),
}),
} as const
export type ToolArgSchemaMap = typeof ToolArgSchemas
@@ -346,6 +366,11 @@ export const ToolSSESchemas = {
manage_custom_tool: toolCallSSEFor('manage_custom_tool', ToolArgSchemas.manage_custom_tool),
manage_mcp_tool: toolCallSSEFor('manage_mcp_tool', ToolArgSchemas.manage_mcp_tool),
sleep: toolCallSSEFor('sleep', ToolArgSchemas.sleep),
get_block_outputs: toolCallSSEFor('get_block_outputs', ToolArgSchemas.get_block_outputs),
get_block_upstream_references: toolCallSSEFor(
'get_block_upstream_references',
ToolArgSchemas.get_block_upstream_references
),
} as const
export type ToolSSESchemaMap = typeof ToolSSESchemas
@@ -603,6 +628,60 @@ export const ToolResultSchemas = {
seconds: z.number(),
message: z.string().optional(),
}),
get_block_outputs: z.object({
blocks: z.array(
z.object({
blockId: z.string(),
blockName: z.string(),
blockType: z.string(),
outputs: z.array(z.string()),
insideSubflowOutputs: z.array(z.string()).optional(),
outsideSubflowOutputs: z.array(z.string()).optional(),
})
),
variables: z.array(
z.object({
id: z.string(),
name: z.string(),
type: z.string(),
tag: z.string(),
})
),
}),
get_block_upstream_references: z.object({
results: z.array(
z.object({
blockId: z.string(),
blockName: z.string(),
insideSubflows: z
.array(
z.object({
blockId: z.string(),
blockName: z.string(),
blockType: z.string(),
})
)
.optional(),
accessibleBlocks: z.array(
z.object({
blockId: z.string(),
blockName: z.string(),
blockType: z.string(),
outputs: z.array(z.string()),
accessContext: z.enum(['inside', 'outside']).optional(),
})
),
variables: z.array(
z.object({
id: z.string(),
name: z.string(),
type: z.string(),
tag: z.string(),
})
),
})
),
}),
} as const
export type ToolResultSchemaMap = typeof ToolResultSchemas

View File

@@ -0,0 +1,142 @@
import {
extractFieldsFromSchema,
parseResponseFormatSafely,
} from '@/lib/core/utils/response-format'
import { getBlockOutputPaths } from '@/lib/workflows/blocks/block-outputs'
import { getBlock } from '@/blocks'
import { useVariablesStore } from '@/stores/panel/variables/store'
import type { Variable } from '@/stores/panel/variables/types'
import { useSubBlockStore } from '@/stores/workflows/subblock/store'
import { normalizeName } from '@/stores/workflows/utils'
import type { BlockState, Loop, Parallel } from '@/stores/workflows/workflow/types'
export interface WorkflowContext {
workflowId: string
blocks: Record<string, BlockState>
loops: Record<string, Loop>
parallels: Record<string, Parallel>
subBlockValues: Record<string, Record<string, any>>
}
export interface VariableOutput {
id: string
name: string
type: string
tag: string
}
export function getWorkflowSubBlockValues(workflowId: string): Record<string, Record<string, any>> {
const subBlockStore = useSubBlockStore.getState()
return subBlockStore.workflowValues[workflowId] ?? {}
}
export function getMergedSubBlocks(
blocks: Record<string, BlockState>,
subBlockValues: Record<string, Record<string, any>>,
targetBlockId: string
): Record<string, any> {
const base = blocks[targetBlockId]?.subBlocks || {}
const live = subBlockValues?.[targetBlockId] || {}
const merged: Record<string, any> = { ...base }
for (const [subId, liveVal] of Object.entries(live)) {
merged[subId] = { ...(base[subId] || {}), value: liveVal }
}
return merged
}
export function getSubBlockValue(
blocks: Record<string, BlockState>,
subBlockValues: Record<string, Record<string, any>>,
targetBlockId: string,
subBlockId: string
): any {
const live = subBlockValues?.[targetBlockId]?.[subBlockId]
if (live !== undefined) return live
return blocks[targetBlockId]?.subBlocks?.[subBlockId]?.value
}
export function getWorkflowVariables(workflowId: string): VariableOutput[] {
const getVariablesByWorkflowId = useVariablesStore.getState().getVariablesByWorkflowId
const workflowVariables = getVariablesByWorkflowId(workflowId)
const validVariables = workflowVariables.filter(
(variable: Variable) => variable.name.trim() !== ''
)
return validVariables.map((variable: Variable) => ({
id: variable.id,
name: variable.name,
type: variable.type,
tag: `variable.${normalizeName(variable.name)}`,
}))
}
export function getSubflowInsidePaths(
blockType: 'loop' | 'parallel',
blockId: string,
loops: Record<string, Loop>,
parallels: Record<string, Parallel>
): string[] {
const paths = ['index']
if (blockType === 'loop') {
const loopType = loops[blockId]?.loopType || 'for'
if (loopType === 'forEach') {
paths.push('currentItem', 'items')
}
} else {
const parallelType = parallels[blockId]?.parallelType || 'count'
if (parallelType === 'collection') {
paths.push('currentItem', 'items')
}
}
return paths
}
export function computeBlockOutputPaths(block: BlockState, ctx: WorkflowContext): string[] {
const { blocks, loops, parallels, subBlockValues } = ctx
const blockConfig = getBlock(block.type)
const mergedSubBlocks = getMergedSubBlocks(blocks, subBlockValues, block.id)
if (block.type === 'loop' || block.type === 'parallel') {
const insidePaths = getSubflowInsidePaths(block.type, block.id, loops, parallels)
return ['results', ...insidePaths]
}
if (block.type === 'evaluator') {
const metricsValue = getSubBlockValue(blocks, subBlockValues, block.id, 'metrics')
if (metricsValue && Array.isArray(metricsValue) && metricsValue.length > 0) {
const validMetrics = metricsValue.filter((metric: { name?: string }) => metric?.name)
return validMetrics.map((metric: { name: string }) => metric.name.toLowerCase())
}
return getBlockOutputPaths(block.type, mergedSubBlocks)
}
if (block.type === 'variables') {
const variablesValue = getSubBlockValue(blocks, subBlockValues, block.id, 'variables')
if (variablesValue && Array.isArray(variablesValue) && variablesValue.length > 0) {
const validAssignments = variablesValue.filter((assignment: { variableName?: string }) =>
assignment?.variableName?.trim()
)
return validAssignments.map((assignment: { variableName: string }) =>
assignment.variableName.trim()
)
}
return []
}
if (blockConfig) {
const responseFormatValue = mergedSubBlocks?.responseFormat?.value
const responseFormat = parseResponseFormatSafely(responseFormatValue, block.id)
if (responseFormat) {
const schemaFields = extractFieldsFromSchema(responseFormat)
if (schemaFields.length > 0) {
return schemaFields.map((field) => field.name)
}
}
}
return getBlockOutputPaths(block.type, mergedSubBlocks, block.triggerMode)
}
export function formatOutputsWithPrefix(paths: string[], blockName: string): string[] {
const normalizedName = normalizeName(blockName)
return paths.map((path) => `${normalizedName}.${path}`)
}

View File

@@ -0,0 +1,144 @@
import { Loader2, Tag, X, XCircle } from 'lucide-react'
import {
BaseClientTool,
type BaseClientToolMetadata,
ClientToolCallState,
} from '@/lib/copilot/tools/client/base-tool'
import {
computeBlockOutputPaths,
formatOutputsWithPrefix,
getSubflowInsidePaths,
getWorkflowSubBlockValues,
getWorkflowVariables,
} from '@/lib/copilot/tools/client/workflow/block-output-utils'
import {
GetBlockOutputsResult,
type GetBlockOutputsResultType,
} from '@/lib/copilot/tools/shared/schemas'
import { createLogger } from '@/lib/logs/console/logger'
import { useWorkflowRegistry } from '@/stores/workflows/registry/store'
import { normalizeName } from '@/stores/workflows/utils'
import { useWorkflowStore } from '@/stores/workflows/workflow/store'
const logger = createLogger('GetBlockOutputsClientTool')
interface GetBlockOutputsArgs {
blockIds?: string[]
}
export class GetBlockOutputsClientTool extends BaseClientTool {
static readonly id = 'get_block_outputs'
constructor(toolCallId: string) {
super(toolCallId, GetBlockOutputsClientTool.id, GetBlockOutputsClientTool.metadata)
}
static readonly metadata: BaseClientToolMetadata = {
displayNames: {
[ClientToolCallState.generating]: { text: 'Getting block outputs', icon: Loader2 },
[ClientToolCallState.pending]: { text: 'Getting block outputs', icon: Tag },
[ClientToolCallState.executing]: { text: 'Getting block outputs', icon: Loader2 },
[ClientToolCallState.aborted]: { text: 'Aborted getting outputs', icon: XCircle },
[ClientToolCallState.success]: { text: 'Retrieved block outputs', icon: Tag },
[ClientToolCallState.error]: { text: 'Failed to get outputs', icon: X },
[ClientToolCallState.rejected]: { text: 'Skipped getting outputs', icon: XCircle },
},
getDynamicText: (params, state) => {
const blockIds = params?.blockIds
if (blockIds && Array.isArray(blockIds) && blockIds.length > 0) {
const count = blockIds.length
switch (state) {
case ClientToolCallState.success:
return `Retrieved outputs for ${count} block${count > 1 ? 's' : ''}`
case ClientToolCallState.executing:
case ClientToolCallState.generating:
case ClientToolCallState.pending:
return `Getting outputs for ${count} block${count > 1 ? 's' : ''}`
case ClientToolCallState.error:
return `Failed to get outputs for ${count} block${count > 1 ? 's' : ''}`
}
}
return undefined
},
}
async execute(args?: GetBlockOutputsArgs): Promise<void> {
try {
this.setState(ClientToolCallState.executing)
const { activeWorkflowId } = useWorkflowRegistry.getState()
if (!activeWorkflowId) {
await this.markToolComplete(400, 'No active workflow found')
this.setState(ClientToolCallState.error)
return
}
const workflowStore = useWorkflowStore.getState()
const blocks = workflowStore.blocks || {}
const loops = workflowStore.loops || {}
const parallels = workflowStore.parallels || {}
const subBlockValues = getWorkflowSubBlockValues(activeWorkflowId)
const ctx = { workflowId: activeWorkflowId, blocks, loops, parallels, subBlockValues }
const targetBlockIds =
args?.blockIds && args.blockIds.length > 0 ? args.blockIds : Object.keys(blocks)
const blockOutputs: GetBlockOutputsResultType['blocks'] = []
for (const blockId of targetBlockIds) {
const block = blocks[blockId]
if (!block?.type) continue
const blockName = block.name || block.type
const normalizedBlockName = normalizeName(blockName)
let insideSubflowOutputs: string[] | undefined
let outsideSubflowOutputs: string[] | undefined
const blockOutput: GetBlockOutputsResultType['blocks'][0] = {
blockId,
blockName,
blockType: block.type,
outputs: [],
}
if (block.type === 'loop' || block.type === 'parallel') {
const insidePaths = getSubflowInsidePaths(block.type, blockId, loops, parallels)
blockOutput.insideSubflowOutputs = formatOutputsWithPrefix(insidePaths, blockName)
blockOutput.outsideSubflowOutputs = formatOutputsWithPrefix(['results'], blockName)
} else {
const outputPaths = computeBlockOutputPaths(block, ctx)
blockOutput.outputs = formatOutputsWithPrefix(outputPaths, blockName)
}
blockOutputs.push(blockOutput)
}
const includeVariables = !args?.blockIds || args.blockIds.length === 0
const resultData: {
blocks: typeof blockOutputs
variables?: ReturnType<typeof getWorkflowVariables>
} = {
blocks: blockOutputs,
}
if (includeVariables) {
resultData.variables = getWorkflowVariables(activeWorkflowId)
}
const result = GetBlockOutputsResult.parse(resultData)
logger.info('Retrieved block outputs', {
blockCount: blockOutputs.length,
variableCount: resultData.variables?.length ?? 0,
})
await this.markToolComplete(200, 'Retrieved block outputs', result)
this.setState(ClientToolCallState.success)
} catch (error: any) {
const message = error instanceof Error ? error.message : String(error)
logger.error('Error in tool execution', { toolCallId: this.toolCallId, error, message })
await this.markToolComplete(500, message || 'Failed to get block outputs')
this.setState(ClientToolCallState.error)
}
}
}

View File

@@ -0,0 +1,227 @@
import { GitBranch, Loader2, X, XCircle } from 'lucide-react'
import {
BaseClientTool,
type BaseClientToolMetadata,
ClientToolCallState,
} from '@/lib/copilot/tools/client/base-tool'
import {
computeBlockOutputPaths,
formatOutputsWithPrefix,
getSubflowInsidePaths,
getWorkflowSubBlockValues,
getWorkflowVariables,
} from '@/lib/copilot/tools/client/workflow/block-output-utils'
import {
GetBlockUpstreamReferencesResult,
type GetBlockUpstreamReferencesResultType,
} from '@/lib/copilot/tools/shared/schemas'
import { createLogger } from '@/lib/logs/console/logger'
import { BlockPathCalculator } from '@/lib/workflows/blocks/block-path-calculator'
import { useWorkflowRegistry } from '@/stores/workflows/registry/store'
import { useWorkflowStore } from '@/stores/workflows/workflow/store'
import type { Loop, Parallel } from '@/stores/workflows/workflow/types'
const logger = createLogger('GetBlockUpstreamReferencesClientTool')
interface GetBlockUpstreamReferencesArgs {
blockIds: string[]
}
export class GetBlockUpstreamReferencesClientTool extends BaseClientTool {
static readonly id = 'get_block_upstream_references'
constructor(toolCallId: string) {
super(
toolCallId,
GetBlockUpstreamReferencesClientTool.id,
GetBlockUpstreamReferencesClientTool.metadata
)
}
static readonly metadata: BaseClientToolMetadata = {
displayNames: {
[ClientToolCallState.generating]: { text: 'Getting upstream references', icon: Loader2 },
[ClientToolCallState.pending]: { text: 'Getting upstream references', icon: GitBranch },
[ClientToolCallState.executing]: { text: 'Getting upstream references', icon: Loader2 },
[ClientToolCallState.aborted]: { text: 'Aborted getting references', icon: XCircle },
[ClientToolCallState.success]: { text: 'Retrieved upstream references', icon: GitBranch },
[ClientToolCallState.error]: { text: 'Failed to get references', icon: X },
[ClientToolCallState.rejected]: { text: 'Skipped getting references', icon: XCircle },
},
getDynamicText: (params, state) => {
const blockIds = params?.blockIds
if (blockIds && Array.isArray(blockIds) && blockIds.length > 0) {
const count = blockIds.length
switch (state) {
case ClientToolCallState.success:
return `Retrieved references for ${count} block${count > 1 ? 's' : ''}`
case ClientToolCallState.executing:
case ClientToolCallState.generating:
case ClientToolCallState.pending:
return `Getting references for ${count} block${count > 1 ? 's' : ''}`
case ClientToolCallState.error:
return `Failed to get references for ${count} block${count > 1 ? 's' : ''}`
}
}
return undefined
},
}
async execute(args?: GetBlockUpstreamReferencesArgs): Promise<void> {
try {
this.setState(ClientToolCallState.executing)
if (!args?.blockIds || args.blockIds.length === 0) {
await this.markToolComplete(400, 'blockIds array is required')
this.setState(ClientToolCallState.error)
return
}
const { activeWorkflowId } = useWorkflowRegistry.getState()
if (!activeWorkflowId) {
await this.markToolComplete(400, 'No active workflow found')
this.setState(ClientToolCallState.error)
return
}
const workflowStore = useWorkflowStore.getState()
const blocks = workflowStore.blocks || {}
const edges = workflowStore.edges || []
const loops = workflowStore.loops || {}
const parallels = workflowStore.parallels || {}
const subBlockValues = getWorkflowSubBlockValues(activeWorkflowId)
const ctx = { workflowId: activeWorkflowId, blocks, loops, parallels, subBlockValues }
const variableOutputs = getWorkflowVariables(activeWorkflowId)
const graphEdges = edges.map((edge) => ({ source: edge.source, target: edge.target }))
const results: GetBlockUpstreamReferencesResultType['results'] = []
for (const blockId of args.blockIds) {
const targetBlock = blocks[blockId]
if (!targetBlock) {
logger.warn(`Block ${blockId} not found`)
continue
}
const insideSubflows: { blockId: string; blockName: string; blockType: string }[] = []
const containingLoopIds = new Set<string>()
const containingParallelIds = new Set<string>()
Object.values(loops as Record<string, Loop>).forEach((loop) => {
if (loop?.nodes?.includes(blockId)) {
containingLoopIds.add(loop.id)
const loopBlock = blocks[loop.id]
if (loopBlock) {
insideSubflows.push({
blockId: loop.id,
blockName: loopBlock.name || loopBlock.type,
blockType: 'loop',
})
}
}
})
Object.values(parallels as Record<string, Parallel>).forEach((parallel) => {
if (parallel?.nodes?.includes(blockId)) {
containingParallelIds.add(parallel.id)
const parallelBlock = blocks[parallel.id]
if (parallelBlock) {
insideSubflows.push({
blockId: parallel.id,
blockName: parallelBlock.name || parallelBlock.type,
blockType: 'parallel',
})
}
}
})
const ancestorIds = BlockPathCalculator.findAllPathNodes(graphEdges, blockId)
const accessibleIds = new Set<string>(ancestorIds)
accessibleIds.add(blockId)
const starterBlock = Object.values(blocks).find(
(b) => b.type === 'starter' || b.type === 'start_trigger'
)
if (starterBlock && ancestorIds.includes(starterBlock.id)) {
accessibleIds.add(starterBlock.id)
}
containingLoopIds.forEach((loopId) => {
accessibleIds.add(loopId)
loops[loopId]?.nodes?.forEach((nodeId) => accessibleIds.add(nodeId))
})
containingParallelIds.forEach((parallelId) => {
accessibleIds.add(parallelId)
parallels[parallelId]?.nodes?.forEach((nodeId) => accessibleIds.add(nodeId))
})
const accessibleBlocks: GetBlockUpstreamReferencesResultType['results'][0]['accessibleBlocks'] =
[]
for (const accessibleBlockId of accessibleIds) {
const block = blocks[accessibleBlockId]
if (!block?.type) continue
const canSelfReference = block.type === 'approval' || block.type === 'human_in_the_loop'
if (accessibleBlockId === blockId && !canSelfReference) continue
const blockName = block.name || block.type
let accessContext: 'inside' | 'outside' | undefined
let outputPaths: string[]
if (block.type === 'loop' || block.type === 'parallel') {
const isInside =
(block.type === 'loop' && containingLoopIds.has(accessibleBlockId)) ||
(block.type === 'parallel' && containingParallelIds.has(accessibleBlockId))
accessContext = isInside ? 'inside' : 'outside'
outputPaths = isInside
? getSubflowInsidePaths(block.type, accessibleBlockId, loops, parallels)
: ['results']
} else {
outputPaths = computeBlockOutputPaths(block, ctx)
}
const formattedOutputs = formatOutputsWithPrefix(outputPaths, blockName)
const entry: GetBlockUpstreamReferencesResultType['results'][0]['accessibleBlocks'][0] = {
blockId: accessibleBlockId,
blockName,
blockType: block.type,
outputs: formattedOutputs,
}
if (accessContext) entry.accessContext = accessContext
accessibleBlocks.push(entry)
}
const resultEntry: GetBlockUpstreamReferencesResultType['results'][0] = {
blockId,
blockName: targetBlock.name || targetBlock.type,
accessibleBlocks,
variables: variableOutputs,
}
if (insideSubflows.length > 0) resultEntry.insideSubflows = insideSubflows
results.push(resultEntry)
}
const result = GetBlockUpstreamReferencesResult.parse({ results })
logger.info('Retrieved upstream references', {
blockIds: args.blockIds,
resultCount: results.length,
})
await this.markToolComplete(200, 'Retrieved upstream references', result)
this.setState(ClientToolCallState.success)
} catch (error: any) {
const message = error instanceof Error ? error.message : String(error)
logger.error('Error in tool execution', { toolCallId: this.toolCallId, error, message })
await this.markToolComplete(500, message || 'Failed to get upstream references')
this.setState(ClientToolCallState.error)
}
}
}

View File

@@ -104,3 +104,71 @@ export const KnowledgeBaseResultSchema = z.object({
data: z.any().optional(),
})
export type KnowledgeBaseResult = z.infer<typeof KnowledgeBaseResultSchema>
export const GetBlockOutputsInput = z.object({
blockIds: z.array(z.string()).optional(),
})
export const GetBlockOutputsResult = z.object({
blocks: z.array(
z.object({
blockId: z.string(),
blockName: z.string(),
blockType: z.string(),
outputs: z.array(z.string()),
insideSubflowOutputs: z.array(z.string()).optional(),
outsideSubflowOutputs: z.array(z.string()).optional(),
})
),
variables: z
.array(
z.object({
id: z.string(),
name: z.string(),
type: z.string(),
tag: z.string(),
})
)
.optional(),
})
export type GetBlockOutputsInputType = z.infer<typeof GetBlockOutputsInput>
export type GetBlockOutputsResultType = z.infer<typeof GetBlockOutputsResult>
export const GetBlockUpstreamReferencesInput = z.object({
blockIds: z.array(z.string()).min(1),
})
export const GetBlockUpstreamReferencesResult = z.object({
results: z.array(
z.object({
blockId: z.string(),
blockName: z.string(),
insideSubflows: z
.array(
z.object({
blockId: z.string(),
blockName: z.string(),
blockType: z.string(),
})
)
.optional(),
accessibleBlocks: z.array(
z.object({
blockId: z.string(),
blockName: z.string(),
blockType: z.string(),
outputs: z.array(z.string()),
accessContext: z.enum(['inside', 'outside']).optional(),
})
),
variables: z.array(
z.object({
id: z.string(),
name: z.string(),
type: z.string(),
tag: z.string(),
})
),
})
),
})
export type GetBlockUpstreamReferencesInputType = z.infer<typeof GetBlockUpstreamReferencesInput>
export type GetBlockUpstreamReferencesResultType = z.infer<typeof GetBlockUpstreamReferencesResult>

View File

@@ -41,7 +41,7 @@ function filterUserFile(data: any): any {
const DISPLAY_FILTERS = [filterUserFile]
export function filterForDisplay(data: any): any {
const seen = new WeakSet()
const seen = new Set<object>()
return filterForDisplayInternal(data, seen, 0)
}
@@ -49,7 +49,7 @@ function getObjectType(data: unknown): string {
return Object.prototype.toString.call(data).slice(8, -1)
}
function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: number): any {
function filterForDisplayInternal(data: any, seen: Set<object>, depth: number): any {
try {
if (data === null || data === undefined) {
return data
@@ -93,6 +93,7 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
return '[Unknown Type]'
}
// True circular reference: object is an ancestor in the current path
if (seen.has(data)) {
return '[Circular Reference]'
}
@@ -131,18 +132,24 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
return `[ArrayBuffer: ${(data as ArrayBuffer).byteLength} bytes]`
case 'Map': {
seen.add(data)
const obj: Record<string, any> = {}
for (const [key, value] of (data as Map<any, any>).entries()) {
const keyStr = typeof key === 'string' ? key : String(key)
obj[keyStr] = filterForDisplayInternal(value, seen, depth + 1)
}
seen.delete(data)
return obj
}
case 'Set':
return Array.from(data as Set<any>).map((item) =>
case 'Set': {
seen.add(data)
const result = Array.from(data as Set<any>).map((item) =>
filterForDisplayInternal(item, seen, depth + 1)
)
seen.delete(data)
return result
}
case 'WeakMap':
return '[WeakMap]'
@@ -161,17 +168,22 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
return `[${objectType}: ${(data as ArrayBufferView).byteLength} bytes]`
}
// Add to current path before processing children
seen.add(data)
for (const filterFn of DISPLAY_FILTERS) {
const result = filterFn(data)
if (result !== data) {
return filterForDisplayInternal(result, seen, depth + 1)
const filtered = filterFn(data)
if (filtered !== data) {
const result = filterForDisplayInternal(filtered, seen, depth + 1)
seen.delete(data)
return result
}
}
if (Array.isArray(data)) {
return data.map((item) => filterForDisplayInternal(item, seen, depth + 1))
const result = data.map((item) => filterForDisplayInternal(item, seen, depth + 1))
seen.delete(data)
return result
}
const result: Record<string, any> = {}
@@ -182,6 +194,8 @@ function filterForDisplayInternal(data: any, seen: WeakSet<object>, depth: numbe
result[key] = '[Error accessing property]'
}
}
// Remove from current path after processing children
seen.delete(data)
return result
} catch {
return '[Unserializable]'

View File

@@ -471,8 +471,10 @@ function groupIterationBlocks(spans: TraceSpan[]): TraceSpan[] {
}
})
// Include loop/parallel spans that have errors (e.g., validation errors that blocked execution)
// These won't have iteration children, so they should appear directly in results
const nonIterationContainerSpans = normalSpans.filter(
(span) => span.type !== 'parallel' && span.type !== 'loop'
(span) => (span.type !== 'parallel' && span.type !== 'loop') || span.status === 'error'
)
if (iterationSpans.length > 0) {

View File

@@ -32,6 +32,7 @@ import {
SlackIcon,
SpotifyIcon,
TrelloIcon,
VertexIcon,
WealthboxIcon,
WebflowIcon,
WordpressIcon,
@@ -80,6 +81,7 @@ export type OAuthService =
| 'google-vault'
| 'google-forms'
| 'google-groups'
| 'vertex-ai'
| 'github'
| 'x'
| 'confluence'
@@ -237,6 +239,16 @@ export const OAUTH_PROVIDERS: Record<string, OAuthProviderConfig> = {
],
scopeHints: ['admin.directory.group'],
},
'vertex-ai': {
id: 'vertex-ai',
name: 'Vertex AI',
description: 'Access Google Cloud Vertex AI for Gemini models with OAuth.',
providerId: 'vertex-ai',
icon: (props) => VertexIcon(props),
baseProviderIcon: (props) => VertexIcon(props),
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
scopeHints: ['cloud-platform', 'vertex', 'aiplatform'],
},
},
defaultService: 'gmail',
},
@@ -1099,6 +1111,12 @@ export function parseProvider(provider: OAuthProvider): ProviderConfig {
featureType: 'microsoft-planner',
}
}
if (provider === 'vertex-ai') {
return {
baseProvider: 'google',
featureType: 'vertex-ai',
}
}
// Handle compound providers (e.g., 'google-email' -> { baseProvider: 'google', featureType: 'email' })
const [base, feature] = provider.split('-')

View File

@@ -58,7 +58,7 @@ export const anthropicProvider: ProviderConfig = {
throw new Error('API key is required for Anthropic')
}
const modelId = request.model || 'claude-3-7-sonnet-20250219'
const modelId = request.model
const useNativeStructuredOutputs = !!(
request.responseFormat && supportsNativeStructuredOutputs(modelId)
)
@@ -174,7 +174,7 @@ export const anthropicProvider: ProviderConfig = {
}
const payload: any = {
model: request.model || 'claude-3-7-sonnet-20250219',
model: request.model,
messages,
system: systemPrompt,
max_tokens: Number.parseInt(String(request.maxTokens)) || 1024,
@@ -561,37 +561,93 @@ export const anthropicProvider: ProviderConfig = {
throw error
}
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
return {
content,
model: request.model || 'claude-3-7-sonnet-20250219',
tokens,
toolCalls:
toolCalls.length > 0
? toolCalls.map((tc) => ({
name: tc.name,
arguments: tc.arguments as Record<string, any>,
startTime: tc.startTime,
endTime: tc.endTime,
duration: tc.duration,
result: tc.result,
}))
: undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
timing: {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
modelTime: modelTime,
toolsTime: toolsTime,
firstResponseTime: firstResponseTime,
iterations: iterationCount + 1,
timeSegments: timeSegments,
const streamingPayload = {
...payload,
messages: currentMessages,
stream: true,
tool_choice: undefined,
}
const streamResponse: any = await anthropic.messages.create(streamingPayload)
const streamingResult = {
stream: createReadableStreamFromAnthropicStream(
streamResponse,
(streamContent, usage) => {
streamingResult.execution.output.content = streamContent
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.input_tokens,
completion: tokens.completion + usage.output_tokens,
total: tokens.total + usage.input_tokens + usage.output_tokens,
}
const streamCost = calculateCost(
request.model,
usage.input_tokens,
usage.output_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
if (streamingResult.execution.output.providerTiming) {
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
streamingResult.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
}
}
),
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,
total: tokens.total,
},
toolCalls:
toolCalls.length > 0
? {
list: toolCalls,
count: toolCalls.length,
}
: undefined,
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
modelTime: modelTime,
toolsTime: toolsTime,
firstResponseTime: firstResponseTime,
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
cost: {
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
},
isStreaming: true,
},
}
return streamingResult as StreamingExecution
} catch (error) {
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
@@ -934,7 +990,7 @@ export const anthropicProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'claude-3-7-sonnet-20250219',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,
@@ -978,7 +1034,7 @@ export const anthropicProvider: ProviderConfig = {
return {
content,
model: request.model || 'claude-3-7-sonnet-20250219',
model: request.model,
tokens,
toolCalls:
toolCalls.length > 0

View File

@@ -39,7 +39,7 @@ export const azureOpenAIProvider: ProviderConfig = {
request: ProviderRequest
): Promise<ProviderResponse | StreamingExecution> => {
logger.info('Preparing Azure OpenAI request', {
model: request.model || 'azure/gpt-4o',
model: request.model,
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
@@ -95,7 +95,7 @@ export const azureOpenAIProvider: ProviderConfig = {
}))
: undefined
const deploymentName = (request.model || 'azure/gpt-4o').replace('azure/', '')
const deploymentName = request.model.replace('azure/', '')
const payload: any = {
model: deploymentName,
messages: allMessages,

View File

@@ -73,7 +73,7 @@ export const cerebrasProvider: ProviderConfig = {
: undefined
const payload: any = {
model: (request.model || 'cerebras/llama-3.3-70b').replace('cerebras/', ''),
model: request.model.replace('cerebras/', ''),
messages: allMessages,
}
if (request.temperature !== undefined) payload.temperature = request.temperature
@@ -145,7 +145,7 @@ export const cerebrasProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'cerebras/llama-3.3-70b',
model: request.model,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
@@ -470,7 +470,7 @@ export const cerebrasProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'cerebras/llama-3.3-70b',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,

View File

@@ -105,7 +105,7 @@ export const deepseekProvider: ProviderConfig = {
: toolChoice.type === 'any'
? `force:${toolChoice.any?.name || 'unknown'}`
: 'unknown',
model: request.model || 'deepseek-v3',
model: request.model,
})
}
}
@@ -145,7 +145,7 @@ export const deepseekProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'deepseek-chat',
model: request.model,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
@@ -469,7 +469,7 @@ export const deepseekProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'deepseek-chat',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,

View File

@@ -0,0 +1,58 @@
import { GoogleGenAI } from '@google/genai'
import { createLogger } from '@/lib/logs/console/logger'
import type { GeminiClientConfig } from './types'
const logger = createLogger('GeminiClient')
/**
* Creates a GoogleGenAI client configured for either Google Gemini API or Vertex AI
*
* For Google Gemini API:
* - Uses API key authentication
*
* For Vertex AI:
* - Uses OAuth access token via HTTP Authorization header
* - Requires project and location
*/
export function createGeminiClient(config: GeminiClientConfig): GoogleGenAI {
if (config.vertexai) {
if (!config.project) {
throw new Error('Vertex AI requires a project ID')
}
if (!config.accessToken) {
throw new Error('Vertex AI requires an access token')
}
const location = config.location ?? 'us-central1'
logger.info('Creating Vertex AI client', {
project: config.project,
location,
hasAccessToken: !!config.accessToken,
})
// Create client with Vertex AI configuration
// Use httpOptions.headers to pass the access token directly
return new GoogleGenAI({
vertexai: true,
project: config.project,
location,
httpOptions: {
headers: {
Authorization: `Bearer ${config.accessToken}`,
},
},
})
}
// Google Gemini API with API key
if (!config.apiKey) {
throw new Error('Google Gemini API requires an API key')
}
logger.info('Creating Google Gemini client')
return new GoogleGenAI({
apiKey: config.apiKey,
})
}

View File

@@ -0,0 +1,680 @@
import {
type Content,
FunctionCallingConfigMode,
type FunctionDeclaration,
type GenerateContentConfig,
type GenerateContentResponse,
type GoogleGenAI,
type Part,
type Schema,
type ThinkingConfig,
type ToolConfig,
} from '@google/genai'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import {
checkForForcedToolUsage,
cleanSchemaForGemini,
convertToGeminiFormat,
convertUsageMetadata,
createReadableStreamFromGeminiStream,
extractFunctionCallPart,
extractTextContent,
mapToThinkingLevel,
} from '@/providers/google/utils'
import { getThinkingCapability } from '@/providers/models'
import type { FunctionCallResponse, ProviderRequest, ProviderResponse } from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
} from '@/providers/utils'
import { executeTool } from '@/tools'
import type { ExecutionState, GeminiProviderType, GeminiUsage, ParsedFunctionCall } from './types'
/**
* Creates initial execution state
*/
function createInitialState(
contents: Content[],
initialUsage: GeminiUsage,
firstResponseTime: number,
initialCallTime: number,
model: string,
toolConfig: ToolConfig | undefined
): ExecutionState {
const initialCost = calculateCost(
model,
initialUsage.promptTokenCount,
initialUsage.candidatesTokenCount
)
return {
contents,
tokens: {
prompt: initialUsage.promptTokenCount,
completion: initialUsage.candidatesTokenCount,
total: initialUsage.totalTokenCount,
},
cost: initialCost,
toolCalls: [],
toolResults: [],
iterationCount: 0,
modelTime: firstResponseTime,
toolsTime: 0,
timeSegments: [
{
type: 'model',
name: 'Initial response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
],
usedForcedTools: [],
currentToolConfig: toolConfig,
}
}
/**
* Executes a tool call and updates state
*/
async function executeToolCall(
functionCallPart: Part,
functionCall: ParsedFunctionCall,
request: ProviderRequest,
state: ExecutionState,
forcedTools: string[],
logger: ReturnType<typeof createLogger>
): Promise<{ success: boolean; state: ExecutionState }> {
const toolCallStartTime = Date.now()
const toolName = functionCall.name
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) {
logger.warn(`Tool ${toolName} not found in registry, skipping`)
return { success: false, state }
}
try {
const { toolParams, executionParams } = prepareToolExecution(tool, functionCall.args, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const duration = toolCallEndTime - toolCallStartTime
const resultContent: Record<string, unknown> = result.success
? (result.output as Record<string, unknown>)
: { error: true, message: result.error || 'Tool execution failed', tool: toolName }
const toolCall: FunctionCallResponse = {
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration,
result: resultContent,
}
const updatedContents: Content[] = [
...state.contents,
{
role: 'model',
parts: [functionCallPart],
},
{
role: 'user',
parts: [
{
functionResponse: {
name: functionCall.name,
response: resultContent,
},
},
],
},
]
const forcedToolCheck = checkForForcedToolUsage(
[{ name: functionCall.name, args: functionCall.args }],
state.currentToolConfig,
forcedTools,
state.usedForcedTools
)
return {
success: true,
state: {
...state,
contents: updatedContents,
toolCalls: [...state.toolCalls, toolCall],
toolResults: result.success
? [...state.toolResults, result.output as Record<string, unknown>]
: state.toolResults,
toolsTime: state.toolsTime + duration,
timeSegments: [
...state.timeSegments,
{
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration,
},
],
usedForcedTools: forcedToolCheck?.usedForcedTools ?? state.usedForcedTools,
currentToolConfig: forcedToolCheck?.nextToolConfig ?? state.currentToolConfig,
},
}
} catch (error) {
logger.error('Error processing function call:', {
error: error instanceof Error ? error.message : String(error),
functionName: toolName,
})
return { success: false, state }
}
}
/**
* Updates state with model response metadata
*/
function updateStateWithResponse(
state: ExecutionState,
response: GenerateContentResponse,
model: string,
startTime: number,
endTime: number
): ExecutionState {
const usage = convertUsageMetadata(response.usageMetadata)
const cost = calculateCost(model, usage.promptTokenCount, usage.candidatesTokenCount)
const duration = endTime - startTime
return {
...state,
tokens: {
prompt: state.tokens.prompt + usage.promptTokenCount,
completion: state.tokens.completion + usage.candidatesTokenCount,
total: state.tokens.total + usage.totalTokenCount,
},
cost: {
input: state.cost.input + cost.input,
output: state.cost.output + cost.output,
total: state.cost.total + cost.total,
pricing: cost.pricing, // Use latest pricing
},
modelTime: state.modelTime + duration,
timeSegments: [
...state.timeSegments,
{
type: 'model',
name: `Model response (iteration ${state.iterationCount + 1})`,
startTime,
endTime,
duration,
},
],
iterationCount: state.iterationCount + 1,
}
}
/**
* Builds config for next iteration
*/
function buildNextConfig(
baseConfig: GenerateContentConfig,
state: ExecutionState,
forcedTools: string[],
request: ProviderRequest,
logger: ReturnType<typeof createLogger>
): GenerateContentConfig {
const nextConfig = { ...baseConfig }
const allForcedToolsUsed =
forcedTools.length > 0 && state.usedForcedTools.length === forcedTools.length
if (allForcedToolsUsed && request.responseFormat) {
nextConfig.tools = undefined
nextConfig.toolConfig = undefined
nextConfig.responseMimeType = 'application/json'
nextConfig.responseSchema = cleanSchemaForGemini(request.responseFormat.schema) as Schema
logger.info('Using structured output for final response after tool execution')
} else if (state.currentToolConfig) {
nextConfig.toolConfig = state.currentToolConfig
} else {
nextConfig.toolConfig = { functionCallingConfig: { mode: FunctionCallingConfigMode.AUTO } }
}
return nextConfig
}
/**
* Creates streaming execution result template
*/
function createStreamingResult(
providerStartTime: number,
providerStartTimeISO: string,
firstResponseTime: number,
initialCallTime: number,
state?: ExecutionState
): StreamingExecution {
return {
stream: undefined as unknown as ReadableStream<Uint8Array>,
execution: {
success: true,
output: {
content: '',
model: '',
tokens: state?.tokens ?? { prompt: 0, completion: 0, total: 0 },
toolCalls: state?.toolCalls.length
? { list: state.toolCalls, count: state.toolCalls.length }
: undefined,
toolResults: state?.toolResults,
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
modelTime: state?.modelTime ?? firstResponseTime,
toolsTime: state?.toolsTime ?? 0,
firstResponseTime,
iterations: (state?.iterationCount ?? 0) + 1,
timeSegments: state?.timeSegments ?? [
{
type: 'model',
name: 'Initial streaming response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
],
},
cost: state?.cost ?? {
input: 0,
output: 0,
total: 0,
pricing: { input: 0, output: 0, updatedAt: new Date().toISOString() },
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
},
isStreaming: true,
},
}
}
/**
* Configuration for executing a Gemini request
*/
export interface GeminiExecutionConfig {
ai: GoogleGenAI
model: string
request: ProviderRequest
providerType: GeminiProviderType
}
/**
* Executes a request using the Gemini API
*
* This is the shared core logic for both Google and Vertex AI providers.
* The only difference is how the GoogleGenAI client is configured.
*/
export async function executeGeminiRequest(
config: GeminiExecutionConfig
): Promise<ProviderResponse | StreamingExecution> {
const { ai, model, request, providerType } = config
const logger = createLogger(providerType === 'google' ? 'GoogleProvider' : 'VertexProvider')
logger.info(`Preparing ${providerType} Gemini request`, {
model,
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
toolCount: request.tools?.length ?? 0,
hasResponseFormat: !!request.responseFormat,
streaming: !!request.stream,
})
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
// Build configuration
const geminiConfig: GenerateContentConfig = {}
if (request.temperature !== undefined) {
geminiConfig.temperature = request.temperature
}
if (request.maxTokens !== undefined) {
geminiConfig.maxOutputTokens = request.maxTokens
}
if (systemInstruction) {
geminiConfig.systemInstruction = systemInstruction
}
// Handle response format (only when no tools)
if (request.responseFormat && !tools?.length) {
geminiConfig.responseMimeType = 'application/json'
geminiConfig.responseSchema = cleanSchemaForGemini(request.responseFormat.schema) as Schema
logger.info('Using Gemini native structured output format')
} else if (request.responseFormat && tools?.length) {
logger.warn('Gemini does not support responseFormat with tools. Structured output ignored.')
}
// Configure thinking for models that support it
const thinkingCapability = getThinkingCapability(model)
if (thinkingCapability) {
const level = request.thinkingLevel ?? thinkingCapability.default ?? 'high'
const thinkingConfig: ThinkingConfig = {
includeThoughts: false,
thinkingLevel: mapToThinkingLevel(level),
}
geminiConfig.thinkingConfig = thinkingConfig
}
// Prepare tools
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
let toolConfig: ToolConfig | undefined
if (tools?.length) {
const functionDeclarations: FunctionDeclaration[] = tools.map((t) => ({
name: t.name,
description: t.description,
parameters: t.parameters,
}))
preparedTools = prepareToolsWithUsageControl(
functionDeclarations,
request.tools,
logger,
'google'
)
const { tools: filteredTools, toolConfig: preparedToolConfig } = preparedTools
if (filteredTools?.length) {
geminiConfig.tools = [{ functionDeclarations: filteredTools as FunctionDeclaration[] }]
if (preparedToolConfig) {
toolConfig = {
functionCallingConfig: {
mode:
{
AUTO: FunctionCallingConfigMode.AUTO,
ANY: FunctionCallingConfigMode.ANY,
NONE: FunctionCallingConfigMode.NONE,
}[preparedToolConfig.functionCallingConfig.mode] ?? FunctionCallingConfigMode.AUTO,
allowedFunctionNames: preparedToolConfig.functionCallingConfig.allowedFunctionNames,
},
}
geminiConfig.toolConfig = toolConfig
}
logger.info('Gemini request with tools:', {
toolCount: filteredTools.length,
model,
tools: filteredTools.map((t) => (t as FunctionDeclaration).name),
})
}
}
const initialCallTime = Date.now()
const shouldStream = request.stream && !tools?.length
// Streaming without tools
if (shouldStream) {
logger.info('Handling Gemini streaming response')
const streamGenerator = await ai.models.generateContentStream({
model,
contents,
config: geminiConfig,
})
const firstResponseTime = Date.now() - initialCallTime
const streamingResult = createStreamingResult(
providerStartTime,
providerStartTimeISO,
firstResponseTime,
initialCallTime
)
streamingResult.execution.output.model = model
streamingResult.stream = createReadableStreamFromGeminiStream(
streamGenerator,
(content: string, usage: GeminiUsage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: usage.promptTokenCount,
completion: usage.candidatesTokenCount,
total: usage.totalTokenCount,
}
const costResult = calculateCost(
model,
usage.promptTokenCount,
usage.candidatesTokenCount
)
streamingResult.execution.output.cost = costResult
const streamEndTime = Date.now()
if (streamingResult.execution.output.providerTiming) {
streamingResult.execution.output.providerTiming.endTime = new Date(
streamEndTime
).toISOString()
streamingResult.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
const segments = streamingResult.execution.output.providerTiming.timeSegments
if (segments?.[0]) {
segments[0].endTime = streamEndTime
segments[0].duration = streamEndTime - providerStartTime
}
}
}
)
return streamingResult
}
// Non-streaming request
const response = await ai.models.generateContent({ model, contents, config: geminiConfig })
const firstResponseTime = Date.now() - initialCallTime
// Check for UNEXPECTED_TOOL_CALL
const candidate = response.candidates?.[0]
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
logger.warn('Gemini returned UNEXPECTED_TOOL_CALL - model attempted to call unknown tool')
}
const initialUsage = convertUsageMetadata(response.usageMetadata)
let state = createInitialState(
contents,
initialUsage,
firstResponseTime,
initialCallTime,
model,
toolConfig
)
const forcedTools = preparedTools?.forcedTools ?? []
let currentResponse = response
let content = ''
// Tool execution loop
const functionCalls = response.functionCalls
if (functionCalls?.length) {
logger.info(`Received function call from Gemini: ${functionCalls[0].name}`)
while (state.iterationCount < MAX_TOOL_ITERATIONS) {
const functionCallPart = extractFunctionCallPart(currentResponse.candidates?.[0])
if (!functionCallPart?.functionCall) {
content = extractTextContent(currentResponse.candidates?.[0])
break
}
const functionCall: ParsedFunctionCall = {
name: functionCallPart.functionCall.name ?? '',
args: (functionCallPart.functionCall.args ?? {}) as Record<string, unknown>,
}
logger.info(
`Processing function call: ${functionCall.name} (iteration ${state.iterationCount + 1})`
)
const { success, state: updatedState } = await executeToolCall(
functionCallPart,
functionCall,
request,
state,
forcedTools,
logger
)
if (!success) {
content = extractTextContent(currentResponse.candidates?.[0])
break
}
state = { ...updatedState, iterationCount: updatedState.iterationCount + 1 }
const nextConfig = buildNextConfig(geminiConfig, state, forcedTools, request, logger)
// Stream final response if requested
if (request.stream) {
const checkResponse = await ai.models.generateContent({
model,
contents: state.contents,
config: nextConfig,
})
state = updateStateWithResponse(state, checkResponse, model, Date.now() - 100, Date.now())
if (checkResponse.functionCalls?.length) {
currentResponse = checkResponse
continue
}
logger.info('No more function calls, streaming final response')
if (request.responseFormat) {
nextConfig.tools = undefined
nextConfig.toolConfig = undefined
nextConfig.responseMimeType = 'application/json'
nextConfig.responseSchema = cleanSchemaForGemini(
request.responseFormat.schema
) as Schema
}
// Capture accumulated cost before streaming
const accumulatedCost = {
input: state.cost.input,
output: state.cost.output,
total: state.cost.total,
}
const accumulatedTokens = { ...state.tokens }
const streamGenerator = await ai.models.generateContentStream({
model,
contents: state.contents,
config: nextConfig,
})
const streamingResult = createStreamingResult(
providerStartTime,
providerStartTimeISO,
firstResponseTime,
initialCallTime,
state
)
streamingResult.execution.output.model = model
streamingResult.stream = createReadableStreamFromGeminiStream(
streamGenerator,
(streamContent: string, usage: GeminiUsage) => {
streamingResult.execution.output.content = streamContent
streamingResult.execution.output.tokens = {
prompt: accumulatedTokens.prompt + usage.promptTokenCount,
completion: accumulatedTokens.completion + usage.candidatesTokenCount,
total: accumulatedTokens.total + usage.totalTokenCount,
}
const streamCost = calculateCost(
model,
usage.promptTokenCount,
usage.candidatesTokenCount
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
pricing: streamCost.pricing,
}
if (streamingResult.execution.output.providerTiming) {
streamingResult.execution.output.providerTiming.endTime = new Date().toISOString()
streamingResult.execution.output.providerTiming.duration =
Date.now() - providerStartTime
}
}
)
return streamingResult
}
// Non-streaming: get next response
const nextModelStartTime = Date.now()
const nextResponse = await ai.models.generateContent({
model,
contents: state.contents,
config: nextConfig,
})
state = updateStateWithResponse(state, nextResponse, model, nextModelStartTime, Date.now())
currentResponse = nextResponse
}
if (!content) {
content = extractTextContent(currentResponse.candidates?.[0])
}
} else {
content = extractTextContent(candidate)
}
const providerEndTime = Date.now()
return {
content,
model,
tokens: state.tokens,
toolCalls: state.toolCalls.length ? state.toolCalls : undefined,
toolResults: state.toolResults.length ? state.toolResults : undefined,
timing: {
startTime: providerStartTimeISO,
endTime: new Date(providerEndTime).toISOString(),
duration: providerEndTime - providerStartTime,
modelTime: state.modelTime,
toolsTime: state.toolsTime,
firstResponseTime,
iterations: state.iterationCount + 1,
timeSegments: state.timeSegments,
},
cost: state.cost,
}
} catch (error) {
const providerEndTime = Date.now()
const duration = providerEndTime - providerStartTime
logger.error('Error in Gemini request:', {
error: error instanceof Error ? error.message : String(error),
stack: error instanceof Error ? error.stack : undefined,
})
const enhancedError = error instanceof Error ? error : new Error(String(error))
Object.assign(enhancedError, {
timing: {
startTime: providerStartTimeISO,
endTime: new Date(providerEndTime).toISOString(),
duration,
},
})
throw enhancedError
}
}

View File

@@ -0,0 +1,18 @@
/**
* Shared Gemini execution core
*
* This module provides the shared execution logic for both Google Gemini API
* and Vertex AI providers. The only difference between providers is how the
* GoogleGenAI client is configured (API key vs OAuth).
*/
export { createGeminiClient } from './client'
export { executeGeminiRequest, type GeminiExecutionConfig } from './core'
export type {
ExecutionState,
ForcedToolResult,
GeminiClientConfig,
GeminiProviderType,
GeminiUsage,
ParsedFunctionCall,
} from './types'

View File

@@ -0,0 +1,64 @@
import type { Content, ToolConfig } from '@google/genai'
import type { FunctionCallResponse, ModelPricing, TimeSegment } from '@/providers/types'
/**
* Usage metadata from Gemini responses
*/
export interface GeminiUsage {
promptTokenCount: number
candidatesTokenCount: number
totalTokenCount: number
}
/**
* Parsed function call from Gemini response
*/
export interface ParsedFunctionCall {
name: string
args: Record<string, unknown>
}
/**
* Accumulated state during tool execution loop
*/
export interface ExecutionState {
contents: Content[]
tokens: { prompt: number; completion: number; total: number }
cost: { input: number; output: number; total: number; pricing: ModelPricing }
toolCalls: FunctionCallResponse[]
toolResults: Record<string, unknown>[]
iterationCount: number
modelTime: number
toolsTime: number
timeSegments: TimeSegment[]
usedForcedTools: string[]
currentToolConfig: ToolConfig | undefined
}
/**
* Result from forced tool usage check
*/
export interface ForcedToolResult {
hasUsedForcedTool: boolean
usedForcedTools: string[]
nextToolConfig: ToolConfig | undefined
}
/**
* Configuration for creating a Gemini client
*/
export interface GeminiClientConfig {
/** For Google Gemini API */
apiKey?: string
/** For Vertex AI */
vertexai?: boolean
project?: string
location?: string
/** OAuth access token for Vertex AI */
accessToken?: string
}
/**
* Provider type for logging and model lookup
*/
export type GeminiProviderType = 'google' | 'vertex'

File diff suppressed because it is too large Load Diff

View File

@@ -1,61 +1,89 @@
import type { Candidate } from '@google/genai'
import {
type Candidate,
type Content,
type FunctionCall,
FunctionCallingConfigMode,
type GenerateContentResponse,
type GenerateContentResponseUsageMetadata,
type Part,
type Schema,
type SchemaUnion,
ThinkingLevel,
type ToolConfig,
Type,
} from '@google/genai'
import { createLogger } from '@/lib/logs/console/logger'
import type { ProviderRequest } from '@/providers/types'
import { trackForcedToolUsage } from '@/providers/utils'
const logger = createLogger('GoogleUtils')
/**
* Usage metadata for Google Gemini responses
*/
export interface GeminiUsage {
promptTokenCount: number
candidatesTokenCount: number
totalTokenCount: number
}
/**
* Parsed function call from Gemini response
*/
export interface ParsedFunctionCall {
name: string
args: Record<string, unknown>
}
/**
* Removes additionalProperties from a schema object (not supported by Gemini)
*/
export function cleanSchemaForGemini(schema: any): any {
export function cleanSchemaForGemini(schema: SchemaUnion): SchemaUnion {
if (schema === null || schema === undefined) return schema
if (typeof schema !== 'object') return schema
if (Array.isArray(schema)) {
return schema.map((item) => cleanSchemaForGemini(item))
}
const cleanedSchema: any = {}
const cleanedSchema: Record<string, unknown> = {}
const schemaObj = schema as Record<string, unknown>
for (const key in schema) {
for (const key in schemaObj) {
if (key === 'additionalProperties') continue
cleanedSchema[key] = cleanSchemaForGemini(schema[key])
cleanedSchema[key] = cleanSchemaForGemini(schemaObj[key] as SchemaUnion)
}
return cleanedSchema
}
/**
* Extracts text content from a Gemini response candidate, handling structured output
* Extracts text content from a Gemini response candidate.
* Filters out thought parts (model reasoning) from the output.
*/
export function extractTextContent(candidate: Candidate | undefined): string {
if (!candidate?.content?.parts) return ''
if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) {
const text = candidate.content.parts[0].text
if (text && (text.trim().startsWith('{') || text.trim().startsWith('['))) {
try {
JSON.parse(text)
return text
} catch (_e) {}
}
}
const textParts = candidate.content.parts.filter(
(part): part is Part & { text: string } => Boolean(part.text) && part.thought !== true
)
return candidate.content.parts
.filter((part: any) => part.text)
.map((part: any) => part.text)
.join('\n')
if (textParts.length === 0) return ''
if (textParts.length === 1) return textParts[0].text
return textParts.map((part) => part.text).join('\n')
}
/**
* Extracts a function call from a Gemini response candidate
* Extracts the first function call from a Gemini response candidate
*/
export function extractFunctionCall(
candidate: Candidate | undefined
): { name: string; args: any } | null {
export function extractFunctionCall(candidate: Candidate | undefined): ParsedFunctionCall | null {
if (!candidate?.content?.parts) return null
for (const part of candidate.content.parts) {
if (part.functionCall) {
return {
name: part.functionCall.name ?? '',
args: part.functionCall.args ?? {},
args: (part.functionCall.args ?? {}) as Record<string, unknown>,
}
}
}
@@ -63,16 +91,55 @@ export function extractFunctionCall(
return null
}
/**
* Extracts the full Part containing the function call (preserves thoughtSignature)
*/
export function extractFunctionCallPart(candidate: Candidate | undefined): Part | null {
if (!candidate?.content?.parts) return null
for (const part of candidate.content.parts) {
if (part.functionCall) {
return part
}
}
return null
}
/**
* Converts usage metadata from SDK response to our format
*/
export function convertUsageMetadata(
usageMetadata: GenerateContentResponseUsageMetadata | undefined
): GeminiUsage {
const promptTokenCount = usageMetadata?.promptTokenCount ?? 0
const candidatesTokenCount = usageMetadata?.candidatesTokenCount ?? 0
return {
promptTokenCount,
candidatesTokenCount,
totalTokenCount: usageMetadata?.totalTokenCount ?? promptTokenCount + candidatesTokenCount,
}
}
/**
* Tool definition for Gemini format
*/
export interface GeminiToolDef {
name: string
description: string
parameters: Schema
}
/**
* Converts OpenAI-style request format to Gemini format
*/
export function convertToGeminiFormat(request: ProviderRequest): {
contents: any[]
tools: any[] | undefined
systemInstruction: any | undefined
contents: Content[]
tools: GeminiToolDef[] | undefined
systemInstruction: Content | undefined
} {
const contents: any[] = []
let systemInstruction
const contents: Content[] = []
let systemInstruction: Content | undefined
if (request.systemPrompt) {
systemInstruction = { parts: [{ text: request.systemPrompt }] }
@@ -82,13 +149,13 @@ export function convertToGeminiFormat(request: ProviderRequest): {
contents.push({ role: 'user', parts: [{ text: request.context }] })
}
if (request.messages && request.messages.length > 0) {
if (request.messages?.length) {
for (const message of request.messages) {
if (message.role === 'system') {
if (!systemInstruction) {
systemInstruction = { parts: [{ text: message.content }] }
} else {
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text || ''}\n${message.content}`
systemInstruction = { parts: [{ text: message.content ?? '' }] }
} else if (systemInstruction.parts?.[0] && 'text' in systemInstruction.parts[0]) {
systemInstruction.parts[0].text = `${systemInstruction.parts[0].text}\n${message.content}`
}
} else if (message.role === 'user' || message.role === 'assistant') {
const geminiRole = message.role === 'user' ? 'user' : 'model'
@@ -97,60 +164,200 @@ export function convertToGeminiFormat(request: ProviderRequest): {
contents.push({ role: geminiRole, parts: [{ text: message.content }] })
}
if (message.role === 'assistant' && message.tool_calls && message.tool_calls.length > 0) {
if (message.role === 'assistant' && message.tool_calls?.length) {
const functionCalls = message.tool_calls.map((toolCall) => ({
functionCall: {
name: toolCall.function?.name,
args: JSON.parse(toolCall.function?.arguments || '{}'),
args: JSON.parse(toolCall.function?.arguments || '{}') as Record<string, unknown>,
},
}))
contents.push({ role: 'model', parts: functionCalls })
}
} else if (message.role === 'tool') {
if (!message.name) {
logger.warn('Tool message missing function name, skipping')
continue
}
let responseData: Record<string, unknown>
try {
responseData = JSON.parse(message.content ?? '{}')
} catch {
responseData = { output: message.content }
}
contents.push({
role: 'user',
parts: [{ text: `Function result: ${message.content}` }],
parts: [
{
functionResponse: {
id: message.tool_call_id,
name: message.name,
response: responseData,
},
},
],
})
}
}
}
const tools = request.tools?.map((tool) => {
const tools = request.tools?.map((tool): GeminiToolDef => {
const toolParameters = { ...(tool.parameters || {}) }
if (toolParameters.properties) {
const properties = { ...toolParameters.properties }
const required = toolParameters.required ? [...toolParameters.required] : []
// Remove default values from properties (not supported by Gemini)
for (const key in properties) {
const prop = properties[key] as any
const prop = properties[key] as Record<string, unknown>
if (prop.default !== undefined) {
const { default: _, ...cleanProp } = prop
properties[key] = cleanProp
}
}
const parameters = {
type: toolParameters.type || 'object',
properties,
const parameters: Schema = {
type: (toolParameters.type as Schema['type']) || Type.OBJECT,
properties: properties as Record<string, Schema>,
...(required.length > 0 ? { required } : {}),
}
return {
name: tool.id,
description: tool.description || `Execute the ${tool.id} function`,
parameters: cleanSchemaForGemini(parameters),
parameters: cleanSchemaForGemini(parameters) as Schema,
}
}
return {
name: tool.id,
description: tool.description || `Execute the ${tool.id} function`,
parameters: cleanSchemaForGemini(toolParameters),
parameters: cleanSchemaForGemini(toolParameters) as Schema,
}
})
return { contents, tools, systemInstruction }
}
/**
* Creates a ReadableStream from a Google Gemini streaming response
*/
export function createReadableStreamFromGeminiStream(
stream: AsyncGenerator<GenerateContentResponse>,
onComplete?: (content: string, usage: GeminiUsage) => void
): ReadableStream<Uint8Array> {
let fullContent = ''
let usage: GeminiUsage = { promptTokenCount: 0, candidatesTokenCount: 0, totalTokenCount: 0 }
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of stream) {
if (chunk.usageMetadata) {
usage = convertUsageMetadata(chunk.usageMetadata)
}
const text = chunk.text
if (text) {
fullContent += text
controller.enqueue(new TextEncoder().encode(text))
}
}
onComplete?.(fullContent, usage)
controller.close()
} catch (error) {
logger.error('Error reading Google Gemini stream', {
error: error instanceof Error ? error.message : String(error),
})
controller.error(error)
}
},
})
}
/**
* Maps string mode to FunctionCallingConfigMode enum
*/
function mapToFunctionCallingMode(mode: string): FunctionCallingConfigMode {
switch (mode) {
case 'AUTO':
return FunctionCallingConfigMode.AUTO
case 'ANY':
return FunctionCallingConfigMode.ANY
case 'NONE':
return FunctionCallingConfigMode.NONE
default:
return FunctionCallingConfigMode.AUTO
}
}
/**
* Maps string level to ThinkingLevel enum
*/
export function mapToThinkingLevel(level: string): ThinkingLevel {
switch (level.toLowerCase()) {
case 'minimal':
return ThinkingLevel.MINIMAL
case 'low':
return ThinkingLevel.LOW
case 'medium':
return ThinkingLevel.MEDIUM
case 'high':
return ThinkingLevel.HIGH
default:
return ThinkingLevel.HIGH
}
}
/**
* Result of checking forced tool usage
*/
export interface ForcedToolResult {
hasUsedForcedTool: boolean
usedForcedTools: string[]
nextToolConfig: ToolConfig | undefined
}
/**
* Checks for forced tool usage in Google Gemini responses
*/
export function checkForForcedToolUsage(
functionCalls: FunctionCall[] | undefined,
toolConfig: ToolConfig | undefined,
forcedTools: string[],
usedForcedTools: string[]
): ForcedToolResult | null {
if (!functionCalls?.length) return null
const adaptedToolCalls = functionCalls.map((fc) => ({
name: fc.name ?? '',
arguments: (fc.args ?? {}) as Record<string, unknown>,
}))
const result = trackForcedToolUsage(
adaptedToolCalls,
toolConfig,
logger,
'google',
forcedTools,
usedForcedTools
)
if (!result) return null
const nextToolConfig: ToolConfig | undefined = result.nextToolConfig?.functionCallingConfig?.mode
? {
functionCallingConfig: {
mode: mapToFunctionCallingMode(result.nextToolConfig.functionCallingConfig.mode),
allowedFunctionNames: result.nextToolConfig.functionCallingConfig.allowedFunctionNames,
},
}
: undefined
return {
hasUsedForcedTool: result.hasUsedForcedTool,
usedForcedTools: result.usedForcedTools,
nextToolConfig,
}
}

View File

@@ -69,10 +69,7 @@ export const groqProvider: ProviderConfig = {
: undefined
const payload: any = {
model: (request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct').replace(
'groq/',
''
),
model: request.model.replace('groq/', ''),
messages: allMessages,
}
@@ -109,7 +106,7 @@ export const groqProvider: ProviderConfig = {
toolChoice: payload.tool_choice,
forcedToolsCount: forcedTools.length,
hasFilteredTools,
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
model: request.model,
})
}
}
@@ -149,7 +146,7 @@ export const groqProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
model: request.model,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
@@ -393,7 +390,7 @@ export const groqProvider: ProviderConfig = {
const streamingPayload = {
...payload,
messages: currentMessages,
tool_choice: 'auto',
tool_choice: originalToolChoice || 'auto',
stream: true,
}
@@ -425,7 +422,7 @@ export const groqProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,

View File

@@ -1,11 +1,11 @@
import { getCostMultiplier } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import type { ProviderRequest, ProviderResponse } from '@/providers/types'
import { getProviderExecutor } from '@/providers/registry'
import type { ProviderId, ProviderRequest, ProviderResponse } from '@/providers/types'
import {
calculateCost,
generateStructuredOutputInstructions,
getProvider,
shouldBillModelUsage,
supportsTemperature,
} from '@/providers/utils'
@@ -40,7 +40,7 @@ export async function executeProviderRequest(
providerId: string,
request: ProviderRequest
): Promise<ProviderResponse | ReadableStream | StreamingExecution> {
const provider = getProvider(providerId)
const provider = await getProviderExecutor(providerId as ProviderId)
if (!provider) {
throw new Error(`Provider not found: ${providerId}`)
}

View File

@@ -36,7 +36,7 @@ export const mistralProvider: ProviderConfig = {
request: ProviderRequest
): Promise<ProviderResponse | StreamingExecution> => {
logger.info('Preparing Mistral request', {
model: request.model || 'mistral-large-latest',
model: request.model,
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
@@ -86,7 +86,7 @@ export const mistralProvider: ProviderConfig = {
: undefined
const payload: any = {
model: request.model || 'mistral-large-latest',
model: request.model,
messages: allMessages,
}
@@ -126,7 +126,7 @@ export const mistralProvider: ProviderConfig = {
: toolChoice.type === 'any'
? `force:${toolChoice.any?.name || 'unknown'}`
: 'unknown',
model: request.model || 'mistral-large-latest',
model: request.model,
})
}
}

View File

@@ -39,6 +39,10 @@ export interface ModelCapabilities {
verbosity?: {
values: string[]
}
thinking?: {
levels: string[]
default?: string
}
}
export interface ModelDefinition {
@@ -730,6 +734,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 2 },
thinking: {
levels: ['low', 'high'],
default: 'high',
},
},
contextWindow: 1000000,
},
@@ -743,6 +751,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 2 },
thinking: {
levels: ['minimal', 'low', 'medium', 'high'],
default: 'high',
},
},
contextWindow: 1000000,
},
@@ -832,6 +844,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 2 },
thinking: {
levels: ['low', 'high'],
default: 'high',
},
},
contextWindow: 1000000,
},
@@ -845,6 +861,10 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 2 },
thinking: {
levels: ['minimal', 'low', 'medium', 'high'],
default: 'high',
},
},
contextWindow: 1000000,
},
@@ -1864,3 +1884,49 @@ export function supportsNativeStructuredOutputs(modelId: string): boolean {
}
return false
}
/**
* Check if a model supports thinking/reasoning features.
* Returns the thinking capability config if supported, null otherwise.
*/
export function getThinkingCapability(
modelId: string
): { levels: string[]; default?: string } | null {
const normalizedModelId = modelId.toLowerCase()
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
for (const model of provider.models) {
if (model.capabilities.thinking) {
const baseModelId = model.id.toLowerCase()
if (normalizedModelId === baseModelId || normalizedModelId.startsWith(`${baseModelId}-`)) {
return model.capabilities.thinking
}
}
}
}
return null
}
/**
* Get all models that support thinking capability
*/
export function getModelsWithThinking(): string[] {
const models: string[] = []
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
for (const model of provider.models) {
if (model.capabilities.thinking) {
models.push(model.id)
}
}
}
return models
}
/**
* Get the thinking levels for a specific model
* Returns the valid levels for that model, or null if the model doesn't support thinking
*/
export function getThinkingLevelsForModel(modelId: string): string[] | null {
const capability = getThinkingCapability(modelId)
return capability?.levels ?? null
}

View File

@@ -33,7 +33,7 @@ export const openaiProvider: ProviderConfig = {
request: ProviderRequest
): Promise<ProviderResponse | StreamingExecution> => {
logger.info('Preparing OpenAI request', {
model: request.model || 'gpt-4o',
model: request.model,
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
@@ -76,7 +76,7 @@ export const openaiProvider: ProviderConfig = {
: undefined
const payload: any = {
model: request.model || 'gpt-4o',
model: request.model,
messages: allMessages,
}
@@ -121,7 +121,7 @@ export const openaiProvider: ProviderConfig = {
: toolChoice.type === 'any'
? `force:${toolChoice.any?.name || 'unknown'}`
: 'unknown',
model: request.model || 'gpt-4o',
model: request.model,
})
}
}

View File

@@ -78,7 +78,7 @@ export const openRouterProvider: ProviderConfig = {
baseURL: 'https://openrouter.ai/api/v1',
})
const requestedModel = (request.model || '').replace(/^openrouter\//, '')
const requestedModel = request.model.replace(/^openrouter\//, '')
logger.info('Preparing OpenRouter request', {
model: requestedModel,

View File

@@ -0,0 +1,59 @@
import { createLogger } from '@/lib/logs/console/logger'
import { anthropicProvider } from '@/providers/anthropic'
import { azureOpenAIProvider } from '@/providers/azure-openai'
import { cerebrasProvider } from '@/providers/cerebras'
import { deepseekProvider } from '@/providers/deepseek'
import { googleProvider } from '@/providers/google'
import { groqProvider } from '@/providers/groq'
import { mistralProvider } from '@/providers/mistral'
import { ollamaProvider } from '@/providers/ollama'
import { openaiProvider } from '@/providers/openai'
import { openRouterProvider } from '@/providers/openrouter'
import type { ProviderConfig, ProviderId } from '@/providers/types'
import { vertexProvider } from '@/providers/vertex'
import { vllmProvider } from '@/providers/vllm'
import { xAIProvider } from '@/providers/xai'
const logger = createLogger('ProviderRegistry')
const providerRegistry: Record<ProviderId, ProviderConfig> = {
openai: openaiProvider,
anthropic: anthropicProvider,
google: googleProvider,
vertex: vertexProvider,
deepseek: deepseekProvider,
xai: xAIProvider,
cerebras: cerebrasProvider,
groq: groqProvider,
vllm: vllmProvider,
mistral: mistralProvider,
'azure-openai': azureOpenAIProvider,
openrouter: openRouterProvider,
ollama: ollamaProvider,
}
export async function getProviderExecutor(
providerId: ProviderId
): Promise<ProviderConfig | undefined> {
const provider = providerRegistry[providerId]
if (!provider) {
logger.error(`Provider not found: ${providerId}`)
return undefined
}
return provider
}
export async function initializeProviders(): Promise<void> {
for (const [id, provider] of Object.entries(providerRegistry)) {
if (provider.initialize) {
try {
await provider.initialize()
logger.info(`Initialized provider: ${id}`)
} catch (error) {
logger.error(`Failed to initialize ${id} provider`, {
error: error instanceof Error ? error.message : 'Unknown error',
})
}
}
}
}

View File

@@ -164,6 +164,7 @@ export interface ProviderRequest {
vertexLocation?: string
reasoningEffort?: string
verbosity?: string
thinkingLevel?: string
}
export const providers: Record<string, ProviderConfig> = {}

View File

@@ -3,13 +3,6 @@ import type { CompletionUsage } from 'openai/resources/completions'
import { getEnv, isTruthy } from '@/lib/core/config/env'
import { isHosted } from '@/lib/core/config/feature-flags'
import { createLogger, type Logger } from '@/lib/logs/console/logger'
import { anthropicProvider } from '@/providers/anthropic'
import { azureOpenAIProvider } from '@/providers/azure-openai'
import { cerebrasProvider } from '@/providers/cerebras'
import { deepseekProvider } from '@/providers/deepseek'
import { googleProvider } from '@/providers/google'
import { groqProvider } from '@/providers/groq'
import { mistralProvider } from '@/providers/mistral'
import {
getComputerUseModels,
getEmbeddingModelPricing,
@@ -20,117 +13,82 @@ import {
getModelsWithTemperatureSupport,
getModelsWithTempRange01,
getModelsWithTempRange02,
getModelsWithThinking,
getModelsWithVerbosity,
getProviderDefaultModel as getProviderDefaultModelFromDefinitions,
getProviderModels as getProviderModelsFromDefinitions,
getProvidersWithToolUsageControl,
getReasoningEffortValuesForModel as getReasoningEffortValuesForModelFromDefinitions,
getThinkingLevelsForModel as getThinkingLevelsForModelFromDefinitions,
getVerbosityValuesForModel as getVerbosityValuesForModelFromDefinitions,
PROVIDER_DEFINITIONS,
supportsTemperature as supportsTemperatureFromDefinitions,
supportsToolUsageControl as supportsToolUsageControlFromDefinitions,
updateOllamaModels as updateOllamaModelsInDefinitions,
} from '@/providers/models'
import { ollamaProvider } from '@/providers/ollama'
import { openaiProvider } from '@/providers/openai'
import { openRouterProvider } from '@/providers/openrouter'
import type { ProviderConfig, ProviderId, ProviderToolConfig } from '@/providers/types'
import { vertexProvider } from '@/providers/vertex'
import { vllmProvider } from '@/providers/vllm'
import { xAIProvider } from '@/providers/xai'
import type { ProviderId, ProviderToolConfig } from '@/providers/types'
import { useCustomToolsStore } from '@/stores/custom-tools/store'
import { useProvidersStore } from '@/stores/providers/store'
const logger = createLogger('ProviderUtils')
export const providers: Record<
ProviderId,
ProviderConfig & {
models: string[]
computerUseModels?: string[]
modelPatterns?: RegExp[]
/**
* Client-safe provider metadata.
* This object contains only model lists and patterns - no executeRequest implementations.
* For server-side execution, use @/providers/registry.
*/
export interface ProviderMetadata {
id: string
name: string
description: string
version: string
models: string[]
defaultModel: string
computerUseModels?: string[]
modelPatterns?: RegExp[]
}
/**
* Build provider metadata from PROVIDER_DEFINITIONS.
* This is client-safe as it doesn't import any provider implementations.
*/
function buildProviderMetadata(providerId: ProviderId): ProviderMetadata {
const def = PROVIDER_DEFINITIONS[providerId]
return {
id: providerId,
name: def?.name || providerId,
description: def?.description || '',
version: '1.0.0',
models: getProviderModelsFromDefinitions(providerId),
defaultModel: getProviderDefaultModelFromDefinitions(providerId),
modelPatterns: def?.modelPatterns,
}
> = {
}
export const providers: Record<ProviderId, ProviderMetadata> = {
openai: {
...openaiProvider,
models: getProviderModelsFromDefinitions('openai'),
...buildProviderMetadata('openai'),
computerUseModels: ['computer-use-preview'],
modelPatterns: PROVIDER_DEFINITIONS.openai.modelPatterns,
},
anthropic: {
...anthropicProvider,
models: getProviderModelsFromDefinitions('anthropic'),
...buildProviderMetadata('anthropic'),
computerUseModels: getComputerUseModels().filter((model) =>
getProviderModelsFromDefinitions('anthropic').includes(model)
),
modelPatterns: PROVIDER_DEFINITIONS.anthropic.modelPatterns,
},
google: {
...googleProvider,
models: getProviderModelsFromDefinitions('google'),
modelPatterns: PROVIDER_DEFINITIONS.google.modelPatterns,
},
vertex: {
...vertexProvider,
models: getProviderModelsFromDefinitions('vertex'),
modelPatterns: PROVIDER_DEFINITIONS.vertex.modelPatterns,
},
deepseek: {
...deepseekProvider,
models: getProviderModelsFromDefinitions('deepseek'),
modelPatterns: PROVIDER_DEFINITIONS.deepseek.modelPatterns,
},
xai: {
...xAIProvider,
models: getProviderModelsFromDefinitions('xai'),
modelPatterns: PROVIDER_DEFINITIONS.xai.modelPatterns,
},
cerebras: {
...cerebrasProvider,
models: getProviderModelsFromDefinitions('cerebras'),
modelPatterns: PROVIDER_DEFINITIONS.cerebras.modelPatterns,
},
groq: {
...groqProvider,
models: getProviderModelsFromDefinitions('groq'),
modelPatterns: PROVIDER_DEFINITIONS.groq.modelPatterns,
},
vllm: {
...vllmProvider,
models: getProviderModelsFromDefinitions('vllm'),
modelPatterns: PROVIDER_DEFINITIONS.vllm.modelPatterns,
},
mistral: {
...mistralProvider,
models: getProviderModelsFromDefinitions('mistral'),
modelPatterns: PROVIDER_DEFINITIONS.mistral.modelPatterns,
},
'azure-openai': {
...azureOpenAIProvider,
models: getProviderModelsFromDefinitions('azure-openai'),
modelPatterns: PROVIDER_DEFINITIONS['azure-openai'].modelPatterns,
},
openrouter: {
...openRouterProvider,
models: getProviderModelsFromDefinitions('openrouter'),
modelPatterns: PROVIDER_DEFINITIONS.openrouter.modelPatterns,
},
ollama: {
...ollamaProvider,
models: getProviderModelsFromDefinitions('ollama'),
modelPatterns: PROVIDER_DEFINITIONS.ollama.modelPatterns,
},
google: buildProviderMetadata('google'),
vertex: buildProviderMetadata('vertex'),
deepseek: buildProviderMetadata('deepseek'),
xai: buildProviderMetadata('xai'),
cerebras: buildProviderMetadata('cerebras'),
groq: buildProviderMetadata('groq'),
vllm: buildProviderMetadata('vllm'),
mistral: buildProviderMetadata('mistral'),
'azure-openai': buildProviderMetadata('azure-openai'),
openrouter: buildProviderMetadata('openrouter'),
ollama: buildProviderMetadata('ollama'),
}
Object.entries(providers).forEach(([id, provider]) => {
if (provider.initialize) {
provider.initialize().catch((error) => {
logger.error(`Failed to initialize ${id} provider`, {
error: error instanceof Error ? error.message : 'Unknown error',
})
})
}
})
export function updateOllamaProviderModels(models: string[]): void {
updateOllamaModelsInDefinitions(models)
providers.ollama.models = getProviderModelsFromDefinitions('ollama')
@@ -211,12 +169,12 @@ export function getProviderFromModel(model: string): ProviderId {
return 'ollama'
}
export function getProvider(id: string): ProviderConfig | undefined {
export function getProvider(id: string): ProviderMetadata | undefined {
const providerId = id.split('/')[0] as ProviderId
return providers[providerId]
}
export function getProviderConfigFromModel(model: string): ProviderConfig | undefined {
export function getProviderConfigFromModel(model: string): ProviderMetadata | undefined {
const providerId = getProviderFromModel(model)
return providers[providerId]
}
@@ -929,6 +887,7 @@ export const MODELS_TEMP_RANGE_0_1 = getModelsWithTempRange01()
export const MODELS_WITH_TEMPERATURE_SUPPORT = getModelsWithTemperatureSupport()
export const MODELS_WITH_REASONING_EFFORT = getModelsWithReasoningEffort()
export const MODELS_WITH_VERBOSITY = getModelsWithVerbosity()
export const MODELS_WITH_THINKING = getModelsWithThinking()
export const PROVIDERS_WITH_TOOL_USAGE_CONTROL = getProvidersWithToolUsageControl()
export function supportsTemperature(model: string): boolean {
@@ -963,6 +922,14 @@ export function getVerbosityValuesForModel(model: string): string[] | null {
return getVerbosityValuesForModelFromDefinitions(model)
}
/**
* Get thinking levels for a specific model
* Returns the valid levels for that model, or null if the model doesn't support thinking
*/
export function getThinkingLevelsForModel(model: string): string[] | null {
return getThinkingLevelsForModelFromDefinitions(model)
}
/**
* Prepare tool execution parameters, separating tool parameters from system parameters
*/

View File

@@ -1,33 +1,23 @@
import { GoogleGenAI } from '@google/genai'
import { OAuth2Client } from 'google-auth-library'
import { env } from '@/lib/core/config/env'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import {
cleanSchemaForGemini,
convertToGeminiFormat,
extractFunctionCall,
extractTextContent,
} from '@/providers/google/utils'
import { executeGeminiRequest } from '@/providers/gemini/core'
import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import type {
ProviderConfig,
ProviderRequest,
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
} from '@/providers/utils'
import { buildVertexEndpoint, createReadableStreamFromVertexStream } from '@/providers/vertex/utils'
import { executeTool } from '@/tools'
import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types'
const logger = createLogger('VertexProvider')
/**
* Vertex AI provider configuration
* Vertex AI provider
*
* Uses the @google/genai SDK with Vertex AI backend and OAuth authentication.
* Shares core execution logic with Google Gemini provider.
*
* Authentication:
* - Uses OAuth access token passed via googleAuthOptions.authClient
* - Token refresh is handled at the OAuth layer before calling this provider
*/
export const vertexProvider: ProviderConfig = {
id: 'vertex',
@@ -55,869 +45,35 @@ export const vertexProvider: ProviderConfig = {
)
}
logger.info('Preparing Vertex AI request', {
model: request.model || 'vertex/gemini-2.5-pro',
hasSystemPrompt: !!request.systemPrompt,
hasMessages: !!request.messages?.length,
hasTools: !!request.tools?.length,
toolCount: request.tools?.length || 0,
hasResponseFormat: !!request.responseFormat,
streaming: !!request.stream,
// Strip 'vertex/' prefix from model name if present
const model = request.model.replace('vertex/', '')
logger.info('Creating Vertex AI client', {
project: vertexProject,
location: vertexLocation,
model,
})
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
const { contents, tools, systemInstruction } = convertToGeminiFormat(request)
const requestedModel = (request.model || 'vertex/gemini-2.5-pro').replace('vertex/', '')
const payload: any = {
contents,
generationConfig: {},
}
if (request.temperature !== undefined && request.temperature !== null) {
payload.generationConfig.temperature = request.temperature
}
if (request.maxTokens !== undefined) {
payload.generationConfig.maxOutputTokens = request.maxTokens
}
if (systemInstruction) {
payload.systemInstruction = systemInstruction
}
if (request.responseFormat && !tools?.length) {
const responseFormatSchema = request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
payload.generationConfig.responseMimeType = 'application/json'
payload.generationConfig.responseSchema = cleanSchema
logger.info('Using Vertex AI native structured output format', {
hasSchema: !!cleanSchema,
mimeType: 'application/json',
})
} else if (request.responseFormat && tools?.length) {
logger.warn(
'Vertex AI does not support structured output (responseFormat) with function calling (tools). Structured output will be ignored.'
)
}
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
if (tools?.length) {
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'google')
const { tools: filteredTools, toolConfig } = preparedTools
if (filteredTools?.length) {
payload.tools = [
{
functionDeclarations: filteredTools,
},
]
if (toolConfig) {
payload.toolConfig = toolConfig
}
logger.info('Vertex AI request with tools:', {
toolCount: filteredTools.length,
model: requestedModel,
tools: filteredTools.map((t) => t.name),
hasToolConfig: !!toolConfig,
toolConfig: toolConfig,
})
}
}
const initialCallTime = Date.now()
const shouldStream = !!(request.stream && !tools?.length)
const endpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
shouldStream
)
if (request.stream && tools?.length) {
logger.info('Streaming disabled for initial request due to tools presence', {
toolCount: tools.length,
willStreamAfterTools: true,
})
}
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(payload),
})
if (!response.ok) {
const responseText = await response.text()
logger.error('Vertex AI API error details:', {
status: response.status,
statusText: response.statusText,
responseBody: responseText,
})
throw new Error(`Vertex AI API error: ${response.status} ${response.statusText}`)
}
const firstResponseTime = Date.now() - initialCallTime
if (shouldStream) {
logger.info('Handling Vertex AI streaming response')
const streamingResult: StreamingExecution = {
stream: null as any,
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens: {
prompt: 0,
completion: 0,
total: 0,
},
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: firstResponseTime,
modelTime: firstResponseTime,
toolsTime: 0,
firstResponseTime,
iterations: 1,
timeSegments: [
{
type: 'model',
name: 'Initial streaming response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
],
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: firstResponseTime,
},
isStreaming: true,
},
}
streamingResult.stream = createReadableStreamFromVertexStream(
response,
(content, usage) => {
streamingResult.execution.output.content = content
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
if (streamingResult.execution.output.providerTiming) {
streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO
streamingResult.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
streamEndTime
streamingResult.execution.output.providerTiming.timeSegments[0].duration =
streamEndTime - providerStartTime
}
}
const promptTokens = usage?.promptTokenCount || 0
const completionTokens = usage?.candidatesTokenCount || 0
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
streamingResult.execution.output.tokens = {
prompt: promptTokens,
completion: completionTokens,
total: totalTokens,
}
const costResult = calculateCost(request.model, promptTokens, completionTokens)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
}
)
return streamingResult
}
let geminiResponse = await response.json()
if (payload.generationConfig?.responseSchema) {
const candidate = geminiResponse.candidates?.[0]
if (candidate?.content?.parts?.[0]?.text) {
const text = candidate.content.parts[0].text
try {
JSON.parse(text)
logger.info('Successfully received structured JSON output')
} catch (_e) {
logger.warn('Failed to parse structured output as JSON')
}
}
}
let content = ''
let tokens = {
prompt: 0,
completion: 0,
total: 0,
}
const toolCalls = []
const toolResults = []
let iterationCount = 0
const originalToolConfig = preparedTools?.toolConfig
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
let hasUsedForcedTool = false
let currentToolConfig = originalToolConfig
const checkForForcedToolUsage = (functionCall: { name: string; args: any }) => {
if (currentToolConfig && forcedTools.length > 0) {
const toolCallsForTracking = [{ name: functionCall.name, arguments: functionCall.args }]
const result = trackForcedToolUsage(
toolCallsForTracking,
currentToolConfig,
logger,
'google',
forcedTools,
usedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
usedForcedTools = result.usedForcedTools
if (result.nextToolConfig) {
currentToolConfig = result.nextToolConfig
logger.info('Updated tool config for next iteration', {
hasNextToolConfig: !!currentToolConfig,
usedForcedTools: usedForcedTools,
})
}
}
}
let modelTime = firstResponseTime
let toolsTime = 0
const timeSegments: TimeSegment[] = [
{
type: 'model',
name: 'Initial response',
startTime: initialCallTime,
endTime: initialCallTime + firstResponseTime,
duration: firstResponseTime,
},
]
try {
const candidate = geminiResponse.candidates?.[0]
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
logger.warn(
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
{
finishReason: candidate.finishReason,
hasContent: !!candidate?.content,
hasParts: !!candidate?.content?.parts,
}
)
content = extractTextContent(candidate)
}
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.info(`Received function call from Vertex AI: ${functionCall.name}`)
while (iterationCount < MAX_TOOL_ITERATIONS) {
const latestResponse = geminiResponse.candidates?.[0]
const latestFunctionCall = extractFunctionCall(latestResponse)
if (!latestFunctionCall) {
content = extractTextContent(latestResponse)
break
}
logger.info(
`Processing function call: ${latestFunctionCall.name} (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
const toolsStartTime = Date.now()
try {
const toolName = latestFunctionCall.name
const toolArgs = latestFunctionCall.args || {}
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) {
logger.warn(`Tool ${toolName} not found in registry, skipping`)
break
}
const toolCallStartTime = Date.now()
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
})
let resultContent: any
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
} else {
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
tool: toolName,
}
}
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
result: resultContent,
success: result.success,
})
const simplifiedMessages = [
...(contents.filter((m) => m.role === 'user').length > 0
? [contents.filter((m) => m.role === 'user')[0]]
: [contents[0]]),
{
role: 'model',
parts: [
{
functionCall: {
name: latestFunctionCall.name,
args: latestFunctionCall.args,
},
},
],
},
{
role: 'user',
parts: [
{
text: `Function ${latestFunctionCall.name} result: ${JSON.stringify(resultContent)}`,
},
],
},
]
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
checkForForcedToolUsage(latestFunctionCall)
const nextModelStartTime = Date.now()
try {
if (request.stream) {
const streamingPayload = {
...payload,
contents: simplifiedMessages,
}
const allForcedToolsUsed =
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
if (allForcedToolsUsed && request.responseFormat) {
streamingPayload.tools = undefined
streamingPayload.toolConfig = undefined
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!streamingPayload.generationConfig) {
streamingPayload.generationConfig = {}
}
streamingPayload.generationConfig.responseMimeType = 'application/json'
streamingPayload.generationConfig.responseSchema = cleanSchema
logger.info('Using structured output for final response after tool execution')
} else {
if (currentToolConfig) {
streamingPayload.toolConfig = currentToolConfig
} else {
streamingPayload.toolConfig = { functionCallingConfig: { mode: 'AUTO' } }
}
}
const checkPayload = {
...streamingPayload,
}
checkPayload.stream = undefined
const checkEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
false
)
const checkResponse = await fetch(checkEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(checkPayload),
})
if (!checkResponse.ok) {
const errorBody = await checkResponse.text()
logger.error('Error in Vertex AI check request:', {
status: checkResponse.status,
statusText: checkResponse.statusText,
responseBody: errorBody,
})
throw new Error(
`Vertex AI API check error: ${checkResponse.status} ${checkResponse.statusText}`
)
}
const checkResult = await checkResponse.json()
const checkCandidate = checkResult.candidates?.[0]
const checkFunctionCall = extractFunctionCall(checkCandidate)
if (checkFunctionCall) {
logger.info(
'Function call detected in follow-up, handling in non-streaming mode',
{
functionName: checkFunctionCall.name,
}
)
geminiResponse = checkResult
if (checkResult.usageMetadata) {
tokens.prompt += checkResult.usageMetadata.promptTokenCount || 0
tokens.completion += checkResult.usageMetadata.candidatesTokenCount || 0
tokens.total +=
(checkResult.usageMetadata.promptTokenCount || 0) +
(checkResult.usageMetadata.candidatesTokenCount || 0)
}
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
modelTime += thisModelTime
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
startTime: nextModelStartTime,
endTime: nextModelEndTime,
duration: thisModelTime,
})
iterationCount++
continue
}
logger.info('No function call detected, proceeding with streaming response')
if (request.responseFormat) {
streamingPayload.tools = undefined
streamingPayload.toolConfig = undefined
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!streamingPayload.generationConfig) {
streamingPayload.generationConfig = {}
}
streamingPayload.generationConfig.responseMimeType = 'application/json'
streamingPayload.generationConfig.responseSchema = cleanSchema
logger.info(
'Using structured output for final streaming response after tool execution'
)
}
const streamEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
true
)
const streamingResponse = await fetch(streamEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(streamingPayload),
})
if (!streamingResponse.ok) {
const errorBody = await streamingResponse.text()
logger.error('Error in Vertex AI streaming follow-up request:', {
status: streamingResponse.status,
statusText: streamingResponse.statusText,
responseBody: errorBody,
})
throw new Error(
`Vertex AI API streaming error: ${streamingResponse.status} ${streamingResponse.statusText}`
)
}
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
modelTime += thisModelTime
timeSegments.push({
type: 'model',
name: 'Final streaming response after tool calls',
startTime: nextModelStartTime,
endTime: nextModelEndTime,
duration: thisModelTime,
})
const streamingExecution: StreamingExecution = {
stream: null as any,
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens,
toolCalls:
toolCalls.length > 0
? {
list: toolCalls,
count: toolCalls.length,
}
: undefined,
toolResults,
providerTiming: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
modelTime,
toolsTime,
firstResponseTime,
iterations: iterationCount + 1,
timeSegments,
},
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
duration: Date.now() - providerStartTime,
},
isStreaming: true,
},
}
streamingExecution.stream = createReadableStreamFromVertexStream(
streamingResponse,
(content, usage) => {
streamingExecution.execution.output.content = content
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
if (streamingExecution.execution.output.providerTiming) {
streamingExecution.execution.output.providerTiming.endTime =
streamEndTimeISO
streamingExecution.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
}
const promptTokens = usage?.promptTokenCount || 0
const completionTokens = usage?.candidatesTokenCount || 0
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
const existingTokens = streamingExecution.execution.output.tokens || {
prompt: 0,
completion: 0,
total: 0,
}
const existingPrompt = existingTokens.prompt || 0
const existingCompletion = existingTokens.completion || 0
const existingTotal = existingTokens.total || 0
streamingExecution.execution.output.tokens = {
prompt: existingPrompt + promptTokens,
completion: existingCompletion + completionTokens,
total: existingTotal + totalTokens,
}
const accumulatedCost = calculateCost(
request.model,
existingPrompt,
existingCompletion
)
const streamCost = calculateCost(
request.model,
promptTokens,
completionTokens
)
streamingExecution.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}
)
return streamingExecution
}
const nextPayload = {
...payload,
contents: simplifiedMessages,
}
const allForcedToolsUsed =
forcedTools.length > 0 && usedForcedTools.length === forcedTools.length
if (allForcedToolsUsed && request.responseFormat) {
nextPayload.tools = undefined
nextPayload.toolConfig = undefined
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!nextPayload.generationConfig) {
nextPayload.generationConfig = {}
}
nextPayload.generationConfig.responseMimeType = 'application/json'
nextPayload.generationConfig.responseSchema = cleanSchema
logger.info(
'Using structured output for final non-streaming response after tool execution'
)
} else {
if (currentToolConfig) {
nextPayload.toolConfig = currentToolConfig
}
}
const nextEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
false
)
const nextResponse = await fetch(nextEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(nextPayload),
})
if (!nextResponse.ok) {
const errorBody = await nextResponse.text()
logger.error('Error in Vertex AI follow-up request:', {
status: nextResponse.status,
statusText: nextResponse.statusText,
responseBody: errorBody,
iterationCount,
})
break
}
geminiResponse = await nextResponse.json()
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
startTime: nextModelStartTime,
endTime: nextModelEndTime,
duration: thisModelTime,
})
modelTime += thisModelTime
const nextCandidate = geminiResponse.candidates?.[0]
const nextFunctionCall = extractFunctionCall(nextCandidate)
if (!nextFunctionCall) {
if (request.responseFormat) {
const finalPayload = {
...payload,
contents: nextPayload.contents,
tools: undefined,
toolConfig: undefined,
}
const responseFormatSchema =
request.responseFormat.schema || request.responseFormat
const cleanSchema = cleanSchemaForGemini(responseFormatSchema)
if (!finalPayload.generationConfig) {
finalPayload.generationConfig = {}
}
finalPayload.generationConfig.responseMimeType = 'application/json'
finalPayload.generationConfig.responseSchema = cleanSchema
logger.info('Making final request with structured output after tool execution')
const finalEndpoint = buildVertexEndpoint(
vertexProject,
vertexLocation,
requestedModel,
false
)
const finalResponse = await fetch(finalEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${request.apiKey}`,
},
body: JSON.stringify(finalPayload),
})
if (finalResponse.ok) {
const finalResult = await finalResponse.json()
const finalCandidate = finalResult.candidates?.[0]
content = extractTextContent(finalCandidate)
if (finalResult.usageMetadata) {
tokens.prompt += finalResult.usageMetadata.promptTokenCount || 0
tokens.completion += finalResult.usageMetadata.candidatesTokenCount || 0
tokens.total +=
(finalResult.usageMetadata.promptTokenCount || 0) +
(finalResult.usageMetadata.candidatesTokenCount || 0)
}
} else {
logger.warn(
'Failed to get structured output, falling back to regular response'
)
content = extractTextContent(nextCandidate)
}
} else {
content = extractTextContent(nextCandidate)
}
break
}
iterationCount++
} catch (error) {
logger.error('Error in Vertex AI follow-up request:', {
error: error instanceof Error ? error.message : String(error),
iterationCount,
})
break
}
} catch (error) {
logger.error('Error processing function call:', {
error: error instanceof Error ? error.message : String(error),
functionName: latestFunctionCall?.name || 'unknown',
})
break
}
}
} else {
content = extractTextContent(candidate)
}
} catch (error) {
logger.error('Error processing Vertex AI response:', {
error: error instanceof Error ? error.message : String(error),
iterationCount,
})
if (!content && toolCalls.length > 0) {
content = `Tool call(s) executed: ${toolCalls.map((t) => t.name).join(', ')}. Results are available in the tool results.`
}
}
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
if (geminiResponse.usageMetadata) {
tokens = {
prompt: geminiResponse.usageMetadata.promptTokenCount || 0,
completion: geminiResponse.usageMetadata.candidatesTokenCount || 0,
total:
(geminiResponse.usageMetadata.promptTokenCount || 0) +
(geminiResponse.usageMetadata.candidatesTokenCount || 0),
}
}
return {
content,
model: request.model,
tokens,
toolCalls: toolCalls.length > 0 ? toolCalls : undefined,
toolResults: toolResults.length > 0 ? toolResults : undefined,
timing: {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
modelTime: modelTime,
toolsTime: toolsTime,
firstResponseTime: firstResponseTime,
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
}
} catch (error) {
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
logger.error('Error in Vertex AI request:', {
error: error instanceof Error ? error.message : String(error),
duration: totalDuration,
})
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,
duration: totalDuration,
}
throw enhancedError
}
// Create an OAuth2Client and set the access token
// This allows us to use an OAuth access token with the SDK
const authClient = new OAuth2Client()
authClient.setCredentials({ access_token: request.apiKey })
// Create client with Vertex AI configuration
const ai = new GoogleGenAI({
vertexai: true,
project: vertexProject,
location: vertexLocation,
googleAuthOptions: {
authClient,
},
})
return executeGeminiRequest({
ai,
model,
request,
providerType: 'vertex',
})
},
}

View File

@@ -1,231 +0,0 @@
import { createLogger } from '@/lib/logs/console/logger'
import { extractFunctionCall, extractTextContent } from '@/providers/google/utils'
const logger = createLogger('VertexUtils')
/**
* Creates a ReadableStream from Vertex AI's Gemini stream response
*/
export function createReadableStreamFromVertexStream(
response: Response,
onComplete?: (
content: string,
usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number }
) => void
): ReadableStream<Uint8Array> {
const reader = response.body?.getReader()
if (!reader) {
throw new Error('Failed to get reader from response body')
}
return new ReadableStream({
async start(controller) {
try {
let buffer = ''
let fullContent = ''
let usageData: {
promptTokenCount?: number
candidatesTokenCount?: number
totalTokenCount?: number
} | null = null
while (true) {
const { done, value } = await reader.read()
if (done) {
if (buffer.trim()) {
try {
const data = JSON.parse(buffer.trim())
if (data.usageMetadata) {
usageData = data.usageMetadata
}
const candidate = data.candidates?.[0]
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in final buffer, ending stream to execute tool',
{
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
} catch (e) {
if (buffer.trim().startsWith('[')) {
try {
const dataArray = JSON.parse(buffer.trim())
if (Array.isArray(dataArray)) {
for (const item of dataArray) {
if (item.usageMetadata) {
usageData = item.usageMetadata
}
const candidate = item.candidates?.[0]
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in array item, ending stream to execute tool',
{
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
}
}
} catch (arrayError) {}
}
}
}
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
break
}
const text = new TextDecoder().decode(value)
buffer += text
let searchIndex = 0
while (searchIndex < buffer.length) {
const openBrace = buffer.indexOf('{', searchIndex)
if (openBrace === -1) break
let braceCount = 0
let inString = false
let escaped = false
let closeBrace = -1
for (let i = openBrace; i < buffer.length; i++) {
const char = buffer[i]
if (!inString) {
if (char === '"' && !escaped) {
inString = true
} else if (char === '{') {
braceCount++
} else if (char === '}') {
braceCount--
if (braceCount === 0) {
closeBrace = i
break
}
}
} else {
if (char === '"' && !escaped) {
inString = false
}
}
escaped = char === '\\' && !escaped
}
if (closeBrace !== -1) {
const jsonStr = buffer.substring(openBrace, closeBrace + 1)
try {
const data = JSON.parse(jsonStr)
if (data.usageMetadata) {
usageData = data.usageMetadata
}
const candidate = data.candidates?.[0]
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
logger.warn(
'Vertex AI returned UNEXPECTED_TOOL_CALL - model attempted to call a tool that was not provided',
{
finishReason: candidate.finishReason,
hasContent: !!candidate?.content,
hasParts: !!candidate?.content?.parts,
}
)
const textContent = extractTextContent(candidate)
if (textContent) {
fullContent += textContent
controller.enqueue(new TextEncoder().encode(textContent))
}
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in stream, ending stream to execute tool',
{
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
controller.close()
return
}
const content = extractTextContent(candidate)
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
} catch (e) {
logger.error('Error parsing JSON from stream', {
error: e instanceof Error ? e.message : String(e),
jsonPreview: jsonStr.substring(0, 200),
})
}
buffer = buffer.substring(closeBrace + 1)
searchIndex = 0
} else {
break
}
}
}
} catch (e) {
logger.error('Error reading Vertex AI stream', {
error: e instanceof Error ? e.message : String(e),
})
controller.error(e)
}
},
async cancel() {
await reader.cancel()
},
})
}
/**
* Build Vertex AI endpoint URL
*/
export function buildVertexEndpoint(
project: string,
location: string,
model: string,
isStreaming: boolean
): string {
const action = isStreaming ? 'streamGenerateContent' : 'generateContent'
if (location === 'global') {
return `https://aiplatform.googleapis.com/v1/projects/${project}/locations/global/publishers/google/models/${model}:${action}`
}
return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/google/models/${model}:${action}`
}

View File

@@ -130,7 +130,7 @@ export const vllmProvider: ProviderConfig = {
: undefined
const payload: any = {
model: (request.model || getProviderDefaultModel('vllm')).replace(/^vllm\//, ''),
model: request.model.replace(/^vllm\//, ''),
messages: allMessages,
}

View File

@@ -48,7 +48,7 @@ export const xAIProvider: ProviderConfig = {
hasTools: !!request.tools?.length,
toolCount: request.tools?.length || 0,
hasResponseFormat: !!request.responseFormat,
model: request.model || 'grok-3-latest',
model: request.model,
streaming: !!request.stream,
})
@@ -87,7 +87,7 @@ export const xAIProvider: ProviderConfig = {
)
}
const basePayload: any = {
model: request.model || 'grok-3-latest',
model: request.model,
messages: allMessages,
}
@@ -139,7 +139,7 @@ export const xAIProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'grok-3-latest',
model: request.model,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
@@ -505,7 +505,7 @@ export const xAIProvider: ProviderConfig = {
success: true,
output: {
content: '',
model: request.model || 'grok-3-latest',
model: request.model,
tokens: {
prompt: tokens.prompt,
completion: tokens.completion,

View File

@@ -41,6 +41,8 @@ import { SetEnvironmentVariablesClientTool } from '@/lib/copilot/tools/client/us
import { CheckDeploymentStatusClientTool } from '@/lib/copilot/tools/client/workflow/check-deployment-status'
import { DeployWorkflowClientTool } from '@/lib/copilot/tools/client/workflow/deploy-workflow'
import { EditWorkflowClientTool } from '@/lib/copilot/tools/client/workflow/edit-workflow'
import { GetBlockOutputsClientTool } from '@/lib/copilot/tools/client/workflow/get-block-outputs'
import { GetBlockUpstreamReferencesClientTool } from '@/lib/copilot/tools/client/workflow/get-block-upstream-references'
import { GetUserWorkflowClientTool } from '@/lib/copilot/tools/client/workflow/get-user-workflow'
import { GetWorkflowConsoleClientTool } from '@/lib/copilot/tools/client/workflow/get-workflow-console'
import { GetWorkflowDataClientTool } from '@/lib/copilot/tools/client/workflow/get-workflow-data'
@@ -110,6 +112,8 @@ const CLIENT_TOOL_INSTANTIATORS: Record<string, (id: string) => any> = {
manage_custom_tool: (id) => new ManageCustomToolClientTool(id),
manage_mcp_tool: (id) => new ManageMcpToolClientTool(id),
sleep: (id) => new SleepClientTool(id),
get_block_outputs: (id) => new GetBlockOutputsClientTool(id),
get_block_upstream_references: (id) => new GetBlockUpstreamReferencesClientTool(id),
}
// Read-only static metadata for class-based tools (no instances)
@@ -150,6 +154,8 @@ export const CLASS_TOOL_METADATA: Record<string, BaseClientToolMetadata | undefi
manage_custom_tool: (ManageCustomToolClientTool as any)?.metadata,
manage_mcp_tool: (ManageMcpToolClientTool as any)?.metadata,
sleep: (SleepClientTool as any)?.metadata,
get_block_outputs: (GetBlockOutputsClientTool as any)?.metadata,
get_block_upstream_references: (GetBlockUpstreamReferencesClientTool as any)?.metadata,
}
function ensureClientToolInstance(toolName: string | undefined, toolCallId: string | undefined) {

View File

@@ -144,7 +144,7 @@ describe('workflow store', () => {
expect(state.blocks.parallel1?.data?.count).toBe(5)
expect(state.parallels.parallel1).toBeDefined()
expect(state.parallels.parallel1.distribution).toBe('')
expect(state.parallels.parallel1.distribution).toBeUndefined()
})
it.concurrent('should regenerate parallels when updateParallelCollection is called', () => {

View File

@@ -103,7 +103,7 @@ export function convertParallelBlockToParallel(
: 'collection'
const distribution =
validatedParallelType === 'collection' ? parallelBlock.data?.collection || '' : ''
validatedParallelType === 'collection' ? parallelBlock.data?.collection || '' : undefined
const count = parallelBlock.data?.count || 5

View File

@@ -141,6 +141,10 @@ spec:
ports:
- protocol: TCP
port: 443
# Allow custom egress rules
{{- with .Values.networkPolicy.egress }}
{{- toYaml . | nindent 2 }}
{{- end }}
{{- end }}
{{- if .Values.postgresql.enabled }}