fix(models): memory fixes, provider code typing, cost calculation cleanup (#2515)

* improvement(memory): should not be block scoped

* cleanup provider code

* update other providers

* cleanup fallback code

* remove flaky test

* fix memory

* move streaming fix to right level

* cleanup streaming server

* make memories workspace scoped

* update docs

* fix dedup logic

* fix streaming parsing issue for multiple onStream calls for same block

* fix(provieders): support parallel agent tool calls, consolidate utils

* address greptile comments

* remove all comments

* fixed openrouter response format handling, groq & cerebras response formats

* removed duplicate type

---------

Co-authored-by: waleed <walif6@gmail.com>
This commit is contained in:
Vikhyath Mondreti
2025-12-22 15:59:53 -08:00
committed by GitHub
parent 086982c7a3
commit 8c2c49eb14
65 changed files with 12201 additions and 4626 deletions

View File

@@ -1,6 +1,6 @@
---
title: Memory
description: Add memory store
description: Store and retrieve conversation history
---
import { BlockInfoCard } from "@/components/ui/block-info-card"
@@ -10,100 +10,94 @@ import { BlockInfoCard } from "@/components/ui/block-info-card"
color="#F64F9E"
/>
## Usage Instructions
Integrate Memory into the workflow. Can add, get a memory, get all memories, and delete memories.
## Overview
The Memory block stores conversation history for agents. Each memory is identified by a `conversationId` that you provide. Multiple agents can share the same memory by using the same `conversationId`.
Memory stores only user and assistant messages. System messages are not stored—they are configured in the Agent block and prefixed at runtime.
## Tools
### `memory_add`
Add a new memory to the database or append to existing memory with the same ID.
Add a message to memory. Creates a new memory if the `conversationId` doesn't exist, or appends to existing memory.
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `conversationId` | string | No | Conversation identifier \(e.g., user-123, session-abc\). If a memory with this conversationId already exists for this block, the new message will be appended to it. |
| `id` | string | No | Legacy parameter for conversation identifier. Use conversationId instead. Provided for backwards compatibility. |
| `role` | string | Yes | Role for agent memory \(user, assistant, or system\) |
| `content` | string | Yes | Content for agent memory |
| `blockId` | string | No | Optional block ID. If not provided, uses the current block ID from execution context, or defaults to "default". |
| `conversationId` | string | Yes | Unique identifier for the conversation (e.g., `user-123`, `session-abc`) |
| `role` | string | Yes | Message role: `user` or `assistant` |
| `content` | string | Yes | Message content |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `success` | boolean | Whether the memory was added successfully |
| `memories` | array | Array of memory objects including the new or updated memory |
| `error` | string | Error message if operation failed |
| `success` | boolean | Whether the operation succeeded |
| `memories` | array | Updated memory array |
| `error` | string | Error message if failed |
### `memory_get`
Retrieve memory by conversationId, blockId, blockName, or a combination. Returns all matching memories.
Retrieve memory by conversation ID.
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `conversationId` | string | No | Conversation identifier \(e.g., user-123, session-abc\). If provided alone, returns all memories for this conversation across all blocks. |
| `id` | string | No | Legacy parameter for conversation identifier. Use conversationId instead. Provided for backwards compatibility. |
| `blockId` | string | No | Block identifier. If provided alone, returns all memories for this block across all conversations. If provided with conversationId, returns memories for that specific conversation in this block. |
| `blockName` | string | No | Block name. Alternative to blockId. If provided alone, returns all memories for blocks with this name. If provided with conversationId, returns memories for that conversation in blocks with this name. |
| `conversationId` | string | Yes | Conversation identifier |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `success` | boolean | Whether the memory was retrieved successfully |
| `memories` | array | Array of memory objects with conversationId, blockId, blockName, and data fields |
| `message` | string | Success or error message |
| `error` | string | Error message if operation failed |
| `success` | boolean | Whether the operation succeeded |
| `memories` | array | Array of messages with `role` and `content` |
| `error` | string | Error message if failed |
### `memory_get_all`
Retrieve all memories from the database
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
Retrieve all memories for the current workspace.
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `success` | boolean | Whether all memories were retrieved successfully |
| `memories` | array | Array of all memory objects with key, conversationId, blockId, blockName, and data fields |
| `message` | string | Success or error message |
| `error` | string | Error message if operation failed |
| `success` | boolean | Whether the operation succeeded |
| `memories` | array | All memory objects with `conversationId` and `data` fields |
| `error` | string | Error message if failed |
### `memory_delete`
Delete memories by conversationId, blockId, blockName, or a combination. Supports bulk deletion.
Delete memory by conversation ID.
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `conversationId` | string | No | Conversation identifier \(e.g., user-123, session-abc\). If provided alone, deletes all memories for this conversation across all blocks. |
| `id` | string | No | Legacy parameter for conversation identifier. Use conversationId instead. Provided for backwards compatibility. |
| `blockId` | string | No | Block identifier. If provided alone, deletes all memories for this block across all conversations. If provided with conversationId, deletes memories for that specific conversation in this block. |
| `blockName` | string | No | Block name. Alternative to blockId. If provided alone, deletes all memories for blocks with this name. If provided with conversationId, deletes memories for that conversation in blocks with this name. |
| `conversationId` | string | Yes | Conversation identifier to delete |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `success` | boolean | Whether the memory was deleted successfully |
| `message` | string | Success or error message |
| `error` | string | Error message if operation failed |
| `success` | boolean | Whether the operation succeeded |
| `message` | string | Confirmation message |
| `error` | string | Error message if failed |
## Agent Memory Types
When using memory with an Agent block, you can configure how conversation history is managed:
| Type | Description |
| ---- | ----------- |
| **Full Conversation** | Stores all messages, limited by model's context window (uses 90% to leave room for response) |
| **Sliding Window (Messages)** | Keeps the last N messages (default: 10) |
| **Sliding Window (Tokens)** | Keeps messages that fit within a token limit (default: 4000) |
## Notes
- Category: `blocks`
- Type: `memory`
- Memory is scoped per workspace—workflows in the same workspace share the memory store
- Use unique `conversationId` values to keep conversations separate (e.g., session IDs, user IDs, or UUIDs)
- System messages belong in the Agent block configuration, not in memory

View File

@@ -70,19 +70,6 @@ vi.mock('@/lib/core/utils/request', () => ({
generateRequestId: vi.fn().mockReturnValue('test-request-id'),
}))
vi.mock('@/app/api/workflows/[id]/execute/route', () => ({
createFilteredResult: vi.fn().mockImplementation((result: any) => ({
...result,
logs: undefined,
metadata: result.metadata
? {
...result.metadata,
workflowConnections: undefined,
}
: undefined,
})),
}))
describe('Chat Identifier API Route', () => {
const mockAddCorsHeaders = vi.fn().mockImplementation((response) => response)
const mockValidateChatAuth = vi.fn().mockResolvedValue({ authorized: true })

View File

@@ -206,7 +206,6 @@ export async function POST(
const { createStreamingResponse } = await import('@/lib/workflows/streaming/streaming')
const { SSE_HEADERS } = await import('@/lib/core/utils/sse')
const { createFilteredResult } = await import('@/app/api/workflows/[id]/execute/route')
const workflowInput: any = { input, conversationId }
if (files && Array.isArray(files) && files.length > 0) {
@@ -267,7 +266,6 @@ export async function POST(
isSecureMode: true,
workflowTriggerType: 'chat',
},
createFilteredResult,
executionId,
})

View File

@@ -1,54 +1,16 @@
import { db } from '@sim/db'
import { memory, workflowBlocks } from '@sim/db/schema'
import { memory, permissions, workspace } from '@sim/db/schema'
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { checkHybridAuth } from '@/lib/auth/hybrid'
import { generateRequestId } from '@/lib/core/utils/request'
import { createLogger } from '@/lib/logs/console/logger'
import { getWorkflowAccessContext } from '@/lib/workflows/utils'
const logger = createLogger('MemoryByIdAPI')
/**
* Parse memory key into conversationId and blockId
* Key format: conversationId:blockId
*/
function parseMemoryKey(key: string): { conversationId: string; blockId: string } | null {
const parts = key.split(':')
if (parts.length !== 2) {
return null
}
return {
conversationId: parts[0],
blockId: parts[1],
}
}
/**
* Lookup block name from block ID
*/
async function getBlockName(blockId: string, workflowId: string): Promise<string | undefined> {
try {
const result = await db
.select({ name: workflowBlocks.name })
.from(workflowBlocks)
.where(and(eq(workflowBlocks.id, blockId), eq(workflowBlocks.workflowId, workflowId)))
.limit(1)
if (result.length === 0) {
return undefined
}
return result[0].name
} catch (error) {
logger.error('Error looking up block name', { error, blockId, workflowId })
return undefined
}
}
const memoryQuerySchema = z.object({
workflowId: z.string().uuid('Invalid workflow ID format'),
workspaceId: z.string().uuid('Invalid workspace ID format'),
})
const agentMemoryDataSchema = z.object({
@@ -64,26 +26,56 @@ const memoryPutBodySchema = z.object({
data: z.union([agentMemoryDataSchema, genericMemoryDataSchema], {
errorMap: () => ({ message: 'Invalid memory data structure' }),
}),
workflowId: z.string().uuid('Invalid workflow ID format'),
workspaceId: z.string().uuid('Invalid workspace ID format'),
})
/**
* Validates authentication and workflow access for memory operations
* @param request - The incoming request
* @param workflowId - The workflow ID to check access for
* @param requestId - Request ID for logging
* @param action - 'read' for GET, 'write' for PUT/DELETE
* @returns Object with userId if successful, or error response if failed
*/
async function checkWorkspaceAccess(
workspaceId: string,
userId: string
): Promise<{ hasAccess: boolean; canWrite: boolean }> {
const [workspaceRow] = await db
.select({ ownerId: workspace.ownerId })
.from(workspace)
.where(eq(workspace.id, workspaceId))
.limit(1)
if (!workspaceRow) {
return { hasAccess: false, canWrite: false }
}
if (workspaceRow.ownerId === userId) {
return { hasAccess: true, canWrite: true }
}
const [permissionRow] = await db
.select({ permissionType: permissions.permissionType })
.from(permissions)
.where(
and(
eq(permissions.userId, userId),
eq(permissions.entityType, 'workspace'),
eq(permissions.entityId, workspaceId)
)
)
.limit(1)
if (!permissionRow) {
return { hasAccess: false, canWrite: false }
}
return {
hasAccess: true,
canWrite: permissionRow.permissionType === 'write' || permissionRow.permissionType === 'admin',
}
}
async function validateMemoryAccess(
request: NextRequest,
workflowId: string,
workspaceId: string,
requestId: string,
action: 'read' | 'write'
): Promise<{ userId: string } | { error: NextResponse }> {
const authResult = await checkHybridAuth(request, {
requireWorkflowId: false,
})
const authResult = await checkHybridAuth(request, { requireWorkflowId: false })
if (!authResult.success || !authResult.userId) {
logger.warn(`[${requestId}] Unauthorized memory ${action} attempt`)
return {
@@ -94,30 +86,20 @@ async function validateMemoryAccess(
}
}
const accessContext = await getWorkflowAccessContext(workflowId, authResult.userId)
if (!accessContext) {
logger.warn(`[${requestId}] Workflow ${workflowId} not found`)
const { hasAccess, canWrite } = await checkWorkspaceAccess(workspaceId, authResult.userId)
if (!hasAccess) {
return {
error: NextResponse.json(
{ success: false, error: { message: 'Workflow not found' } },
{ success: false, error: { message: 'Workspace not found' } },
{ status: 404 }
),
}
}
const { isOwner, workspacePermission } = accessContext
const hasAccess =
action === 'read'
? isOwner || workspacePermission !== null
: isOwner || workspacePermission === 'write' || workspacePermission === 'admin'
if (!hasAccess) {
logger.warn(
`[${requestId}] User ${authResult.userId} denied ${action} access to workflow ${workflowId}`
)
if (action === 'write' && !canWrite) {
return {
error: NextResponse.json(
{ success: false, error: { message: 'Access denied' } },
{ success: false, error: { message: 'Write access denied' } },
{ status: 403 }
),
}
@@ -129,40 +111,28 @@ async function validateMemoryAccess(
export const dynamic = 'force-dynamic'
export const runtime = 'nodejs'
/**
* GET handler for retrieving a specific memory by ID
*/
export async function GET(request: NextRequest, { params }: { params: Promise<{ id: string }> }) {
const requestId = generateRequestId()
const { id } = await params
try {
logger.info(`[${requestId}] Processing memory get request for ID: ${id}`)
const url = new URL(request.url)
const workflowId = url.searchParams.get('workflowId')
const validation = memoryQuerySchema.safeParse({ workflowId })
const workspaceId = url.searchParams.get('workspaceId')
const validation = memoryQuerySchema.safeParse({ workspaceId })
if (!validation.success) {
const errorMessage = validation.error.errors
.map((err) => `${err.path.join('.')}: ${err.message}`)
.join(', ')
logger.warn(`[${requestId}] Validation error: ${errorMessage}`)
return NextResponse.json(
{
success: false,
error: {
message: errorMessage,
},
},
{ success: false, error: { message: errorMessage } },
{ status: 400 }
)
}
const { workflowId: validatedWorkflowId } = validation.data
const { workspaceId: validatedWorkspaceId } = validation.data
const accessCheck = await validateMemoryAccess(request, validatedWorkflowId, requestId, 'read')
const accessCheck = await validateMemoryAccess(request, validatedWorkspaceId, requestId, 'read')
if ('error' in accessCheck) {
return accessCheck.error
}
@@ -170,72 +140,33 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
const memories = await db
.select()
.from(memory)
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
.orderBy(memory.createdAt)
.limit(1)
if (memories.length === 0) {
logger.warn(`[${requestId}] Memory not found: ${id} for workflow: ${validatedWorkflowId}`)
return NextResponse.json(
{
success: false,
error: {
message: 'Memory not found',
},
},
{ success: false, error: { message: 'Memory not found' } },
{ status: 404 }
)
}
const mem = memories[0]
const parsed = parseMemoryKey(mem.key)
let enrichedMemory
if (!parsed) {
enrichedMemory = {
conversationId: mem.key,
blockId: 'unknown',
blockName: 'unknown',
data: mem.data,
}
} else {
const { conversationId, blockId } = parsed
const blockName = (await getBlockName(blockId, validatedWorkflowId)) || 'unknown'
enrichedMemory = {
conversationId,
blockId,
blockName,
data: mem.data,
}
}
logger.info(
`[${requestId}] Memory retrieved successfully: ${id} for workflow: ${validatedWorkflowId}`
)
logger.info(`[${requestId}] Memory retrieved: ${id} for workspace: ${validatedWorkspaceId}`)
return NextResponse.json(
{
success: true,
data: enrichedMemory,
},
{ success: true, data: { conversationId: mem.key, data: mem.data } },
{ status: 200 }
)
} catch (error: any) {
logger.error(`[${requestId}] Error retrieving memory`, { error })
return NextResponse.json(
{
success: false,
error: {
message: error.message || 'Failed to retrieve memory',
},
},
{ success: false, error: { message: error.message || 'Failed to retrieve memory' } },
{ status: 500 }
)
}
}
/**
* DELETE handler for removing a specific memory
*/
export async function DELETE(
request: NextRequest,
{ params }: { params: Promise<{ id: string }> }
@@ -244,32 +175,28 @@ export async function DELETE(
const { id } = await params
try {
logger.info(`[${requestId}] Processing memory delete request for ID: ${id}`)
const url = new URL(request.url)
const workflowId = url.searchParams.get('workflowId')
const validation = memoryQuerySchema.safeParse({ workflowId })
const workspaceId = url.searchParams.get('workspaceId')
const validation = memoryQuerySchema.safeParse({ workspaceId })
if (!validation.success) {
const errorMessage = validation.error.errors
.map((err) => `${err.path.join('.')}: ${err.message}`)
.join(', ')
logger.warn(`[${requestId}] Validation error: ${errorMessage}`)
return NextResponse.json(
{
success: false,
error: {
message: errorMessage,
},
},
{ success: false, error: { message: errorMessage } },
{ status: 400 }
)
}
const { workflowId: validatedWorkflowId } = validation.data
const { workspaceId: validatedWorkspaceId } = validation.data
const accessCheck = await validateMemoryAccess(request, validatedWorkflowId, requestId, 'write')
const accessCheck = await validateMemoryAccess(
request,
validatedWorkspaceId,
requestId,
'write'
)
if ('error' in accessCheck) {
return accessCheck.error
}
@@ -277,61 +204,41 @@ export async function DELETE(
const existingMemory = await db
.select({ id: memory.id })
.from(memory)
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
.limit(1)
if (existingMemory.length === 0) {
logger.warn(`[${requestId}] Memory not found: ${id} for workflow: ${validatedWorkflowId}`)
return NextResponse.json(
{
success: false,
error: {
message: 'Memory not found',
},
},
{ success: false, error: { message: 'Memory not found' } },
{ status: 404 }
)
}
await db
.delete(memory)
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
logger.info(
`[${requestId}] Memory deleted successfully: ${id} for workflow: ${validatedWorkflowId}`
)
logger.info(`[${requestId}] Memory deleted: ${id} for workspace: ${validatedWorkspaceId}`)
return NextResponse.json(
{
success: true,
data: { message: 'Memory deleted successfully' },
},
{ success: true, data: { message: 'Memory deleted successfully' } },
{ status: 200 }
)
} catch (error: any) {
logger.error(`[${requestId}] Error deleting memory`, { error })
return NextResponse.json(
{
success: false,
error: {
message: error.message || 'Failed to delete memory',
},
},
{ success: false, error: { message: error.message || 'Failed to delete memory' } },
{ status: 500 }
)
}
}
/**
* PUT handler for updating a specific memory
*/
export async function PUT(request: NextRequest, { params }: { params: Promise<{ id: string }> }) {
const requestId = generateRequestId()
const { id } = await params
try {
logger.info(`[${requestId}] Processing memory update request for ID: ${id}`)
let validatedData
let validatedWorkflowId
let validatedWorkspaceId
try {
const body = await request.json()
const validation = memoryPutBodySchema.safeParse(body)
@@ -340,34 +247,27 @@ export async function PUT(request: NextRequest, { params }: { params: Promise<{
const errorMessage = validation.error.errors
.map((err) => `${err.path.join('.')}: ${err.message}`)
.join(', ')
logger.warn(`[${requestId}] Validation error: ${errorMessage}`)
return NextResponse.json(
{
success: false,
error: {
message: `Invalid request body: ${errorMessage}`,
},
},
{ success: false, error: { message: `Invalid request body: ${errorMessage}` } },
{ status: 400 }
)
}
validatedData = validation.data.data
validatedWorkflowId = validation.data.workflowId
} catch (error: any) {
logger.warn(`[${requestId}] Failed to parse request body: ${error.message}`)
validatedWorkspaceId = validation.data.workspaceId
} catch {
return NextResponse.json(
{
success: false,
error: {
message: 'Invalid JSON in request body',
},
},
{ success: false, error: { message: 'Invalid JSON in request body' } },
{ status: 400 }
)
}
const accessCheck = await validateMemoryAccess(request, validatedWorkflowId, requestId, 'write')
const accessCheck = await validateMemoryAccess(
request,
validatedWorkspaceId,
requestId,
'write'
)
if ('error' in accessCheck) {
return accessCheck.error
}
@@ -375,18 +275,12 @@ export async function PUT(request: NextRequest, { params }: { params: Promise<{
const existingMemories = await db
.select()
.from(memory)
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
.limit(1)
if (existingMemories.length === 0) {
logger.warn(`[${requestId}] Memory not found: ${id} for workflow: ${validatedWorkflowId}`)
return NextResponse.json(
{
success: false,
error: {
message: 'Memory not found',
},
},
{ success: false, error: { message: 'Memory not found' } },
{ status: 404 }
)
}
@@ -396,14 +290,8 @@ export async function PUT(request: NextRequest, { params }: { params: Promise<{
const errorMessage = agentValidation.error.errors
.map((err) => `${err.path.join('.')}: ${err.message}`)
.join(', ')
logger.warn(`[${requestId}] Agent memory validation error: ${errorMessage}`)
return NextResponse.json(
{
success: false,
error: {
message: `Invalid agent memory data: ${errorMessage}`,
},
},
{ success: false, error: { message: `Invalid agent memory data: ${errorMessage}` } },
{ status: 400 }
)
}
@@ -411,59 +299,26 @@ export async function PUT(request: NextRequest, { params }: { params: Promise<{
const now = new Date()
await db
.update(memory)
.set({
data: validatedData,
updatedAt: now,
})
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
.set({ data: validatedData, updatedAt: now })
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
const updatedMemories = await db
.select()
.from(memory)
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
.limit(1)
const mem = updatedMemories[0]
const parsed = parseMemoryKey(mem.key)
let enrichedMemory
if (!parsed) {
enrichedMemory = {
conversationId: mem.key,
blockId: 'unknown',
blockName: 'unknown',
data: mem.data,
}
} else {
const { conversationId, blockId } = parsed
const blockName = (await getBlockName(blockId, validatedWorkflowId)) || 'unknown'
enrichedMemory = {
conversationId,
blockId,
blockName,
data: mem.data,
}
}
logger.info(
`[${requestId}] Memory updated successfully: ${id} for workflow: ${validatedWorkflowId}`
)
logger.info(`[${requestId}] Memory updated: ${id} for workspace: ${validatedWorkspaceId}`)
return NextResponse.json(
{
success: true,
data: enrichedMemory,
},
{ success: true, data: { conversationId: mem.key, data: mem.data } },
{ status: 200 }
)
} catch (error: any) {
logger.error(`[${requestId}] Error updating memory`, { error })
return NextResponse.json(
{
success: false,
error: {
message: error.message || 'Failed to update memory',
},
},
{ success: false, error: { message: error.message || 'Failed to update memory' } },
{ status: 500 }
)
}

View File

@@ -1,42 +1,56 @@
import { db } from '@sim/db'
import { memory, workflowBlocks } from '@sim/db/schema'
import { and, eq, inArray, isNull, like } from 'drizzle-orm'
import { memory, permissions, workspace } from '@sim/db/schema'
import { and, eq, isNull, like } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { checkHybridAuth } from '@/lib/auth/hybrid'
import { generateRequestId } from '@/lib/core/utils/request'
import { createLogger } from '@/lib/logs/console/logger'
import { getWorkflowAccessContext } from '@/lib/workflows/utils'
const logger = createLogger('MemoryAPI')
export const dynamic = 'force-dynamic'
export const runtime = 'nodejs'
/**
* Parse memory key into conversationId and blockId
* Key format: conversationId:blockId
* @param key The memory key to parse
* @returns Object with conversationId and blockId, or null if invalid
*/
function parseMemoryKey(key: string): { conversationId: string; blockId: string } | null {
const parts = key.split(':')
if (parts.length !== 2) {
return null
async function checkWorkspaceAccess(
workspaceId: string,
userId: string
): Promise<{ hasAccess: boolean; canWrite: boolean }> {
const [workspaceRow] = await db
.select({ ownerId: workspace.ownerId })
.from(workspace)
.where(eq(workspace.id, workspaceId))
.limit(1)
if (!workspaceRow) {
return { hasAccess: false, canWrite: false }
}
if (workspaceRow.ownerId === userId) {
return { hasAccess: true, canWrite: true }
}
const [permissionRow] = await db
.select({ permissionType: permissions.permissionType })
.from(permissions)
.where(
and(
eq(permissions.userId, userId),
eq(permissions.entityType, 'workspace'),
eq(permissions.entityId, workspaceId)
)
)
.limit(1)
if (!permissionRow) {
return { hasAccess: false, canWrite: false }
}
return {
conversationId: parts[0],
blockId: parts[1],
hasAccess: true,
canWrite: permissionRow.permissionType === 'write' || permissionRow.permissionType === 'admin',
}
}
/**
* GET handler for searching and retrieving memories
* Supports query parameters:
* - query: Search string for memory keys
* - type: Filter by memory type
* - limit: Maximum number of results (default: 50)
* - workflowId: Filter by workflow ID (required)
*/
export async function GET(request: NextRequest) {
const requestId = generateRequestId()
@@ -45,102 +59,32 @@ export async function GET(request: NextRequest) {
if (!authResult.success || !authResult.userId) {
logger.warn(`[${requestId}] Unauthorized memory access attempt`)
return NextResponse.json(
{
success: false,
error: {
message: authResult.error || 'Authentication required',
},
},
{ success: false, error: { message: authResult.error || 'Authentication required' } },
{ status: 401 }
)
}
logger.info(`[${requestId}] Processing memory search request`)
const url = new URL(request.url)
const workflowId = url.searchParams.get('workflowId')
const workspaceId = url.searchParams.get('workspaceId')
const searchQuery = url.searchParams.get('query')
const blockNameFilter = url.searchParams.get('blockName')
const limit = Number.parseInt(url.searchParams.get('limit') || '50')
if (!workflowId) {
logger.warn(`[${requestId}] Missing required parameter: workflowId`)
if (!workspaceId) {
return NextResponse.json(
{
success: false,
error: {
message: 'workflowId parameter is required',
},
},
{ success: false, error: { message: 'workspaceId parameter is required' } },
{ status: 400 }
)
}
const accessContext = await getWorkflowAccessContext(workflowId, authResult.userId)
if (!accessContext) {
logger.warn(`[${requestId}] Workflow ${workflowId} not found for user ${authResult.userId}`)
const { hasAccess } = await checkWorkspaceAccess(workspaceId, authResult.userId)
if (!hasAccess) {
return NextResponse.json(
{
success: false,
error: {
message: 'Workflow not found',
},
},
{ status: 404 }
)
}
const { workspacePermission, isOwner } = accessContext
if (!isOwner && !workspacePermission) {
logger.warn(
`[${requestId}] User ${authResult.userId} denied access to workflow ${workflowId}`
)
return NextResponse.json(
{
success: false,
error: {
message: 'Access denied to this workflow',
},
},
{ success: false, error: { message: 'Access denied to this workspace' } },
{ status: 403 }
)
}
logger.info(
`[${requestId}] User ${authResult.userId} (${authResult.authType}) accessing memories for workflow ${workflowId}`
)
const conditions = []
conditions.push(isNull(memory.deletedAt))
conditions.push(eq(memory.workflowId, workflowId))
let blockIdsToFilter: string[] | null = null
if (blockNameFilter) {
const blocks = await db
.select({ id: workflowBlocks.id })
.from(workflowBlocks)
.where(
and(eq(workflowBlocks.workflowId, workflowId), eq(workflowBlocks.name, blockNameFilter))
)
if (blocks.length === 0) {
logger.info(
`[${requestId}] No blocks found with name "${blockNameFilter}" for workflow: ${workflowId}`
)
return NextResponse.json(
{
success: true,
data: { memories: [] },
},
{ status: 200 }
)
}
blockIdsToFilter = blocks.map((b) => b.id)
}
const conditions = [isNull(memory.deletedAt), eq(memory.workspaceId, workspaceId)]
if (searchQuery) {
conditions.push(like(memory.key, `%${searchQuery}%`))
@@ -153,95 +97,27 @@ export async function GET(request: NextRequest) {
.orderBy(memory.createdAt)
.limit(limit)
const filteredMemories = blockIdsToFilter
? rawMemories.filter((mem) => {
const parsed = parseMemoryKey(mem.key)
return parsed && blockIdsToFilter.includes(parsed.blockId)
})
: rawMemories
const blockIds = new Set<string>()
const parsedKeys = new Map<string, { conversationId: string; blockId: string }>()
for (const mem of filteredMemories) {
const parsed = parseMemoryKey(mem.key)
if (parsed) {
blockIds.add(parsed.blockId)
parsedKeys.set(mem.key, parsed)
}
}
const blockNameMap = new Map<string, string>()
if (blockIds.size > 0) {
const blocks = await db
.select({ id: workflowBlocks.id, name: workflowBlocks.name })
.from(workflowBlocks)
.where(
and(
eq(workflowBlocks.workflowId, workflowId),
inArray(workflowBlocks.id, Array.from(blockIds))
)
)
for (const block of blocks) {
blockNameMap.set(block.id, block.name)
}
}
const enrichedMemories = filteredMemories.map((mem) => {
const parsed = parsedKeys.get(mem.key)
if (!parsed) {
return {
const enrichedMemories = rawMemories.map((mem) => ({
conversationId: mem.key,
blockId: 'unknown',
blockName: 'unknown',
data: mem.data,
}
}
const { conversationId, blockId } = parsed
const blockName = blockNameMap.get(blockId) || 'unknown'
return {
conversationId,
blockId,
blockName,
data: mem.data,
}
})
}))
logger.info(
`[${requestId}] Found ${enrichedMemories.length} memories for workflow: ${workflowId}`
`[${requestId}] Found ${enrichedMemories.length} memories for workspace: ${workspaceId}`
)
return NextResponse.json(
{
success: true,
data: { memories: enrichedMemories },
},
{ success: true, data: { memories: enrichedMemories } },
{ status: 200 }
)
} catch (error: any) {
logger.error(`[${requestId}] Error searching memories`, { error })
return NextResponse.json(
{
success: false,
error: {
message: error.message || 'Failed to search memories',
},
},
{ success: false, error: { message: error.message || 'Failed to search memories' } },
{ status: 500 }
)
}
}
/**
* POST handler for creating new memories
* Requires:
* - key: Unique identifier for the memory (within workflow scope)
* - type: Memory type ('agent')
* - data: Memory content (agent message with role and content)
* - workflowId: ID of the workflow this memory belongs to
*/
export async function POST(request: NextRequest) {
const requestId = generateRequestId()
@@ -250,123 +126,63 @@ export async function POST(request: NextRequest) {
if (!authResult.success || !authResult.userId) {
logger.warn(`[${requestId}] Unauthorized memory creation attempt`)
return NextResponse.json(
{
success: false,
error: {
message: authResult.error || 'Authentication required',
},
},
{ success: false, error: { message: authResult.error || 'Authentication required' } },
{ status: 401 }
)
}
logger.info(`[${requestId}] Processing memory creation request`)
const body = await request.json()
const { key, data, workflowId } = body
const { key, data, workspaceId } = body
if (!key) {
logger.warn(`[${requestId}] Missing required field: key`)
return NextResponse.json(
{
success: false,
error: {
message: 'Memory key is required',
},
},
{ success: false, error: { message: 'Memory key is required' } },
{ status: 400 }
)
}
if (!data) {
logger.warn(`[${requestId}] Missing required field: data`)
return NextResponse.json(
{
success: false,
error: {
message: 'Memory data is required',
},
},
{ success: false, error: { message: 'Memory data is required' } },
{ status: 400 }
)
}
if (!workflowId) {
logger.warn(`[${requestId}] Missing required field: workflowId`)
if (!workspaceId) {
return NextResponse.json(
{
success: false,
error: {
message: 'workflowId is required',
},
},
{ success: false, error: { message: 'workspaceId is required' } },
{ status: 400 }
)
}
const accessContext = await getWorkflowAccessContext(workflowId, authResult.userId)
if (!accessContext) {
logger.warn(`[${requestId}] Workflow ${workflowId} not found for user ${authResult.userId}`)
const { hasAccess, canWrite } = await checkWorkspaceAccess(workspaceId, authResult.userId)
if (!hasAccess) {
return NextResponse.json(
{
success: false,
error: {
message: 'Workflow not found',
},
},
{ success: false, error: { message: 'Workspace not found' } },
{ status: 404 }
)
}
const { workspacePermission, isOwner } = accessContext
const hasWritePermission =
isOwner || workspacePermission === 'write' || workspacePermission === 'admin'
if (!hasWritePermission) {
logger.warn(
`[${requestId}] User ${authResult.userId} denied write access to workflow ${workflowId}`
)
if (!canWrite) {
return NextResponse.json(
{
success: false,
error: {
message: 'Write access denied to this workflow',
},
},
{ success: false, error: { message: 'Write access denied to this workspace' } },
{ status: 403 }
)
}
logger.info(
`[${requestId}] User ${authResult.userId} (${authResult.authType}) creating memory for workflow ${workflowId}`
)
const dataToValidate = Array.isArray(data) ? data : [data]
for (const msg of dataToValidate) {
if (!msg || typeof msg !== 'object' || !msg.role || !msg.content) {
logger.warn(`[${requestId}] Missing required message fields`)
return NextResponse.json(
{
success: false,
error: {
message: 'Memory requires messages with role and content',
},
},
{ success: false, error: { message: 'Memory requires messages with role and content' } },
{ status: 400 }
)
}
if (!['user', 'assistant', 'system'].includes(msg.role)) {
logger.warn(`[${requestId}] Invalid message role: ${msg.role}`)
return NextResponse.json(
{
success: false,
error: {
message: 'Message role must be user, assistant, or system',
},
},
{ success: false, error: { message: 'Message role must be user, assistant, or system' } },
{ status: 400 }
)
}
@@ -382,114 +198,59 @@ export async function POST(request: NextRequest) {
.insert(memory)
.values({
id,
workflowId,
workspaceId,
key,
data: initialData,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: [memory.workflowId, memory.key],
target: [memory.workspaceId, memory.key],
set: {
data: sql`${memory.data} || ${JSON.stringify(initialData)}::jsonb`,
updatedAt: now,
},
})
logger.info(
`[${requestId}] Memory operation successful (atomic): ${key} for workflow: ${workflowId}`
)
logger.info(`[${requestId}] Memory operation successful: ${key} for workspace: ${workspaceId}`)
const allMemories = await db
.select()
.from(memory)
.where(and(eq(memory.key, key), eq(memory.workflowId, workflowId), isNull(memory.deletedAt)))
.where(
and(eq(memory.key, key), eq(memory.workspaceId, workspaceId), isNull(memory.deletedAt))
)
.orderBy(memory.createdAt)
if (allMemories.length === 0) {
logger.warn(`[${requestId}] No memories found after creating/updating memory: ${key}`)
return NextResponse.json(
{
success: false,
error: {
message: 'Failed to retrieve memory after creation/update',
},
},
{ success: false, error: { message: 'Failed to retrieve memory after creation/update' } },
{ status: 500 }
)
}
const memoryRecord = allMemories[0]
const parsed = parseMemoryKey(memoryRecord.key)
let enrichedMemory
if (!parsed) {
enrichedMemory = {
conversationId: memoryRecord.key,
blockId: 'unknown',
blockName: 'unknown',
data: memoryRecord.data,
}
} else {
const { conversationId, blockId } = parsed
const blockName = await (async () => {
const blocks = await db
.select({ name: workflowBlocks.name })
.from(workflowBlocks)
.where(and(eq(workflowBlocks.id, blockId), eq(workflowBlocks.workflowId, workflowId)))
.limit(1)
return blocks.length > 0 ? blocks[0].name : 'unknown'
})()
enrichedMemory = {
conversationId,
blockId,
blockName,
data: memoryRecord.data,
}
}
return NextResponse.json(
{
success: true,
data: enrichedMemory,
},
{ success: true, data: { conversationId: memoryRecord.key, data: memoryRecord.data } },
{ status: 200 }
)
} catch (error: any) {
if (error.code === '23505') {
logger.warn(`[${requestId}] Duplicate key violation`)
return NextResponse.json(
{
success: false,
error: {
message: 'Memory with this key already exists',
},
},
{ success: false, error: { message: 'Memory with this key already exists' } },
{ status: 409 }
)
}
logger.error(`[${requestId}] Error creating memory`, { error })
return NextResponse.json(
{
success: false,
error: {
message: error.message || 'Failed to create memory',
},
},
{ success: false, error: { message: error.message || 'Failed to create memory' } },
{ status: 500 }
)
}
}
/**
* DELETE handler for pattern-based memory deletion
* Supports query parameters:
* - workflowId: Required
* - conversationId: Optional - delete all memories for this conversation
* - blockId: Optional - delete all memories for this block
* - blockName: Optional - delete all memories for blocks with this name
*/
export async function DELETE(request: NextRequest) {
const requestId = generateRequestId()
@@ -498,175 +259,52 @@ export async function DELETE(request: NextRequest) {
if (!authResult.success || !authResult.userId) {
logger.warn(`[${requestId}] Unauthorized memory deletion attempt`)
return NextResponse.json(
{
success: false,
error: {
message: authResult.error || 'Authentication required',
},
},
{ success: false, error: { message: authResult.error || 'Authentication required' } },
{ status: 401 }
)
}
logger.info(`[${requestId}] Processing memory deletion request`)
const url = new URL(request.url)
const workflowId = url.searchParams.get('workflowId')
const workspaceId = url.searchParams.get('workspaceId')
const conversationId = url.searchParams.get('conversationId')
const blockId = url.searchParams.get('blockId')
const blockName = url.searchParams.get('blockName')
if (!workflowId) {
logger.warn(`[${requestId}] Missing required parameter: workflowId`)
if (!workspaceId) {
return NextResponse.json(
{
success: false,
error: {
message: 'workflowId parameter is required',
},
},
{ success: false, error: { message: 'workspaceId parameter is required' } },
{ status: 400 }
)
}
if (!conversationId && !blockId && !blockName) {
logger.warn(`[${requestId}] No filter parameters provided`)
if (!conversationId) {
return NextResponse.json(
{
success: false,
error: {
message: 'At least one of conversationId, blockId, or blockName must be provided',
},
},
{ success: false, error: { message: 'conversationId must be provided' } },
{ status: 400 }
)
}
const accessContext = await getWorkflowAccessContext(workflowId, authResult.userId)
if (!accessContext) {
logger.warn(`[${requestId}] Workflow ${workflowId} not found for user ${authResult.userId}`)
const { hasAccess, canWrite } = await checkWorkspaceAccess(workspaceId, authResult.userId)
if (!hasAccess) {
return NextResponse.json(
{
success: false,
error: {
message: 'Workflow not found',
},
},
{ success: false, error: { message: 'Workspace not found' } },
{ status: 404 }
)
}
const { workspacePermission, isOwner } = accessContext
const hasWritePermission =
isOwner || workspacePermission === 'write' || workspacePermission === 'admin'
if (!hasWritePermission) {
logger.warn(
`[${requestId}] User ${authResult.userId} denied delete access to workflow ${workflowId}`
)
if (!canWrite) {
return NextResponse.json(
{
success: false,
error: {
message: 'Write access denied to this workflow',
},
},
{ success: false, error: { message: 'Write access denied to this workspace' } },
{ status: 403 }
)
}
logger.info(
`[${requestId}] User ${authResult.userId} (${authResult.authType}) deleting memories for workflow ${workflowId}`
)
let deletedCount = 0
if (conversationId && blockId) {
const key = `${conversationId}:${blockId}`
const result = await db
.delete(memory)
.where(and(eq(memory.key, key), eq(memory.workflowId, workflowId)))
.where(and(eq(memory.key, conversationId), eq(memory.workspaceId, workspaceId)))
.returning({ id: memory.id })
deletedCount = result.length
} else if (conversationId && blockName) {
const blocks = await db
.select({ id: workflowBlocks.id })
.from(workflowBlocks)
.where(and(eq(workflowBlocks.workflowId, workflowId), eq(workflowBlocks.name, blockName)))
const deletedCount = result.length
if (blocks.length === 0) {
return NextResponse.json(
{
success: true,
data: {
message: `No blocks found with name "${blockName}"`,
deletedCount: 0,
},
},
{ status: 200 }
)
}
for (const block of blocks) {
const key = `${conversationId}:${block.id}`
const result = await db
.delete(memory)
.where(and(eq(memory.key, key), eq(memory.workflowId, workflowId)))
.returning({ id: memory.id })
deletedCount += result.length
}
} else if (conversationId) {
const pattern = `${conversationId}:%`
const result = await db
.delete(memory)
.where(and(like(memory.key, pattern), eq(memory.workflowId, workflowId)))
.returning({ id: memory.id })
deletedCount = result.length
} else if (blockId) {
const pattern = `%:${blockId}`
const result = await db
.delete(memory)
.where(and(like(memory.key, pattern), eq(memory.workflowId, workflowId)))
.returning({ id: memory.id })
deletedCount = result.length
} else if (blockName) {
const blocks = await db
.select({ id: workflowBlocks.id })
.from(workflowBlocks)
.where(and(eq(workflowBlocks.workflowId, workflowId), eq(workflowBlocks.name, blockName)))
if (blocks.length === 0) {
return NextResponse.json(
{
success: true,
data: {
message: `No blocks found with name "${blockName}"`,
deletedCount: 0,
},
},
{ status: 200 }
)
}
for (const block of blocks) {
const pattern = `%:${block.id}`
const result = await db
.delete(memory)
.where(and(like(memory.key, pattern), eq(memory.workflowId, workflowId)))
.returning({ id: memory.id })
deletedCount += result.length
}
}
logger.info(
`[${requestId}] Successfully deleted ${deletedCount} memories for workflow: ${workflowId}`
)
logger.info(`[${requestId}] Deleted ${deletedCount} memories for workspace: ${workspaceId}`)
return NextResponse.json(
{
success: true,
@@ -683,12 +321,7 @@ export async function DELETE(request: NextRequest) {
} catch (error: any) {
logger.error(`[${requestId}] Error deleting memories`, { error })
return NextResponse.json(
{
success: false,
error: {
message: error.message || 'Failed to delete memories',
},
},
{ success: false, error: { message: error.message || 'Failed to delete memories' } },
{ status: 500 }
)
}

View File

@@ -6,12 +6,29 @@ const logger = createLogger('OpenRouterModelsAPI')
interface OpenRouterModel {
id: string
context_length?: number
supported_parameters?: string[]
pricing?: {
prompt?: string
completion?: string
}
}
interface OpenRouterResponse {
data: OpenRouterModel[]
}
export interface OpenRouterModelInfo {
id: string
contextLength?: number
supportsStructuredOutputs?: boolean
supportsTools?: boolean
pricing?: {
input: number
output: number
}
}
export async function GET(_request: NextRequest) {
try {
const response = await fetch('https://openrouter.ai/api/v1/models', {
@@ -24,23 +41,51 @@ export async function GET(_request: NextRequest) {
status: response.status,
statusText: response.statusText,
})
return NextResponse.json({ models: [] })
return NextResponse.json({ models: [], modelInfo: {} })
}
const data = (await response.json()) as OpenRouterResponse
const allModels = Array.from(new Set(data.data?.map((model) => `openrouter/${model.id}`) ?? []))
const models = filterBlacklistedModels(allModels)
const modelInfo: Record<string, OpenRouterModelInfo> = {}
const allModels: string[] = []
for (const model of data.data ?? []) {
const modelId = `openrouter/${model.id}`
allModels.push(modelId)
const supportedParams = model.supported_parameters ?? []
modelInfo[modelId] = {
id: modelId,
contextLength: model.context_length,
supportsStructuredOutputs: supportedParams.includes('structured_outputs'),
supportsTools: supportedParams.includes('tools'),
pricing: model.pricing
? {
input: Number.parseFloat(model.pricing.prompt ?? '0') * 1000000,
output: Number.parseFloat(model.pricing.completion ?? '0') * 1000000,
}
: undefined,
}
}
const uniqueModels = Array.from(new Set(allModels))
const models = filterBlacklistedModels(uniqueModels)
const structuredOutputCount = Object.values(modelInfo).filter(
(m) => m.supportsStructuredOutputs
).length
logger.info('Successfully fetched OpenRouter models', {
count: models.length,
filtered: allModels.length - models.length,
filtered: uniqueModels.length - models.length,
withStructuredOutputs: structuredOutputCount,
})
return NextResponse.json({ models })
return NextResponse.json({ models, modelInfo })
} catch (error) {
logger.error('Error fetching OpenRouter models', {
error: error instanceof Error ? error.message : 'Unknown error',
})
return NextResponse.json({ models: [] })
return NextResponse.json({ models: [], modelInfo: {} })
}
}

View File

@@ -49,126 +49,6 @@ const ExecuteWorkflowSchema = z.object({
export const runtime = 'nodejs'
export const dynamic = 'force-dynamic'
/**
* Execute workflow with streaming support - used by chat and other streaming endpoints
*
* This function assumes preprocessing has already been completed.
* Callers must run preprocessExecution() first to validate workflow, check usage limits,
* and resolve actor before calling this function.
*
* This is a wrapper function that:
* - Supports streaming callbacks (onStream, onBlockComplete)
* - Returns ExecutionResult instead of NextResponse
* - Handles pause/resume logic
*
* Used by:
* - Chat execution (/api/chat/[identifier]/route.ts)
* - Streaming responses (lib/workflows/streaming.ts)
*/
export async function executeWorkflow(
workflow: any,
requestId: string,
input: any | undefined,
actorUserId: string,
streamConfig?: {
enabled: boolean
selectedOutputs?: string[]
isSecureMode?: boolean
workflowTriggerType?: 'api' | 'chat'
onStream?: (streamingExec: any) => Promise<void>
onBlockComplete?: (blockId: string, output: any) => Promise<void>
skipLoggingComplete?: boolean
},
providedExecutionId?: string
): Promise<any> {
const workflowId = workflow.id
const executionId = providedExecutionId || uuidv4()
const triggerType = streamConfig?.workflowTriggerType || 'api'
const loggingSession = new LoggingSession(workflowId, executionId, triggerType, requestId)
try {
const metadata: ExecutionMetadata = {
requestId,
executionId,
workflowId,
workspaceId: workflow.workspaceId,
userId: actorUserId,
workflowUserId: workflow.userId,
triggerType,
useDraftState: false,
startTime: new Date().toISOString(),
isClientSession: false,
}
const snapshot = new ExecutionSnapshot(
metadata,
workflow,
input,
workflow.variables || {},
streamConfig?.selectedOutputs || []
)
const result = await executeWorkflowCore({
snapshot,
callbacks: {
onStream: streamConfig?.onStream,
onBlockComplete: streamConfig?.onBlockComplete
? async (blockId: string, _blockName: string, _blockType: string, output: any) => {
await streamConfig.onBlockComplete!(blockId, output)
}
: undefined,
},
loggingSession,
})
if (result.status === 'paused') {
if (!result.snapshotSeed) {
logger.error(`[${requestId}] Missing snapshot seed for paused execution`, {
executionId,
})
} else {
await PauseResumeManager.persistPauseResult({
workflowId,
executionId,
pausePoints: result.pausePoints || [],
snapshotSeed: result.snapshotSeed,
executorUserId: result.metadata?.userId,
})
}
} else {
await PauseResumeManager.processQueuedResumes(executionId)
}
if (streamConfig?.skipLoggingComplete) {
return {
...result,
_streamingMetadata: {
loggingSession,
processedInput: input,
},
}
}
return result
} catch (error: any) {
logger.error(`[${requestId}] Workflow execution failed:`, error)
throw error
}
}
export function createFilteredResult(result: any) {
return {
...result,
logs: undefined,
metadata: result.metadata
? {
...result.metadata,
workflowConnections: undefined,
}
: undefined,
}
}
function resolveOutputIds(
selectedOutputs: string[] | undefined,
blocks: Record<string, any>
@@ -606,7 +486,6 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
isSecureMode: false,
workflowTriggerType: triggerType === 'chat' ? 'chat' : 'api',
},
createFilteredResult,
executionId,
})
@@ -743,14 +622,17 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
const onStream = async (streamingExec: StreamingExecution) => {
const blockId = (streamingExec.execution as any).blockId
const reader = streamingExec.stream.getReader()
const decoder = new TextDecoder()
let chunkCount = 0
try {
while (true) {
const { done, value } = await reader.read()
if (done) break
chunkCount++
const chunk = decoder.decode(value, { stream: true })
sendEvent({
type: 'stream:chunk',

View File

@@ -314,22 +314,9 @@ export function useChatStreaming() {
let finalContent = accumulatedText
if (formattedOutputs.length > 0) {
const trimmedStreamingContent = accumulatedText.trim()
const uniqueOutputs = formattedOutputs.filter((output) => {
const trimmedOutput = output.trim()
if (!trimmedOutput) return false
// Skip outputs that exactly match the streamed content to avoid duplication
if (trimmedStreamingContent && trimmedOutput === trimmedStreamingContent) {
return false
}
return true
})
if (uniqueOutputs.length > 0) {
const combinedOutputs = uniqueOutputs.join('\n\n')
const nonEmptyOutputs = formattedOutputs.filter((output) => output.trim())
if (nonEmptyOutputs.length > 0) {
const combinedOutputs = nonEmptyOutputs.join('\n\n')
finalContent = finalContent
? `${finalContent.trim()}\n\n${combinedOutputs}`
: combinedOutputs

View File

@@ -16,6 +16,7 @@ const logger = createLogger('ProviderModelsLoader')
function useSyncProvider(provider: ProviderName) {
const setProviderModels = useProvidersStore((state) => state.setProviderModels)
const setProviderLoading = useProvidersStore((state) => state.setProviderLoading)
const setOpenRouterModelInfo = useProvidersStore((state) => state.setOpenRouterModelInfo)
const { data, isLoading, isFetching, error } = useProviderModels(provider)
useEffect(() => {
@@ -27,18 +28,21 @@ function useSyncProvider(provider: ProviderName) {
try {
if (provider === 'ollama') {
updateOllamaProviderModels(data)
updateOllamaProviderModels(data.models)
} else if (provider === 'vllm') {
updateVLLMProviderModels(data)
updateVLLMProviderModels(data.models)
} else if (provider === 'openrouter') {
void updateOpenRouterProviderModels(data)
void updateOpenRouterProviderModels(data.models)
if (data.modelInfo) {
setOpenRouterModelInfo(data.modelInfo)
}
}
} catch (syncError) {
logger.warn(`Failed to sync provider definitions for ${provider}`, syncError as Error)
}
setProviderModels(provider, data)
}, [provider, data, setProviderModels])
setProviderModels(provider, data.models)
}, [provider, data, setProviderModels, setOpenRouterModelInfo])
useEffect(() => {
if (error) {

View File

@@ -398,6 +398,7 @@ export function useWorkflowExecution() {
}
const streamCompletionTimes = new Map<string, number>()
const processedFirstChunk = new Set<string>()
const onStream = async (streamingExecution: StreamingExecution) => {
const promise = (async () => {
@@ -405,16 +406,14 @@ export function useWorkflowExecution() {
const reader = streamingExecution.stream.getReader()
const blockId = (streamingExecution.execution as any)?.blockId
let isFirstChunk = true
if (blockId) {
if (blockId && !streamedContent.has(blockId)) {
streamedContent.set(blockId, '')
}
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
// Record when this stream completed
if (blockId) {
streamCompletionTimes.set(blockId, Date.now())
}
@@ -425,13 +424,12 @@ export function useWorkflowExecution() {
streamedContent.set(blockId, (streamedContent.get(blockId) || '') + chunk)
}
// Add separator before first chunk if this isn't the first block
let chunkToSend = chunk
if (isFirstChunk && streamedContent.size > 1) {
if (blockId && !processedFirstChunk.has(blockId)) {
processedFirstChunk.add(blockId)
if (streamedContent.size > 1) {
chunkToSend = `\n\n${chunk}`
isFirstChunk = false
} else if (isFirstChunk) {
isFirstChunk = false
}
}
controller.enqueue(encodeSSE({ blockId, chunk: chunkToSend }))

View File

@@ -41,17 +41,6 @@ export const MemoryBlock: BlockConfig = {
},
required: true,
},
{
id: 'blockId',
title: 'Block ID',
type: 'short-input',
placeholder: 'Enter block ID (optional, defaults to current block)',
condition: {
field: 'operation',
value: 'add',
},
required: false,
},
{
id: 'id',
title: 'Conversation ID',
@@ -61,29 +50,7 @@ export const MemoryBlock: BlockConfig = {
field: 'operation',
value: 'get',
},
required: false,
},
{
id: 'blockId',
title: 'Block ID',
type: 'short-input',
placeholder: 'Enter block ID (optional)',
condition: {
field: 'operation',
value: 'get',
},
required: false,
},
{
id: 'blockName',
title: 'Block Name',
type: 'short-input',
placeholder: 'Enter block name (optional)',
condition: {
field: 'operation',
value: 'get',
},
required: false,
required: true,
},
{
id: 'id',
@@ -94,29 +61,7 @@ export const MemoryBlock: BlockConfig = {
field: 'operation',
value: 'delete',
},
required: false,
},
{
id: 'blockId',
title: 'Block ID',
type: 'short-input',
placeholder: 'Enter block ID (optional)',
condition: {
field: 'operation',
value: 'delete',
},
required: false,
},
{
id: 'blockName',
title: 'Block Name',
type: 'short-input',
placeholder: 'Enter block name (optional)',
condition: {
field: 'operation',
value: 'delete',
},
required: false,
required: true,
},
{
id: 'role',
@@ -186,10 +131,8 @@ export const MemoryBlock: BlockConfig = {
}
if (params.operation === 'get' || params.operation === 'delete') {
if (!conversationId && !params.blockId && !params.blockName) {
errors.push(
`At least one of ID, blockId, or blockName is required for ${params.operation} operation`
)
if (!conversationId) {
errors.push(`Conversation ID is required for ${params.operation} operation`)
}
}
@@ -200,33 +143,26 @@ export const MemoryBlock: BlockConfig = {
const baseResult: Record<string, any> = {}
if (params.operation === 'add') {
const result: Record<string, any> = {
return {
...baseResult,
conversationId: conversationId,
role: params.role,
content: params.content,
}
if (params.blockId) {
result.blockId = params.blockId
}
return result
}
if (params.operation === 'get') {
const result: Record<string, any> = { ...baseResult }
if (conversationId) result.conversationId = conversationId
if (params.blockId) result.blockId = params.blockId
if (params.blockName) result.blockName = params.blockName
return result
return {
...baseResult,
conversationId: conversationId,
}
}
if (params.operation === 'delete') {
const result: Record<string, any> = { ...baseResult }
if (conversationId) result.conversationId = conversationId
if (params.blockId) result.blockId = params.blockId
if (params.blockName) result.blockName = params.blockName
return result
return {
...baseResult,
conversationId: conversationId,
}
}
return baseResult
@@ -235,10 +171,8 @@ export const MemoryBlock: BlockConfig = {
},
inputs: {
operation: { type: 'string', description: 'Operation to perform' },
id: { type: 'string', description: 'Memory identifier (for add operation)' },
id: { type: 'string', description: 'Memory identifier (conversation ID)' },
conversationId: { type: 'string', description: 'Conversation identifier' },
blockId: { type: 'string', description: 'Block identifier' },
blockName: { type: 'string', description: 'Block name' },
role: { type: 'string', description: 'Agent role' },
content: { type: 'string', description: 'Memory content' },
},

View File

@@ -158,11 +158,19 @@ export const HTTP = {
export const AGENT = {
DEFAULT_MODEL: 'claude-sonnet-4-5',
DEFAULT_FUNCTION_TIMEOUT: 600000, // 10 minutes for custom tool code execution
REQUEST_TIMEOUT: 600000, // 10 minutes for LLM API requests
DEFAULT_FUNCTION_TIMEOUT: 600000,
REQUEST_TIMEOUT: 600000,
CUSTOM_TOOL_PREFIX: 'custom_',
} as const
export const MEMORY = {
DEFAULT_SLIDING_WINDOW_SIZE: 10,
DEFAULT_SLIDING_WINDOW_TOKENS: 4000,
CONTEXT_WINDOW_UTILIZATION: 0.9,
MAX_CONVERSATION_ID_LENGTH: 255,
MAX_MESSAGE_CONTENT_BYTES: 100 * 1024,
} as const
export const ROUTER = {
DEFAULT_MODEL: 'gpt-4o',
DEFAULT_TEMPERATURE: 0,

View File

@@ -1140,7 +1140,7 @@ describe('AgentBlockHandler', () => {
expect(systemMessages[0].content).toBe('You are a helpful assistant.')
})
it('should prioritize messages array system message over system messages in memories', async () => {
it('should prefix agent system message before legacy memories', async () => {
const inputs = {
model: 'gpt-4o',
messages: [
@@ -1163,25 +1163,26 @@ describe('AgentBlockHandler', () => {
const requestBody = JSON.parse(fetchCall[1].body)
// Verify messages were built correctly
// Agent system (1) + legacy memories (3) + user from messages (1) = 5
expect(requestBody.messages).toBeDefined()
expect(requestBody.messages.length).toBe(5) // memory system + 2 non-system memories + 2 from messages array
expect(requestBody.messages.length).toBe(5)
// All messages should be present (memories first, then messages array)
// Memory messages come first
// Agent's system message is prefixed first
expect(requestBody.messages[0].role).toBe('system')
expect(requestBody.messages[0].content).toBe('Old system message from memories.')
expect(requestBody.messages[1].role).toBe('user')
expect(requestBody.messages[1].content).toBe('Hello!')
expect(requestBody.messages[2].role).toBe('assistant')
expect(requestBody.messages[2].content).toBe('Hi there!')
// Then messages array
expect(requestBody.messages[3].role).toBe('system')
expect(requestBody.messages[3].content).toBe('You are a helpful assistant.')
expect(requestBody.messages[0].content).toBe('You are a helpful assistant.')
// Then legacy memories (with their system message preserved)
expect(requestBody.messages[1].role).toBe('system')
expect(requestBody.messages[1].content).toBe('Old system message from memories.')
expect(requestBody.messages[2].role).toBe('user')
expect(requestBody.messages[2].content).toBe('Hello!')
expect(requestBody.messages[3].role).toBe('assistant')
expect(requestBody.messages[3].content).toBe('Hi there!')
// Then user message from messages array
expect(requestBody.messages[4].role).toBe('user')
expect(requestBody.messages[4].content).toBe('What should I do?')
})
it('should handle multiple system messages in memories with messages array', async () => {
it('should prefix agent system message and preserve legacy memory system messages', async () => {
const inputs = {
model: 'gpt-4o',
messages: [
@@ -1207,21 +1208,23 @@ describe('AgentBlockHandler', () => {
// Verify messages were built correctly
expect(requestBody.messages).toBeDefined()
expect(requestBody.messages.length).toBe(7) // 5 memory messages (3 system + 2 conversation) + 2 from messages array
expect(requestBody.messages.length).toBe(7)
// All messages should be present in order
// Agent's system message prefixed first
expect(requestBody.messages[0].role).toBe('system')
expect(requestBody.messages[0].content).toBe('First system message.')
expect(requestBody.messages[1].role).toBe('user')
expect(requestBody.messages[1].content).toBe('Hello!')
expect(requestBody.messages[2].role).toBe('system')
expect(requestBody.messages[2].content).toBe('Second system message.')
expect(requestBody.messages[3].role).toBe('assistant')
expect(requestBody.messages[3].content).toBe('Hi there!')
expect(requestBody.messages[4].role).toBe('system')
expect(requestBody.messages[4].content).toBe('Third system message.')
expect(requestBody.messages[0].content).toBe('You are a helpful assistant.')
// Then legacy memories with their system messages preserved in order
expect(requestBody.messages[1].role).toBe('system')
expect(requestBody.messages[1].content).toBe('First system message.')
expect(requestBody.messages[2].role).toBe('user')
expect(requestBody.messages[2].content).toBe('Hello!')
expect(requestBody.messages[3].role).toBe('system')
expect(requestBody.messages[3].content).toBe('Second system message.')
expect(requestBody.messages[4].role).toBe('assistant')
expect(requestBody.messages[4].content).toBe('Hi there!')
expect(requestBody.messages[5].role).toBe('system')
expect(requestBody.messages[5].content).toBe('You are a helpful assistant.')
expect(requestBody.messages[5].content).toBe('Third system message.')
// Then user message from messages array
expect(requestBody.messages[6].role).toBe('user')
expect(requestBody.messages[6].content).toBe('Continue our conversation.')
})

View File

@@ -47,7 +47,7 @@ export class AgentBlockHandler implements BlockHandler {
const providerId = getProviderFromModel(model)
const formattedTools = await this.formatTools(ctx, filteredInputs.tools || [])
const streamingConfig = this.getStreamingConfig(ctx, block)
const messages = await this.buildMessages(ctx, filteredInputs, block.id)
const messages = await this.buildMessages(ctx, filteredInputs)
const providerRequest = this.buildProviderRequest({
ctx,
@@ -68,7 +68,20 @@ export class AgentBlockHandler implements BlockHandler {
filteredInputs
)
await this.persistResponseToMemory(ctx, filteredInputs, result, block.id)
if (this.isStreamingExecution(result)) {
if (filteredInputs.memoryType && filteredInputs.memoryType !== 'none') {
return this.wrapStreamForMemoryPersistence(
ctx,
filteredInputs,
result as StreamingExecution
)
}
return result
}
if (filteredInputs.memoryType && filteredInputs.memoryType !== 'none') {
await this.persistResponseToMemory(ctx, filteredInputs, result as BlockOutput)
}
return result
}
@@ -686,79 +699,100 @@ export class AgentBlockHandler implements BlockHandler {
private async buildMessages(
ctx: ExecutionContext,
inputs: AgentInputs,
blockId: string
inputs: AgentInputs
): Promise<Message[] | undefined> {
const messages: Message[] = []
const memoryEnabled = inputs.memoryType && inputs.memoryType !== 'none'
// 1. Fetch memory history if configured (industry standard: chronological order)
if (inputs.memoryType && inputs.memoryType !== 'none') {
const memoryMessages = await memoryService.fetchMemoryMessages(ctx, inputs, blockId)
// 1. Extract and validate messages from messages-input subblock
const inputMessages = this.extractValidMessages(inputs.messages)
const systemMessages = inputMessages.filter((m) => m.role === 'system')
const conversationMessages = inputMessages.filter((m) => m.role !== 'system')
// 2. Handle native memory: seed on first run, then fetch and append new user input
if (memoryEnabled && ctx.workspaceId) {
const memoryMessages = await memoryService.fetchMemoryMessages(ctx, inputs)
const hasExisting = memoryMessages.length > 0
if (!hasExisting && conversationMessages.length > 0) {
const taggedMessages = conversationMessages.map((m) =>
m.role === 'user' ? { ...m, executionId: ctx.executionId } : m
)
await memoryService.seedMemory(ctx, inputs, taggedMessages)
messages.push(...taggedMessages)
} else {
messages.push(...memoryMessages)
if (hasExisting && conversationMessages.length > 0) {
const latestUserFromInput = conversationMessages.filter((m) => m.role === 'user').pop()
if (latestUserFromInput) {
const userMessageInThisRun = memoryMessages.some(
(m) => m.role === 'user' && m.executionId === ctx.executionId
)
if (!userMessageInThisRun) {
const taggedMessage = { ...latestUserFromInput, executionId: ctx.executionId }
messages.push(taggedMessage)
await memoryService.appendToMemory(ctx, inputs, taggedMessage)
}
}
}
}
}
// 2. Process legacy memories (backward compatibility - from Memory block)
// 3. Process legacy memories (backward compatibility - from Memory block)
// These may include system messages which are preserved in their position
if (inputs.memories) {
messages.push(...this.processMemories(inputs.memories))
}
// 3. Add messages array (new approach - from messages-input subblock)
if (inputs.messages && Array.isArray(inputs.messages)) {
const validMessages = inputs.messages.filter(
(msg) =>
// 4. Add conversation messages from inputs.messages (if not using native memory)
// When memory is enabled, these are already seeded/fetched above
if (!memoryEnabled && conversationMessages.length > 0) {
messages.push(...conversationMessages)
}
// 5. Handle legacy systemPrompt (backward compatibility)
// Only add if no system message exists from any source
if (inputs.systemPrompt) {
const hasSystem = systemMessages.length > 0 || messages.some((m) => m.role === 'system')
if (!hasSystem) {
this.addSystemPrompt(messages, inputs.systemPrompt)
}
}
// 6. Handle legacy userPrompt - this is NEW input each run
if (inputs.userPrompt) {
this.addUserPrompt(messages, inputs.userPrompt)
if (memoryEnabled) {
const userMessages = messages.filter((m) => m.role === 'user')
const lastUserMessage = userMessages[userMessages.length - 1]
if (lastUserMessage) {
await memoryService.appendToMemory(ctx, inputs, lastUserMessage)
}
}
}
// 7. Prefix system messages from inputs.messages at the start (runtime only)
// These are the agent's configured system prompts
if (systemMessages.length > 0) {
messages.unshift(...systemMessages)
}
return messages.length > 0 ? messages : undefined
}
private extractValidMessages(messages?: Message[]): Message[] {
if (!messages || !Array.isArray(messages)) return []
return messages.filter(
(msg): msg is Message =>
msg &&
typeof msg === 'object' &&
'role' in msg &&
'content' in msg &&
['system', 'user', 'assistant'].includes(msg.role)
)
messages.push(...validMessages)
}
// Warn if using both new and legacy input formats
if (
inputs.messages &&
inputs.messages.length > 0 &&
(inputs.systemPrompt || inputs.userPrompt)
) {
logger.warn('Agent block using both messages array and legacy prompts', {
hasMessages: true,
hasSystemPrompt: !!inputs.systemPrompt,
hasUserPrompt: !!inputs.userPrompt,
})
}
// 4. Handle legacy systemPrompt (backward compatibility)
// Only add if no system message exists yet
if (inputs.systemPrompt && !messages.some((m) => m.role === 'system')) {
this.addSystemPrompt(messages, inputs.systemPrompt)
}
// 5. Handle legacy userPrompt (backward compatibility)
if (inputs.userPrompt) {
this.addUserPrompt(messages, inputs.userPrompt)
}
// 6. Persist user message(s) to memory if configured
// This ensures conversation history is complete before agent execution
if (inputs.memoryType && inputs.memoryType !== 'none' && messages.length > 0) {
// Find new user messages that need to be persisted
// (messages added via messages array or userPrompt)
const userMessages = messages.filter((m) => m.role === 'user')
const lastUserMessage = userMessages[userMessages.length - 1]
// Only persist if there's a user message AND it's from userPrompt or messages input
// (not from memory history which was already persisted)
if (
lastUserMessage &&
(inputs.userPrompt || (inputs.messages && inputs.messages.length > 0))
) {
await memoryService.persistUserMessage(ctx, inputs, lastUserMessage, blockId)
}
}
// Return messages or undefined if empty (maintains API compatibility)
return messages.length > 0 ? messages : undefined
}
private processMemories(memories: any): Message[] {
@@ -1036,29 +1070,14 @@ export class AgentBlockHandler implements BlockHandler {
private async handleStreamingResponse(
response: Response,
block: SerializedBlock,
ctx?: ExecutionContext,
inputs?: AgentInputs
_ctx?: ExecutionContext,
_inputs?: AgentInputs
): Promise<StreamingExecution> {
const executionDataHeader = response.headers.get('X-Execution-Data')
if (executionDataHeader) {
try {
const executionData = JSON.parse(executionDataHeader)
// If execution data contains full content, persist to memory
if (ctx && inputs && executionData.output?.content) {
const assistantMessage: Message = {
role: 'assistant',
content: executionData.output.content,
}
// Fire and forget - don't await
memoryService
.persistMemoryMessage(ctx, inputs, assistantMessage, block.id)
.catch((error) =>
logger.error('Failed to persist streaming response to memory:', error)
)
}
return {
stream: response.body!,
execution: {
@@ -1158,46 +1177,35 @@ export class AgentBlockHandler implements BlockHandler {
}
}
private wrapStreamForMemoryPersistence(
ctx: ExecutionContext,
inputs: AgentInputs,
streamingExec: StreamingExecution
): StreamingExecution {
return {
stream: memoryService.wrapStreamForPersistence(streamingExec.stream, ctx, inputs),
execution: streamingExec.execution,
}
}
private async persistResponseToMemory(
ctx: ExecutionContext,
inputs: AgentInputs,
result: BlockOutput | StreamingExecution,
blockId: string
result: BlockOutput
): Promise<void> {
// Only persist if memoryType is configured
if (!inputs.memoryType || inputs.memoryType === 'none') {
return
}
try {
// Don't persist streaming responses here - they're handled separately
if (this.isStreamingExecution(result)) {
return
}
// Extract content from regular response
const blockOutput = result as any
const content = blockOutput?.content
const content = (result as any)?.content
if (!content || typeof content !== 'string') {
return
}
const assistantMessage: Message = {
role: 'assistant',
content,
}
await memoryService.persistMemoryMessage(ctx, inputs, assistantMessage, blockId)
try {
await memoryService.appendToMemory(ctx, inputs, { role: 'assistant', content })
logger.debug('Persisted assistant response to memory', {
workflowId: ctx.workflowId,
memoryType: inputs.memoryType,
conversationId: inputs.conversationId,
})
} catch (error) {
logger.error('Failed to persist response to memory:', error)
// Don't throw - memory persistence failure shouldn't break workflow execution
}
}

View File

@@ -1,7 +1,7 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { MEMORY } from '@/executor/constants'
import { Memory } from '@/executor/handlers/agent/memory'
import type { AgentInputs, Message } from '@/executor/handlers/agent/types'
import type { ExecutionContext } from '@/executor/types'
import type { Message } from '@/executor/handlers/agent/types'
vi.mock('@/lib/logs/console/logger', () => ({
createLogger: () => ({
@@ -20,21 +20,14 @@ vi.mock('@/lib/tokenization/estimators', () => ({
describe('Memory', () => {
let memoryService: Memory
let mockContext: ExecutionContext
beforeEach(() => {
memoryService = new Memory()
mockContext = {
workflowId: 'test-workflow-id',
executionId: 'test-execution-id',
workspaceId: 'test-workspace-id',
} as ExecutionContext
})
describe('applySlidingWindow (message-based)', () => {
it('should keep last N conversation messages', () => {
describe('applyWindow (message-based)', () => {
it('should keep last N messages', () => {
const messages: Message[] = [
{ role: 'system', content: 'System prompt' },
{ role: 'user', content: 'Message 1' },
{ role: 'assistant', content: 'Response 1' },
{ role: 'user', content: 'Message 2' },
@@ -43,55 +36,51 @@ describe('Memory', () => {
{ role: 'assistant', content: 'Response 3' },
]
const result = (memoryService as any).applySlidingWindow(messages, '4')
const result = (memoryService as any).applyWindow(messages, 4)
expect(result.length).toBe(5)
expect(result[0].role).toBe('system')
expect(result[0].content).toBe('System prompt')
expect(result[1].content).toBe('Message 2')
expect(result[4].content).toBe('Response 3')
expect(result.length).toBe(4)
expect(result[0].content).toBe('Message 2')
expect(result[3].content).toBe('Response 3')
})
it('should preserve only first system message', () => {
it('should return all messages if limit exceeds array length', () => {
const messages: Message[] = [
{ role: 'system', content: 'First system' },
{ role: 'user', content: 'User message' },
{ role: 'system', content: 'Second system' },
{ role: 'assistant', content: 'Assistant message' },
{ role: 'user', content: 'Test' },
{ role: 'assistant', content: 'Response' },
]
const result = (memoryService as any).applySlidingWindow(messages, '10')
const systemMessages = result.filter((m: Message) => m.role === 'system')
expect(systemMessages.length).toBe(1)
expect(systemMessages[0].content).toBe('First system')
const result = (memoryService as any).applyWindow(messages, 10)
expect(result.length).toBe(2)
})
it('should handle invalid window size', () => {
const messages: Message[] = [{ role: 'user', content: 'Test' }]
const result = (memoryService as any).applySlidingWindow(messages, 'invalid')
const result = (memoryService as any).applyWindow(messages, Number.NaN)
expect(result).toEqual(messages)
})
it('should handle zero limit', () => {
const messages: Message[] = [{ role: 'user', content: 'Test' }]
const result = (memoryService as any).applyWindow(messages, 0)
expect(result).toEqual(messages)
})
})
describe('applySlidingWindowByTokens (token-based)', () => {
describe('applyTokenWindow (token-based)', () => {
it('should keep messages within token limit', () => {
const messages: Message[] = [
{ role: 'system', content: 'This is a system message' }, // ~6 tokens
{ role: 'user', content: 'Short' }, // ~2 tokens
{ role: 'assistant', content: 'This is a longer response message' }, // ~8 tokens
{ role: 'user', content: 'Another user message here' }, // ~6 tokens
{ role: 'assistant', content: 'Final response' }, // ~3 tokens
{ role: 'user', content: 'Short' },
{ role: 'assistant', content: 'This is a longer response message' },
{ role: 'user', content: 'Another user message here' },
{ role: 'assistant', content: 'Final response' },
]
// Set limit to ~15 tokens - should include last 2-3 messages
const result = (memoryService as any).applySlidingWindowByTokens(messages, '15', 'gpt-4o')
const result = (memoryService as any).applyTokenWindow(messages, 15, 'gpt-4o')
expect(result.length).toBeGreaterThan(0)
expect(result.length).toBeLessThan(messages.length)
// Should include newest messages
expect(result[result.length - 1].content).toBe('Final response')
})
@@ -104,30 +93,12 @@ describe('Memory', () => {
},
]
const result = (memoryService as any).applySlidingWindowByTokens(messages, '5', 'gpt-4o')
const result = (memoryService as any).applyTokenWindow(messages, 5, 'gpt-4o')
expect(result.length).toBe(1)
expect(result[0].content).toBe(messages[0].content)
})
it('should preserve first system message and exclude it from token count', () => {
const messages: Message[] = [
{ role: 'system', content: 'A' }, // System message - always preserved
{ role: 'user', content: 'B' }, // ~1 token
{ role: 'assistant', content: 'C' }, // ~1 token
{ role: 'user', content: 'D' }, // ~1 token
]
// Limit to 2 tokens - should fit system message + last 2 conversation messages (D, C)
const result = (memoryService as any).applySlidingWindowByTokens(messages, '2', 'gpt-4o')
// Should have: system message + 2 conversation messages = 3 total
expect(result.length).toBe(3)
expect(result[0].role).toBe('system') // First system message preserved
expect(result[1].content).toBe('C') // Second most recent conversation message
expect(result[2].content).toBe('D') // Most recent conversation message
})
it('should process messages from newest to oldest', () => {
const messages: Message[] = [
{ role: 'user', content: 'Old message' },
@@ -136,141 +107,101 @@ describe('Memory', () => {
{ role: 'assistant', content: 'New response' },
]
const result = (memoryService as any).applySlidingWindowByTokens(messages, '10', 'gpt-4o')
const result = (memoryService as any).applyTokenWindow(messages, 10, 'gpt-4o')
// Should prioritize newer messages
expect(result[result.length - 1].content).toBe('New response')
})
it('should handle invalid token limit', () => {
const messages: Message[] = [{ role: 'user', content: 'Test' }]
const result = (memoryService as any).applySlidingWindowByTokens(
messages,
'invalid',
'gpt-4o'
)
expect(result).toEqual(messages) // Should return all messages
const result = (memoryService as any).applyTokenWindow(messages, Number.NaN, 'gpt-4o')
expect(result).toEqual(messages)
})
it('should handle zero or negative token limit', () => {
const messages: Message[] = [{ role: 'user', content: 'Test' }]
const result1 = (memoryService as any).applySlidingWindowByTokens(messages, '0', 'gpt-4o')
const result1 = (memoryService as any).applyTokenWindow(messages, 0, 'gpt-4o')
expect(result1).toEqual(messages)
const result2 = (memoryService as any).applySlidingWindowByTokens(messages, '-5', 'gpt-4o')
const result2 = (memoryService as any).applyTokenWindow(messages, -5, 'gpt-4o')
expect(result2).toEqual(messages)
})
it('should work with different model names', () => {
it('should work without model specified', () => {
const messages: Message[] = [{ role: 'user', content: 'Test message' }]
const result1 = (memoryService as any).applySlidingWindowByTokens(messages, '100', 'gpt-4o')
expect(result1.length).toBe(1)
const result2 = (memoryService as any).applySlidingWindowByTokens(
messages,
'100',
'claude-3-5-sonnet-20241022'
)
expect(result2.length).toBe(1)
const result3 = (memoryService as any).applySlidingWindowByTokens(messages, '100', undefined)
expect(result3.length).toBe(1)
const result = (memoryService as any).applyTokenWindow(messages, 100, undefined)
expect(result.length).toBe(1)
})
it('should handle empty messages array', () => {
const messages: Message[] = []
const result = (memoryService as any).applySlidingWindowByTokens(messages, '100', 'gpt-4o')
const result = (memoryService as any).applyTokenWindow(messages, 100, 'gpt-4o')
expect(result).toEqual([])
})
})
describe('buildMemoryKey', () => {
it('should build correct key with conversationId:blockId format', () => {
const inputs: AgentInputs = {
memoryType: 'conversation',
conversationId: 'emir',
}
const key = (memoryService as any).buildMemoryKey(mockContext, inputs, 'test-block-id')
expect(key).toBe('emir:test-block-id')
})
it('should use same key format regardless of memory type', () => {
const conversationId = 'user-123'
const blockId = 'block-abc'
const conversationKey = (memoryService as any).buildMemoryKey(
mockContext,
{ memoryType: 'conversation', conversationId },
blockId
)
const slidingWindowKey = (memoryService as any).buildMemoryKey(
mockContext,
{ memoryType: 'sliding_window', conversationId },
blockId
)
const slidingTokensKey = (memoryService as any).buildMemoryKey(
mockContext,
{ memoryType: 'sliding_window_tokens', conversationId },
blockId
)
// All should produce the same key - memory type only affects processing
expect(conversationKey).toBe('user-123:block-abc')
expect(slidingWindowKey).toBe('user-123:block-abc')
expect(slidingTokensKey).toBe('user-123:block-abc')
})
describe('validateConversationId', () => {
it('should throw error for missing conversationId', () => {
const inputs: AgentInputs = {
memoryType: 'conversation',
// conversationId missing
}
expect(() => {
;(memoryService as any).buildMemoryKey(mockContext, inputs, 'test-block-id')
}).toThrow('Conversation ID is required for all memory types')
;(memoryService as any).validateConversationId(undefined)
}).toThrow('Conversation ID is required')
})
it('should throw error for empty conversationId', () => {
const inputs: AgentInputs = {
memoryType: 'conversation',
conversationId: ' ', // Only whitespace
}
expect(() => {
;(memoryService as any).buildMemoryKey(mockContext, inputs, 'test-block-id')
}).toThrow('Conversation ID is required for all memory types')
;(memoryService as any).validateConversationId(' ')
}).toThrow('Conversation ID is required')
})
it('should throw error for too long conversationId', () => {
const longId = 'a'.repeat(MEMORY.MAX_CONVERSATION_ID_LENGTH + 1)
expect(() => {
;(memoryService as any).validateConversationId(longId)
}).toThrow('Conversation ID too long')
})
it('should accept valid conversationId', () => {
expect(() => {
;(memoryService as any).validateConversationId('user-123')
}).not.toThrow()
})
})
describe('validateContent', () => {
it('should throw error for content exceeding max size', () => {
const largeContent = 'x'.repeat(MEMORY.MAX_MESSAGE_CONTENT_BYTES + 1)
expect(() => {
;(memoryService as any).validateContent(largeContent)
}).toThrow('Message content too large')
})
it('should accept content within limit', () => {
const content = 'Normal sized content'
expect(() => {
;(memoryService as any).validateContent(content)
}).not.toThrow()
})
})
describe('Token-based vs Message-based comparison', () => {
it('should produce different results for same message count limit', () => {
it('should produce different results for same limit concept', () => {
const messages: Message[] = [
{ role: 'user', content: 'A' }, // Short message (~1 token)
{ role: 'user', content: 'A' },
{
role: 'assistant',
content: 'This is a much longer response that takes many more tokens',
}, // Long message (~15 tokens)
{ role: 'user', content: 'B' }, // Short message (~1 token)
},
{ role: 'user', content: 'B' },
]
// Message-based: last 2 messages
const messageResult = (memoryService as any).applySlidingWindow(messages, '2')
const messageResult = (memoryService as any).applyWindow(messages, 2)
expect(messageResult.length).toBe(2)
// Token-based: with limit of 10 tokens, might fit all 3 messages or just last 2
const tokenResult = (memoryService as any).applySlidingWindowByTokens(
messages,
'10',
'gpt-4o'
)
// The long message should affect what fits
const tokenResult = (memoryService as any).applyTokenWindow(messages, 10, 'gpt-4o')
expect(tokenResult.length).toBeGreaterThanOrEqual(1)
})
})

View File

@@ -1,661 +1,281 @@
import { randomUUID } from 'node:crypto'
import { db } from '@sim/db'
import { memory } from '@sim/db/schema'
import { and, eq, sql } from 'drizzle-orm'
import { createLogger } from '@/lib/logs/console/logger'
import { getAccurateTokenCount } from '@/lib/tokenization/estimators'
import { MEMORY } from '@/executor/constants'
import type { AgentInputs, Message } from '@/executor/handlers/agent/types'
import type { ExecutionContext } from '@/executor/types'
import { buildAPIUrl, buildAuthHeaders } from '@/executor/utils/http'
import { stringifyJSON } from '@/executor/utils/json'
import { PROVIDER_DEFINITIONS } from '@/providers/models'
const logger = createLogger('Memory')
/**
* Class for managing agent conversation memory
* Handles fetching and persisting messages to the memory table
*/
export class Memory {
/**
* Fetch messages from memory based on memoryType configuration
*/
async fetchMemoryMessages(
ctx: ExecutionContext,
inputs: AgentInputs,
blockId: string
): Promise<Message[]> {
async fetchMemoryMessages(ctx: ExecutionContext, inputs: AgentInputs): Promise<Message[]> {
if (!inputs.memoryType || inputs.memoryType === 'none') {
return []
}
if (!ctx.workflowId) {
logger.warn('Cannot fetch memory without workflowId')
return []
}
const workspaceId = this.requireWorkspaceId(ctx)
this.validateConversationId(inputs.conversationId)
try {
this.validateInputs(inputs.conversationId)
const memoryKey = this.buildMemoryKey(ctx, inputs, blockId)
let messages = await this.fetchFromMemoryAPI(ctx.workflowId, memoryKey)
const messages = await this.fetchMemory(workspaceId, inputs.conversationId!)
switch (inputs.memoryType) {
case 'conversation':
messages = this.applyContextWindowLimit(messages, inputs.model)
break
return this.applyContextWindowLimit(messages, inputs.model)
case 'sliding_window': {
// Default to 10 messages if not specified (matches agent block default)
const windowSize = inputs.slidingWindowSize || '10'
messages = this.applySlidingWindow(messages, windowSize)
break
const limit = this.parsePositiveInt(
inputs.slidingWindowSize,
MEMORY.DEFAULT_SLIDING_WINDOW_SIZE
)
return this.applyWindow(messages, limit)
}
case 'sliding_window_tokens': {
// Default to 4000 tokens if not specified (matches agent block default)
const maxTokens = inputs.slidingWindowTokens || '4000'
messages = this.applySlidingWindowByTokens(messages, maxTokens, inputs.model)
break
}
}
return messages
} catch (error) {
logger.error('Failed to fetch memory messages:', error)
return []
}
}
/**
* Persist assistant response to memory
* Uses atomic append operations to prevent race conditions
*/
async persistMemoryMessage(
ctx: ExecutionContext,
inputs: AgentInputs,
assistantMessage: Message,
blockId: string
): Promise<void> {
if (!inputs.memoryType || inputs.memoryType === 'none') {
return
}
if (!ctx.workflowId) {
logger.warn('Cannot persist memory without workflowId')
return
}
try {
this.validateInputs(inputs.conversationId, assistantMessage.content)
const memoryKey = this.buildMemoryKey(ctx, inputs, blockId)
if (inputs.memoryType === 'sliding_window') {
// Default to 10 messages if not specified (matches agent block default)
const windowSize = inputs.slidingWindowSize || '10'
const existingMessages = await this.fetchFromMemoryAPI(ctx.workflowId, memoryKey)
const updatedMessages = [...existingMessages, assistantMessage]
const messagesToPersist = this.applySlidingWindow(updatedMessages, windowSize)
await this.persistToMemoryAPI(ctx.workflowId, memoryKey, messagesToPersist)
} else if (inputs.memoryType === 'sliding_window_tokens') {
// Default to 4000 tokens if not specified (matches agent block default)
const maxTokens = inputs.slidingWindowTokens || '4000'
const existingMessages = await this.fetchFromMemoryAPI(ctx.workflowId, memoryKey)
const updatedMessages = [...existingMessages, assistantMessage]
const messagesToPersist = this.applySlidingWindowByTokens(
updatedMessages,
maxTokens,
inputs.model
)
await this.persistToMemoryAPI(ctx.workflowId, memoryKey, messagesToPersist)
} else {
// Conversation mode: use atomic append for better concurrency
await this.atomicAppendToMemory(ctx.workflowId, memoryKey, assistantMessage)
}
logger.debug('Successfully persisted memory message', {
workflowId: ctx.workflowId,
key: memoryKey,
})
} catch (error) {
logger.error('Failed to persist memory message:', error)
}
}
/**
* Persist user message to memory before agent execution
*/
async persistUserMessage(
ctx: ExecutionContext,
inputs: AgentInputs,
userMessage: Message,
blockId: string
): Promise<void> {
if (!inputs.memoryType || inputs.memoryType === 'none') {
return
}
if (!ctx.workflowId) {
logger.warn('Cannot persist user message without workflowId')
return
}
try {
const memoryKey = this.buildMemoryKey(ctx, inputs, blockId)
if (inputs.slidingWindowSize && inputs.memoryType === 'sliding_window') {
const existingMessages = await this.fetchFromMemoryAPI(ctx.workflowId, memoryKey)
const updatedMessages = [...existingMessages, userMessage]
const messagesToPersist = this.applySlidingWindow(updatedMessages, inputs.slidingWindowSize)
await this.persistToMemoryAPI(ctx.workflowId, memoryKey, messagesToPersist)
} else if (inputs.slidingWindowTokens && inputs.memoryType === 'sliding_window_tokens') {
const existingMessages = await this.fetchFromMemoryAPI(ctx.workflowId, memoryKey)
const updatedMessages = [...existingMessages, userMessage]
const messagesToPersist = this.applySlidingWindowByTokens(
updatedMessages,
const maxTokens = this.parsePositiveInt(
inputs.slidingWindowTokens,
inputs.model
MEMORY.DEFAULT_SLIDING_WINDOW_TOKENS
)
await this.persistToMemoryAPI(ctx.workflowId, memoryKey, messagesToPersist)
} else {
await this.atomicAppendToMemory(ctx.workflowId, memoryKey, userMessage)
}
} catch (error) {
logger.error('Failed to persist user message:', error)
}
return this.applyTokenWindow(messages, maxTokens, inputs.model)
}
/**
* Build memory key based on conversationId and blockId
* BlockId provides block-level memory isolation
*/
private buildMemoryKey(_ctx: ExecutionContext, inputs: AgentInputs, blockId: string): string {
const { conversationId } = inputs
if (!conversationId || conversationId.trim() === '') {
throw new Error(
'Conversation ID is required for all memory types. ' +
'Please provide a unique identifier (e.g., user-123, session-abc, customer-456).'
)
}
return `${conversationId}:${blockId}`
}
/**
* Apply sliding window to limit number of conversation messages
*
* System message handling:
* - System messages are excluded from the sliding window count
* - Only the first system message is preserved and placed at the start
* - This ensures system prompts remain available while limiting conversation history
*/
private applySlidingWindow(messages: Message[], windowSize: string): Message[] {
const limit = Number.parseInt(windowSize, 10)
if (Number.isNaN(limit) || limit <= 0) {
logger.warn('Invalid sliding window size, returning all messages', { windowSize })
default:
return messages
}
const systemMessages = messages.filter((msg) => msg.role === 'system')
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
const recentMessages = conversationMessages.slice(-limit)
const firstSystemMessage = systemMessages.length > 0 ? [systemMessages[0]] : []
return [...firstSystemMessage, ...recentMessages]
}
/**
* Apply token-based sliding window to limit conversation by token count
*
* System message handling:
* - For consistency with message-based sliding window, the first system message is preserved
* - System messages are excluded from the token count
* - This ensures system prompts are always available while limiting conversation history
*/
private applySlidingWindowByTokens(
messages: Message[],
maxTokens: string,
model?: string
): Message[] {
const tokenLimit = Number.parseInt(maxTokens, 10)
if (Number.isNaN(tokenLimit) || tokenLimit <= 0) {
logger.warn('Invalid token limit, returning all messages', { maxTokens })
return messages
async appendToMemory(
ctx: ExecutionContext,
inputs: AgentInputs,
message: Message
): Promise<void> {
if (!inputs.memoryType || inputs.memoryType === 'none') {
return
}
// Separate system messages from conversation messages for consistent handling
const systemMessages = messages.filter((msg) => msg.role === 'system')
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
const workspaceId = this.requireWorkspaceId(ctx)
this.validateConversationId(inputs.conversationId)
this.validateContent(message.content)
const key = inputs.conversationId!
await this.appendMessage(workspaceId, key, message)
logger.debug('Appended message to memory', {
workspaceId,
key,
role: message.role,
})
}
async seedMemory(ctx: ExecutionContext, inputs: AgentInputs, messages: Message[]): Promise<void> {
if (!inputs.memoryType || inputs.memoryType === 'none') {
return
}
const workspaceId = this.requireWorkspaceId(ctx)
const conversationMessages = messages.filter((m) => m.role !== 'system')
if (conversationMessages.length === 0) {
return
}
this.validateConversationId(inputs.conversationId)
const key = inputs.conversationId!
let messagesToStore = conversationMessages
if (inputs.memoryType === 'sliding_window') {
const limit = this.parsePositiveInt(
inputs.slidingWindowSize,
MEMORY.DEFAULT_SLIDING_WINDOW_SIZE
)
messagesToStore = this.applyWindow(conversationMessages, limit)
} else if (inputs.memoryType === 'sliding_window_tokens') {
const maxTokens = this.parsePositiveInt(
inputs.slidingWindowTokens,
MEMORY.DEFAULT_SLIDING_WINDOW_TOKENS
)
messagesToStore = this.applyTokenWindow(conversationMessages, maxTokens, inputs.model)
}
await this.seedMemoryRecord(workspaceId, key, messagesToStore)
logger.debug('Seeded memory', {
workspaceId,
key,
count: messagesToStore.length,
})
}
wrapStreamForPersistence(
stream: ReadableStream<Uint8Array>,
ctx: ExecutionContext,
inputs: AgentInputs
): ReadableStream<Uint8Array> {
let accumulatedContent = ''
const decoder = new TextDecoder()
const transformStream = new TransformStream<Uint8Array, Uint8Array>({
transform: (chunk, controller) => {
controller.enqueue(chunk)
const decoded = decoder.decode(chunk, { stream: true })
accumulatedContent += decoded
},
flush: () => {
if (accumulatedContent.trim()) {
this.appendToMemory(ctx, inputs, {
role: 'assistant',
content: accumulatedContent,
}).catch((error) => logger.error('Failed to persist streaming response:', error))
}
},
})
return stream.pipeThrough(transformStream)
}
private requireWorkspaceId(ctx: ExecutionContext): string {
if (!ctx.workspaceId) {
throw new Error('workspaceId is required for memory operations')
}
return ctx.workspaceId
}
private applyWindow(messages: Message[], limit: number): Message[] {
return messages.slice(-limit)
}
private applyTokenWindow(messages: Message[], maxTokens: number, model?: string): Message[] {
const result: Message[] = []
let currentTokenCount = 0
let tokenCount = 0
// Add conversation messages from most recent backwards
for (let i = conversationMessages.length - 1; i >= 0; i--) {
const message = conversationMessages[i]
const messageTokens = getAccurateTokenCount(message.content, model)
for (let i = messages.length - 1; i >= 0; i--) {
const msg = messages[i]
const msgTokens = getAccurateTokenCount(msg.content, model)
if (currentTokenCount + messageTokens <= tokenLimit) {
result.unshift(message)
currentTokenCount += messageTokens
if (tokenCount + msgTokens <= maxTokens) {
result.unshift(msg)
tokenCount += msgTokens
} else if (result.length === 0) {
logger.warn('Single message exceeds token limit, including anyway', {
messageTokens,
tokenLimit,
messageRole: message.role,
})
result.unshift(message)
currentTokenCount += messageTokens
result.unshift(msg)
break
} else {
// Token limit reached, stop processing
break
}
}
logger.debug('Applied token-based sliding window', {
totalMessages: messages.length,
conversationMessages: conversationMessages.length,
includedMessages: result.length,
totalTokens: currentTokenCount,
tokenLimit,
})
// Preserve first system message and prepend to results (consistent with message-based window)
const firstSystemMessage = systemMessages.length > 0 ? [systemMessages[0]] : []
return [...firstSystemMessage, ...result]
return result
}
/**
* Apply context window limit based on model's maximum context window
* Auto-trims oldest conversation messages when approaching the model's context limit
* Uses 90% of context window (10% buffer for response)
* Only applies if model has contextWindow defined and contextInformationAvailable !== false
*/
private applyContextWindowLimit(messages: Message[], model?: string): Message[] {
if (!model) {
return messages
}
let contextWindow: number | undefined
if (!model) return messages
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
if (provider.contextInformationAvailable === false) {
continue
}
if (provider.contextInformationAvailable === false) continue
const matchesPattern = provider.modelPatterns?.some((pattern) => pattern.test(model))
const matchesPattern = provider.modelPatterns?.some((p) => p.test(model))
const matchesModel = provider.models.some((m) => m.id === model)
if (matchesPattern || matchesModel) {
const modelDef = provider.models.find((m) => m.id === model)
if (modelDef?.contextWindow) {
contextWindow = modelDef.contextWindow
break
const maxTokens = Math.floor(modelDef.contextWindow * MEMORY.CONTEXT_WINDOW_UTILIZATION)
return this.applyTokenWindow(messages, maxTokens, model)
}
}
}
if (!contextWindow) {
logger.debug('No context window information available for model, skipping auto-trim', {
model,
})
return messages
}
const maxTokens = Math.floor(contextWindow * 0.9)
logger.debug('Applying context window limit', {
model,
contextWindow,
maxTokens,
totalMessages: messages.length,
})
const systemMessages = messages.filter((msg) => msg.role === 'system')
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
// Count tokens used by system messages first
let systemTokenCount = 0
for (const msg of systemMessages) {
systemTokenCount += getAccurateTokenCount(msg.content, model)
}
// Calculate remaining tokens available for conversation messages
const remainingTokens = Math.max(0, maxTokens - systemTokenCount)
if (systemTokenCount >= maxTokens) {
logger.warn('System messages exceed context window limit, including anyway', {
systemTokenCount,
maxTokens,
systemMessageCount: systemMessages.length,
})
return systemMessages
}
const result: Message[] = []
let currentTokenCount = 0
for (let i = conversationMessages.length - 1; i >= 0; i--) {
const message = conversationMessages[i]
const messageTokens = getAccurateTokenCount(message.content, model)
if (currentTokenCount + messageTokens <= remainingTokens) {
result.unshift(message)
currentTokenCount += messageTokens
} else if (result.length === 0) {
logger.warn('Single message exceeds remaining context window, including anyway', {
messageTokens,
remainingTokens,
systemTokenCount,
messageRole: message.role,
})
result.unshift(message)
currentTokenCount += messageTokens
break
} else {
logger.info('Auto-trimmed conversation history to fit context window', {
originalMessages: conversationMessages.length,
trimmedMessages: result.length,
conversationTokens: currentTokenCount,
systemTokens: systemTokenCount,
totalTokens: currentTokenCount + systemTokenCount,
maxTokens,
})
break
}
}
return [...systemMessages, ...result]
}
/**
* Fetch messages from memory API
*/
private async fetchFromMemoryAPI(workflowId: string, key: string): Promise<Message[]> {
try {
const isBrowser = typeof window !== 'undefined'
if (!isBrowser) {
return await this.fetchFromMemoryDirect(workflowId, key)
}
const headers = await buildAuthHeaders()
const url = buildAPIUrl(`/api/memory/${encodeURIComponent(key)}`, { workflowId })
const response = await fetch(url.toString(), {
method: 'GET',
headers,
})
if (!response.ok) {
if (response.status === 404) {
return []
}
throw new Error(`Failed to fetch memory: ${response.status} ${response.statusText}`)
}
const result = await response.json()
if (!result.success) {
throw new Error(result.error || 'Failed to fetch memory')
}
const memoryData = result.data?.data || result.data
if (Array.isArray(memoryData)) {
return memoryData.filter(
(msg) => msg && typeof msg === 'object' && 'role' in msg && 'content' in msg
)
}
return []
} catch (error) {
logger.error('Error fetching from memory API:', error)
return []
}
}
/**
* Direct database access
*/
private async fetchFromMemoryDirect(workflowId: string, key: string): Promise<Message[]> {
try {
const { db } = await import('@sim/db')
const { memory } = await import('@sim/db/schema')
const { and, eq } = await import('drizzle-orm')
private async fetchMemory(workspaceId: string, key: string): Promise<Message[]> {
const result = await db
.select({
data: memory.data,
})
.select({ data: memory.data })
.from(memory)
.where(and(eq(memory.workflowId, workflowId), eq(memory.key, key)))
.where(and(eq(memory.workspaceId, workspaceId), eq(memory.key, key)))
.limit(1)
if (result.length === 0) {
return []
}
if (result.length === 0) return []
const memoryData = result[0].data as any
if (Array.isArray(memoryData)) {
return memoryData.filter(
(msg) => msg && typeof msg === 'object' && 'role' in msg && 'content' in msg
const data = result[0].data
if (!Array.isArray(data)) return []
return data.filter(
(msg): msg is Message => msg && typeof msg === 'object' && 'role' in msg && 'content' in msg
)
}
return []
} catch (error) {
logger.error('Error fetching from memory database:', error)
return []
}
}
/**
* Persist messages to memory API
*/
private async persistToMemoryAPI(
workflowId: string,
private async seedMemoryRecord(
workspaceId: string,
key: string,
messages: Message[]
): Promise<void> {
try {
const isBrowser = typeof window !== 'undefined'
if (!isBrowser) {
await this.persistToMemoryDirect(workflowId, key, messages)
return
}
const headers = await buildAuthHeaders()
const url = buildAPIUrl('/api/memory')
const response = await fetch(url.toString(), {
method: 'POST',
headers: {
...headers,
'Content-Type': 'application/json',
},
body: stringifyJSON({
workflowId,
key,
data: messages,
}),
})
if (!response.ok) {
throw new Error(`Failed to persist memory: ${response.status} ${response.statusText}`)
}
const result = await response.json()
if (!result.success) {
throw new Error(result.error || 'Failed to persist memory')
}
} catch (error) {
logger.error('Error persisting to memory API:', error)
throw error
}
}
/**
* Atomically append a message to memory
*/
private async atomicAppendToMemory(
workflowId: string,
key: string,
message: Message
): Promise<void> {
try {
const isBrowser = typeof window !== 'undefined'
if (!isBrowser) {
await this.atomicAppendToMemoryDirect(workflowId, key, message)
} else {
const headers = await buildAuthHeaders()
const url = buildAPIUrl('/api/memory')
const response = await fetch(url.toString(), {
method: 'POST',
headers: {
...headers,
'Content-Type': 'application/json',
},
body: stringifyJSON({
workflowId,
key,
data: message,
}),
})
if (!response.ok) {
throw new Error(`Failed to append memory: ${response.status} ${response.statusText}`)
}
const result = await response.json()
if (!result.success) {
throw new Error(result.error || 'Failed to append memory')
}
}
} catch (error) {
logger.error('Error appending to memory:', error)
throw error
}
}
/**
* Direct database atomic append for server-side
* Uses PostgreSQL JSONB concatenation operator for atomic operations
*/
private async atomicAppendToMemoryDirect(
workflowId: string,
key: string,
message: Message
): Promise<void> {
try {
const { db } = await import('@sim/db')
const { memory } = await import('@sim/db/schema')
const { sql } = await import('drizzle-orm')
const { randomUUID } = await import('node:crypto')
const now = new Date()
const id = randomUUID()
await db
.insert(memory)
.values({
id,
workflowId,
id: randomUUID(),
workspaceId,
key,
data: messages,
createdAt: now,
updatedAt: now,
})
.onConflictDoNothing()
}
private async appendMessage(workspaceId: string, key: string, message: Message): Promise<void> {
const now = new Date()
await db
.insert(memory)
.values({
id: randomUUID(),
workspaceId,
key,
data: [message],
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: [memory.workflowId, memory.key],
target: [memory.workspaceId, memory.key],
set: {
data: sql`${memory.data} || ${JSON.stringify([message])}::jsonb`,
updatedAt: now,
},
})
}
logger.debug('Atomically appended message to memory', {
workflowId,
key,
})
} catch (error) {
logger.error('Error in atomic append to memory database:', error)
throw error
private parsePositiveInt(value: string | undefined, defaultValue: number): number {
if (!value) return defaultValue
const parsed = Number.parseInt(value, 10)
if (Number.isNaN(parsed) || parsed <= 0) return defaultValue
return parsed
}
private validateConversationId(conversationId?: string): void {
if (!conversationId || conversationId.trim() === '') {
throw new Error('Conversation ID is required')
}
if (conversationId.length > MEMORY.MAX_CONVERSATION_ID_LENGTH) {
throw new Error(
`Conversation ID too long (max ${MEMORY.MAX_CONVERSATION_ID_LENGTH} characters)`
)
}
}
/**
* Direct database access for server-side persistence
* Uses UPSERT to handle race conditions atomically
*/
private async persistToMemoryDirect(
workflowId: string,
key: string,
messages: Message[]
): Promise<void> {
try {
const { db } = await import('@sim/db')
const { memory } = await import('@sim/db/schema')
const { randomUUID } = await import('node:crypto')
const now = new Date()
const id = randomUUID()
await db
.insert(memory)
.values({
id,
workflowId,
key,
data: messages,
createdAt: now,
updatedAt: now,
})
.onConflictDoUpdate({
target: [memory.workflowId, memory.key],
set: {
data: messages,
updatedAt: now,
},
})
} catch (error) {
logger.error('Error persisting to memory database:', error)
throw error
}
}
/**
* Validate inputs to prevent malicious data or performance issues
*/
private validateInputs(conversationId?: string, content?: string): void {
if (conversationId) {
if (conversationId.length > 255) {
throw new Error('Conversation ID too long (max 255 characters)')
}
if (!/^[a-zA-Z0-9_\-:.@]+$/.test(conversationId)) {
logger.warn('Conversation ID contains special characters', { conversationId })
}
}
if (content) {
const contentSize = Buffer.byteLength(content, 'utf8')
const MAX_CONTENT_SIZE = 100 * 1024 // 100KB
if (contentSize > MAX_CONTENT_SIZE) {
throw new Error(`Message content too large (${contentSize} bytes, max ${MAX_CONTENT_SIZE})`)
}
private validateContent(content: string): void {
const size = Buffer.byteLength(content, 'utf8')
if (size > MEMORY.MAX_MESSAGE_CONTENT_BYTES) {
throw new Error(
`Message content too large (${size} bytes, max ${MEMORY.MAX_MESSAGE_CONTENT_BYTES})`
)
}
}
}

View File

@@ -41,6 +41,7 @@ export interface ToolInput {
export interface Message {
role: 'system' | 'user' | 'assistant'
content: string
executionId?: string
function_call?: any
tool_calls?: any[]
}

View File

@@ -1,6 +1,6 @@
import { useQuery } from '@tanstack/react-query'
import { createLogger } from '@/lib/logs/console/logger'
import type { ProviderName } from '@/stores/providers/types'
import type { OpenRouterModelInfo, ProviderName } from '@/stores/providers/types'
const logger = createLogger('ProviderModelsQuery')
@@ -11,7 +11,12 @@ const providerEndpoints: Record<ProviderName, string> = {
openrouter: '/api/providers/openrouter/models',
}
async function fetchProviderModels(provider: ProviderName): Promise<string[]> {
interface ProviderModelsResponse {
models: string[]
modelInfo?: Record<string, OpenRouterModelInfo>
}
async function fetchProviderModels(provider: ProviderName): Promise<ProviderModelsResponse> {
const response = await fetch(providerEndpoints[provider])
if (!response.ok) {
@@ -24,8 +29,12 @@ async function fetchProviderModels(provider: ProviderName): Promise<string[]> {
const data = await response.json()
const models: string[] = Array.isArray(data.models) ? data.models : []
const uniqueModels = provider === 'openrouter' ? Array.from(new Set(models)) : models
return provider === 'openrouter' ? Array.from(new Set(models)) : models
return {
models: uniqueModels,
modelInfo: data.modelInfo,
}
}
export function useProviderModels(provider: ProviderName) {

View File

@@ -24,20 +24,6 @@ describe('Email Validation', () => {
expect(result.checks.disposable).toBe(false)
})
it.concurrent('should accept legitimate business emails', async () => {
const legitimateEmails = [
'test@gmail.com',
'no-reply@yahoo.com',
'user12345@outlook.com',
'longusernamehere@gmail.com',
]
for (const email of legitimateEmails) {
const result = await validateEmail(email)
expect(result.isValid).toBe(true)
}
})
it.concurrent('should reject consecutive dots (RFC violation)', async () => {
const result = await validateEmail('user..name@example.com')
expect(result.isValid).toBe(false)

View File

@@ -0,0 +1,114 @@
import { v4 as uuidv4 } from 'uuid'
import { createLogger } from '@/lib/logs/console/logger'
import { LoggingSession } from '@/lib/logs/execution/logging-session'
import { executeWorkflowCore } from '@/lib/workflows/executor/execution-core'
import { PauseResumeManager } from '@/lib/workflows/executor/human-in-the-loop-manager'
import { type ExecutionMetadata, ExecutionSnapshot } from '@/executor/execution/snapshot'
const logger = createLogger('WorkflowExecution')
export interface ExecuteWorkflowOptions {
enabled: boolean
selectedOutputs?: string[]
isSecureMode?: boolean
workflowTriggerType?: 'api' | 'chat'
onStream?: (streamingExec: any) => Promise<void>
onBlockComplete?: (blockId: string, output: any) => Promise<void>
skipLoggingComplete?: boolean
}
export interface WorkflowInfo {
id: string
userId: string
workspaceId?: string | null
isDeployed?: boolean
variables?: Record<string, any>
}
export async function executeWorkflow(
workflow: WorkflowInfo,
requestId: string,
input: any | undefined,
actorUserId: string,
streamConfig?: ExecuteWorkflowOptions,
providedExecutionId?: string
): Promise<any> {
if (!workflow.workspaceId) {
throw new Error(`Workflow ${workflow.id} has no workspaceId`)
}
const workflowId = workflow.id
const workspaceId = workflow.workspaceId
const executionId = providedExecutionId || uuidv4()
const triggerType = streamConfig?.workflowTriggerType || 'api'
const loggingSession = new LoggingSession(workflowId, executionId, triggerType, requestId)
try {
const metadata: ExecutionMetadata = {
requestId,
executionId,
workflowId,
workspaceId,
userId: actorUserId,
workflowUserId: workflow.userId,
triggerType,
useDraftState: false,
startTime: new Date().toISOString(),
isClientSession: false,
}
const snapshot = new ExecutionSnapshot(
metadata,
workflow,
input,
workflow.variables || {},
streamConfig?.selectedOutputs || []
)
const result = await executeWorkflowCore({
snapshot,
callbacks: {
onStream: streamConfig?.onStream,
onBlockComplete: streamConfig?.onBlockComplete
? async (blockId: string, _blockName: string, _blockType: string, output: any) => {
await streamConfig.onBlockComplete!(blockId, output)
}
: undefined,
},
loggingSession,
})
if (result.status === 'paused') {
if (!result.snapshotSeed) {
logger.error(`[${requestId}] Missing snapshot seed for paused execution`, {
executionId,
})
} else {
await PauseResumeManager.persistPauseResult({
workflowId,
executionId,
pausePoints: result.pausePoints || [],
snapshotSeed: result.snapshotSeed,
executorUserId: result.metadata?.userId,
})
}
} else {
await PauseResumeManager.processQueuedResumes(executionId)
}
if (streamConfig?.skipLoggingComplete) {
return {
...result,
_streamingMetadata: {
loggingSession,
processedInput: input,
},
}
}
return result
} catch (error: any) {
logger.error(`[${requestId}] Workflow execution failed:`, error)
throw error
}
}

View File

@@ -1,17 +1,23 @@
import {
extractBlockIdFromOutputId,
extractPathFromOutputId,
traverseObjectPath,
} from '@/lib/core/utils/response-format'
import { encodeSSE } from '@/lib/core/utils/sse'
import { createLogger } from '@/lib/logs/console/logger'
import { buildTraceSpans } from '@/lib/logs/execution/trace-spans/trace-spans'
import { processStreamingBlockLogs } from '@/lib/tokenization'
import { executeWorkflow } from '@/lib/workflows/executor/execute-workflow'
import type { ExecutionResult } from '@/executor/types'
const logger = createLogger('WorkflowStreaming')
const DANGEROUS_KEYS = ['__proto__', 'constructor', 'prototype']
export interface StreamingConfig {
selectedOutputs?: string[]
isSecureMode?: boolean
workflowTriggerType?: 'api' | 'chat'
onStream?: (streamingExec: {
stream: ReadableStream
execution?: { blockId?: string }
}) => Promise<void>
}
export interface StreamingResponseOptions {
@@ -26,46 +32,156 @@ export interface StreamingResponseOptions {
input: any
executingUserId: string
streamConfig: StreamingConfig
createFilteredResult: (result: ExecutionResult) => any
executionId?: string
}
interface StreamingState {
streamedContent: Map<string, string>
processedOutputs: Set<string>
streamCompletionTimes: Map<string, number>
}
function extractOutputValue(output: any, path: string): any {
let value = traverseObjectPath(output, path)
if (value === undefined && output?.response) {
value = traverseObjectPath(output.response, path)
}
return value
}
function isDangerousKey(key: string): boolean {
return DANGEROUS_KEYS.includes(key)
}
function buildMinimalResult(
result: ExecutionResult,
selectedOutputs: string[] | undefined,
streamedContent: Map<string, string>,
requestId: string
): { success: boolean; error?: string; output: Record<string, any> } {
const minimalResult = {
success: result.success,
error: result.error,
output: {} as Record<string, any>,
}
if (!selectedOutputs?.length) {
minimalResult.output = result.output || {}
return minimalResult
}
if (!result.output || !result.logs) {
return minimalResult
}
for (const outputId of selectedOutputs) {
const blockId = extractBlockIdFromOutputId(outputId)
if (streamedContent.has(blockId)) {
continue
}
if (isDangerousKey(blockId)) {
logger.warn(`[${requestId}] Blocked dangerous blockId: ${blockId}`)
continue
}
const path = extractPathFromOutputId(outputId, blockId)
if (isDangerousKey(path)) {
logger.warn(`[${requestId}] Blocked dangerous path: ${path}`)
continue
}
const blockLog = result.logs.find((log: any) => log.blockId === blockId)
if (!blockLog?.output) {
continue
}
const value = extractOutputValue(blockLog.output, path)
if (value === undefined) {
continue
}
if (!minimalResult.output[blockId]) {
minimalResult.output[blockId] = Object.create(null)
}
minimalResult.output[blockId][path] = value
}
return minimalResult
}
function updateLogsWithStreamedContent(logs: any[], state: StreamingState): any[] {
return logs.map((log: any) => {
if (!state.streamedContent.has(log.blockId)) {
return log
}
const content = state.streamedContent.get(log.blockId)
const updatedLog = { ...log }
if (state.streamCompletionTimes.has(log.blockId)) {
const completionTime = state.streamCompletionTimes.get(log.blockId)!
const startTime = new Date(log.startedAt).getTime()
updatedLog.endedAt = new Date(completionTime).toISOString()
updatedLog.durationMs = completionTime - startTime
}
if (log.output && content) {
updatedLog.output = { ...log.output, content }
}
return updatedLog
})
}
async function completeLoggingSession(result: ExecutionResult): Promise<void> {
if (!result._streamingMetadata?.loggingSession) {
return
}
const { traceSpans, totalDuration } = buildTraceSpans(result)
await result._streamingMetadata.loggingSession.safeComplete({
endedAt: new Date().toISOString(),
totalDurationMs: totalDuration || 0,
finalOutput: result.output || {},
traceSpans: (traceSpans || []) as any,
workflowInput: result._streamingMetadata.processedInput,
})
result._streamingMetadata = undefined
}
export async function createStreamingResponse(
options: StreamingResponseOptions
): Promise<ReadableStream> {
const {
requestId,
workflow,
input,
executingUserId,
streamConfig,
createFilteredResult,
executionId,
} = options
const { executeWorkflow, createFilteredResult: defaultFilteredResult } = await import(
'@/app/api/workflows/[id]/execute/route'
)
const filterResultFn = createFilteredResult || defaultFilteredResult
const { requestId, workflow, input, executingUserId, streamConfig, executionId } = options
return new ReadableStream({
async start(controller) {
try {
const streamedContent = new Map<string, string>()
const processedOutputs = new Set<string>()
const streamCompletionTimes = new Map<string, number>()
const state: StreamingState = {
streamedContent: new Map(),
processedOutputs: new Set(),
streamCompletionTimes: new Map(),
}
const sendChunk = (blockId: string, content: string) => {
const separator = processedOutputs.size > 0 ? '\n\n' : ''
const separator = state.processedOutputs.size > 0 ? '\n\n' : ''
controller.enqueue(encodeSSE({ blockId, chunk: separator + content }))
processedOutputs.add(blockId)
state.processedOutputs.add(blockId)
}
const onStreamCallback = async (streamingExec: {
stream: ReadableStream
execution?: { blockId?: string }
}) => {
const blockId = streamingExec.execution?.blockId || 'unknown'
const blockId = streamingExec.execution?.blockId
if (!blockId) {
logger.warn(`[${requestId}] Streaming execution missing blockId`)
return
}
const reader = streamingExec.stream.getReader()
const decoder = new TextDecoder()
let isFirstChunk = true
@@ -74,13 +190,15 @@ export async function createStreamingResponse(
while (true) {
const { done, value } = await reader.read()
if (done) {
// Record when this stream completed
streamCompletionTimes.set(blockId, Date.now())
state.streamCompletionTimes.set(blockId, Date.now())
break
}
const textChunk = decoder.decode(value, { stream: true })
streamedContent.set(blockId, (streamedContent.get(blockId) || '') + textChunk)
state.streamedContent.set(
blockId,
(state.streamedContent.get(blockId) || '') + textChunk
)
if (isFirstChunk) {
sendChunk(blockId, textChunk)
@@ -89,37 +207,34 @@ export async function createStreamingResponse(
controller.enqueue(encodeSSE({ blockId, chunk: textChunk }))
}
}
} catch (streamError) {
logger.error(`[${requestId}] Error reading agent stream:`, streamError)
} catch (error) {
logger.error(`[${requestId}] Error reading stream for block ${blockId}:`, error)
controller.enqueue(
encodeSSE({
event: 'stream_error',
blockId,
error: streamError instanceof Error ? streamError.message : 'Stream reading error',
error: error instanceof Error ? error.message : 'Stream reading error',
})
)
}
}
const onBlockCompleteCallback = async (blockId: string, output: any) => {
if (!streamConfig.selectedOutputs?.length) return
if (!streamConfig.selectedOutputs?.length) {
return
}
const { extractBlockIdFromOutputId, extractPathFromOutputId, traverseObjectPath } =
await import('@/lib/core/utils/response-format')
if (state.streamedContent.has(blockId)) {
return
}
const matchingOutputs = streamConfig.selectedOutputs.filter(
(outputId) => extractBlockIdFromOutputId(outputId) === blockId
)
if (!matchingOutputs.length) return
for (const outputId of matchingOutputs) {
const path = extractPathFromOutputId(outputId, blockId)
let outputValue = traverseObjectPath(output, path)
if (outputValue === undefined && output.response) {
outputValue = traverseObjectPath(output.response, path)
}
const outputValue = extractOutputValue(output, path)
if (outputValue !== undefined) {
const formattedOutput =
@@ -129,6 +244,7 @@ export async function createStreamingResponse(
}
}
try {
const result = await executeWorkflow(
workflow,
requestId,
@@ -141,97 +257,24 @@ export async function createStreamingResponse(
workflowTriggerType: streamConfig.workflowTriggerType,
onStream: onStreamCallback,
onBlockComplete: onBlockCompleteCallback,
skipLoggingComplete: true, // We'll complete logging after tokenization
skipLoggingComplete: true,
},
executionId
)
if (result.logs && streamedContent.size > 0) {
result.logs = result.logs.map((log: any) => {
if (streamedContent.has(log.blockId)) {
const content = streamedContent.get(log.blockId)
// Update timing to reflect actual stream completion
if (streamCompletionTimes.has(log.blockId)) {
const completionTime = streamCompletionTimes.get(log.blockId)!
const startTime = new Date(log.startedAt).getTime()
log.endedAt = new Date(completionTime).toISOString()
log.durationMs = completionTime - startTime
if (result.logs && state.streamedContent.size > 0) {
result.logs = updateLogsWithStreamedContent(result.logs, state)
processStreamingBlockLogs(result.logs, state.streamedContent)
}
if (log.output && content) {
return { ...log, output: { ...log.output, content } }
}
}
return log
})
await completeLoggingSession(result)
const { processStreamingBlockLogs } = await import('@/lib/tokenization')
processStreamingBlockLogs(result.logs, streamedContent)
}
// Complete the logging session with updated trace spans that include cost data
if (result._streamingMetadata?.loggingSession) {
const { buildTraceSpans } = await import('@/lib/logs/execution/trace-spans/trace-spans')
const { traceSpans, totalDuration } = buildTraceSpans(result)
await result._streamingMetadata.loggingSession.safeComplete({
endedAt: new Date().toISOString(),
totalDurationMs: totalDuration || 0,
finalOutput: result.output || {},
traceSpans: (traceSpans || []) as any,
workflowInput: result._streamingMetadata.processedInput,
})
result._streamingMetadata = undefined
}
// Create a minimal result with only selected outputs
const minimalResult = {
success: result.success,
error: result.error,
output: {} as any,
}
if (streamConfig.selectedOutputs?.length && result.output) {
const { extractBlockIdFromOutputId, extractPathFromOutputId, traverseObjectPath } =
await import('@/lib/core/utils/response-format')
for (const outputId of streamConfig.selectedOutputs) {
const blockId = extractBlockIdFromOutputId(outputId)
const path = extractPathFromOutputId(outputId, blockId)
if (result.logs) {
const blockLog = result.logs.find((log: any) => log.blockId === blockId)
if (blockLog?.output) {
let value = traverseObjectPath(blockLog.output, path)
if (value === undefined && blockLog.output.response) {
value = traverseObjectPath(blockLog.output.response, path)
}
if (value !== undefined) {
const dangerousKeys = ['__proto__', 'constructor', 'prototype']
if (dangerousKeys.includes(blockId) || dangerousKeys.includes(path)) {
logger.warn(
`[${requestId}] Blocked potentially dangerous property assignment`,
{
blockId,
path,
}
const minimalResult = buildMinimalResult(
result,
streamConfig.selectedOutputs,
state.streamedContent,
requestId
)
continue
}
if (!minimalResult.output[blockId]) {
minimalResult.output[blockId] = Object.create(null)
}
minimalResult.output[blockId][path] = value
}
}
}
}
} else if (!streamConfig.selectedOutputs?.length) {
minimalResult.output = result.output
}
controller.enqueue(encodeSSE({ event: 'final', data: minimalResult }))
controller.enqueue(encodeSSE('[DONE]'))

View File

@@ -37,6 +37,7 @@
"@browserbasehq/stagehand": "^3.0.5",
"@cerebras/cerebras_cloud_sdk": "^1.23.0",
"@e2b/code-interpreter": "^2.0.0",
"@google/genai": "1.34.0",
"@hookform/resolvers": "^4.1.3",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/exporter-jaeger": "2.1.0",

File diff suppressed because it is too large Load Diff

View File

@@ -1,22 +1,49 @@
import type {
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStreamEvent,
Usage,
} from '@anthropic-ai/sdk/resources'
import { createLogger } from '@/lib/logs/console/logger'
import { trackForcedToolUsage } from '@/providers/utils'
const logger = createLogger('AnthropicUtils')
/**
* Helper to wrap Anthropic streaming into a browser-friendly ReadableStream
*/
export interface AnthropicStreamUsage {
input_tokens: number
output_tokens: number
}
export function createReadableStreamFromAnthropicStream(
anthropicStream: AsyncIterable<any>
): ReadableStream {
anthropicStream: AsyncIterable<RawMessageStreamEvent>,
onComplete?: (content: string, usage: AnthropicStreamUsage) => void
): ReadableStream<Uint8Array> {
let fullContent = ''
let inputTokens = 0
let outputTokens = 0
return new ReadableStream({
async start(controller) {
try {
for await (const event of anthropicStream) {
if (event.type === 'content_block_delta' && event.delta?.text) {
controller.enqueue(new TextEncoder().encode(event.delta.text))
if (event.type === 'message_start') {
const startEvent = event as RawMessageStartEvent
const usage: Usage = startEvent.message.usage
inputTokens = usage.input_tokens
} else if (event.type === 'message_delta') {
const deltaEvent = event as RawMessageDeltaEvent
outputTokens = deltaEvent.usage.output_tokens
} else if (event.type === 'content_block_delta' && event.delta.type === 'text_delta') {
const text = event.delta.text
fullContent += text
controller.enqueue(new TextEncoder().encode(text))
}
}
if (onComplete) {
onComplete(fullContent, { input_tokens: inputTokens, output_tokens: outputTokens })
}
controller.close()
} catch (err) {
controller.error(err)
@@ -25,16 +52,10 @@ export function createReadableStreamFromAnthropicStream(
})
}
/**
* Helper function to generate a simple unique ID for tool uses
*/
export function generateToolUseId(toolName: string): string {
return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}`
}
/**
* Helper function to check for forced tool usage in Anthropic responses
*/
export function checkForForcedToolUsage(
response: any,
toolChoice: any,
@@ -45,16 +66,11 @@ export function checkForForcedToolUsage(
const toolUses = response.content.filter((item: any) => item.type === 'tool_use')
if (toolUses.length > 0) {
// Convert Anthropic tool_use format to a format trackForcedToolUsage can understand
const adaptedToolCalls = toolUses.map((tool: any) => ({
name: tool.name,
}))
// Convert Anthropic tool_choice format to match OpenAI format for tracking
const adaptedToolCalls = toolUses.map((tool: any) => ({ name: tool.name }))
const adaptedToolChoice =
toolChoice.type === 'tool' ? { function: { name: toolChoice.name } } : toolChoice
const result = trackForcedToolUsage(
return trackForcedToolUsage(
adaptedToolCalls,
adaptedToolChoice,
logger,
@@ -62,8 +78,6 @@ export function checkForForcedToolUsage(
forcedTools,
usedForcedTools
)
return result
}
}
return null

View File

@@ -1,4 +1,5 @@
import { AzureOpenAI } from 'openai'
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
import { env } from '@/lib/core/config/env'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
@@ -14,7 +15,11 @@ import type {
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
} from '@/providers/utils'
import { executeTool } from '@/tools'
const logger = createLogger('AzureOpenAIProvider')
@@ -43,8 +48,6 @@ export const azureOpenAIProvider: ProviderConfig = {
stream: !!request.stream,
})
// Extract Azure-specific configuration from request or environment
// Priority: request parameters > environment variables
const azureEndpoint = request.azureEndpoint || env.AZURE_OPENAI_ENDPOINT
const azureApiVersion =
request.azureApiVersion || env.AZURE_OPENAI_API_VERSION || '2024-07-01-preview'
@@ -55,17 +58,14 @@ export const azureOpenAIProvider: ProviderConfig = {
)
}
// API key is now handled server-side before this function is called
const azureOpenAI = new AzureOpenAI({
apiKey: request.apiKey,
apiVersion: azureApiVersion,
endpoint: azureEndpoint,
})
// Start with an empty array for all messages
const allMessages = []
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push({
role: 'system',
@@ -73,7 +73,6 @@ export const azureOpenAIProvider: ProviderConfig = {
})
}
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
@@ -81,12 +80,10 @@ export const azureOpenAIProvider: ProviderConfig = {
})
}
// Add remaining messages
if (request.messages) {
allMessages.push(...request.messages)
}
// Transform tools to Azure OpenAI format if provided
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
@@ -98,24 +95,19 @@ export const azureOpenAIProvider: ProviderConfig = {
}))
: undefined
// Build the request payload - use deployment name instead of model name
const deploymentName = (request.model || 'azure/gpt-4o').replace('azure/', '')
const payload: any = {
model: deploymentName, // Azure OpenAI uses deployment name
model: deploymentName,
messages: allMessages,
}
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add GPT-5 specific parameters
if (request.reasoningEffort !== undefined) payload.reasoning_effort = request.reasoningEffort
if (request.verbosity !== undefined) payload.verbosity = request.verbosity
// Add response format for structured output if specified
if (request.responseFormat) {
// Use Azure OpenAI's JSON schema format
payload.response_format = {
type: 'json_schema',
json_schema: {
@@ -128,7 +120,6 @@ export const azureOpenAIProvider: ProviderConfig = {
logger.info('Added JSON schema response format to Azure OpenAI request')
}
// Handle tools and tool usage control
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
if (tools?.length) {
@@ -156,39 +147,40 @@ export const azureOpenAIProvider: ProviderConfig = {
}
}
// Start execution timer for the entire provider execution
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
// Check if we can stream directly (no tools required)
if (request.stream && (!tools || tools.length === 0)) {
logger.info('Using streaming response for Azure OpenAI request')
// Create a streaming request with token usage tracking
const streamResponse = await azureOpenAI.chat.completions.create({
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
stream: true,
stream_options: { include_usage: true },
})
// Start collecting token usage from the stream
const tokenUsage = {
prompt: 0,
completion: 0,
total: 0,
}
const streamResponse = await azureOpenAI.chat.completions.create(streamingParams)
let _streamContent = ''
// Create a StreamingExecution response with a callback to update content and tokens
const streamingResult = {
stream: createReadableStreamFromAzureOpenAIStream(streamResponse, (content, usage) => {
// Update the execution data with the final content and token usage
_streamContent = content
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: usage.prompt_tokens,
completion: usage.completion_tokens,
total: usage.total_tokens,
}
const costResult = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
// Update the timing information with the actual completion time
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
@@ -197,7 +189,6 @@ export const azureOpenAIProvider: ProviderConfig = {
streamingResult.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
// Update the time segment as well
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
streamEndTime
@@ -205,25 +196,13 @@ export const azureOpenAIProvider: ProviderConfig = {
streamEndTime - providerStartTime
}
}
// Update token usage if available from the stream
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokenUsage.prompt,
completion: usage.completion_tokens || tokenUsage.completion,
total: usage.total_tokens || tokenUsage.total,
}
streamingResult.execution.output.tokens = newTokens
}
// We don't need to estimate tokens here as logger.ts will handle that
}),
execution: {
success: true,
output: {
content: '', // Will be filled by the stream completion callback
content: '',
model: request.model,
tokens: tokenUsage,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
@@ -239,9 +218,9 @@ export const azureOpenAIProvider: ProviderConfig = {
},
],
},
// Cost will be calculated in logger
cost: { input: 0, output: 0, total: 0 },
},
logs: [], // No block logs for direct streaming
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -250,17 +229,11 @@ export const azureOpenAIProvider: ProviderConfig = {
},
} as StreamingExecution
// Return the streaming execution object with explicit casting
return streamingResult as StreamingExecution
}
// Make the initial API request
const initialCallTime = Date.now()
// Track the original tool_choice for forced tool tracking
const originalToolChoice = payload.tool_choice
// Track forced tools and their usage
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
@@ -268,7 +241,6 @@ export const azureOpenAIProvider: ProviderConfig = {
const firstResponseTime = Date.now() - initialCallTime
let content = currentResponse.choices[0]?.message?.content || ''
// Collect token information but don't calculate costs - that will be done in logger.ts
const tokens = {
prompt: currentResponse.usage?.prompt_tokens || 0,
completion: currentResponse.usage?.completion_tokens || 0,
@@ -278,15 +250,10 @@ export const azureOpenAIProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
// Track time spent in model vs tools
let modelTime = firstResponseTime
let toolsTime = 0
// Track if a forced tool has been used
let hasUsedForcedTool = false
// Track each model and tool call segment with timestamps
const timeSegments: TimeSegment[] = [
{
type: 'model',
@@ -297,7 +264,6 @@ export const azureOpenAIProvider: ProviderConfig = {
},
]
// Check if a forced tool was used in the first response
const firstCheckResult = checkForForcedToolUsage(
currentResponse,
originalToolChoice,
@@ -309,7 +275,10 @@ export const azureOpenAIProvider: ProviderConfig = {
usedForcedTools = firstCheckResult.usedForcedTools
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
@@ -319,44 +288,85 @@ export const azureOpenAIProvider: ProviderConfig = {
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
// Track time for tool calls in this batch
const toolsStartTime = Date.now()
// Process each tool call
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
// Execute the tool
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) return null
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
// Add to time segments for both success and failure
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('Error processing tool call:', { error, toolName })
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: toolCallsInResponse.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
startTime: startTime,
endTime: endTime,
duration: duration,
})
// Prepare result content for the LLM
let resultContent: any
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
} else {
// Include error information so LLM can respond appropriately
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
@@ -367,78 +377,46 @@ export const azureOpenAIProvider: ProviderConfig = {
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
// Add the tool call and result to messages (both success and failure)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
} catch (error) {
logger.error('Error processing tool call:', {
error,
toolName: toolCall?.function?.name,
})
}
}
// Calculate tool call time for this iteration
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
// Make the next request with updated messages
const nextPayload = {
...payload,
messages: currentMessages,
}
// Update tool_choice based on which forced tools have been used
if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) {
// If we have remaining forced tools, get the next one to force
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
if (remainingTools.length > 0) {
// Force the next tool
nextPayload.tool_choice = {
type: 'function',
function: { name: remainingTools[0] },
}
logger.info(`Forcing next tool: ${remainingTools[0]}`)
} else {
// All forced tools have been used, switch to auto
nextPayload.tool_choice = 'auto'
logger.info('All forced tools have been used, switching to auto tool_choice')
}
}
// Time the next model call
const nextModelStartTime = Date.now()
// Make the next request
currentResponse = await azureOpenAI.chat.completions.create(nextPayload)
// Check if any forced tools were used in this response
const nextCheckResult = checkForForcedToolUsage(
currentResponse,
nextPayload.tool_choice,
@@ -452,7 +430,6 @@ export const azureOpenAIProvider: ProviderConfig = {
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
// Add to time segments
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
@@ -461,15 +438,12 @@ export const azureOpenAIProvider: ProviderConfig = {
duration: thisModelTime,
})
// Add to model time
modelTime += thisModelTime
// Update content if we have a text response
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
// Update token counts
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.prompt_tokens || 0
tokens.completion += currentResponse.usage.completion_tokens || 0
@@ -479,46 +453,44 @@ export const azureOpenAIProvider: ProviderConfig = {
iterationCount++
}
// After all tool processing complete, if streaming was requested, use streaming for the final response
if (request.stream) {
logger.info('Using streaming for final response after tool processing')
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
// This prevents Azure OpenAI API from trying to force tool usage again in the final streaming response
const streamingPayload = {
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
messages: currentMessages,
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
tool_choice: 'auto',
stream: true,
stream_options: { include_usage: true },
}
const streamResponse = await azureOpenAI.chat.completions.create(streamingPayload)
// Create the StreamingExecution object with all collected data
let _streamContent = ''
const streamResponse = await azureOpenAI.chat.completions.create(streamingParams)
const streamingResult = {
stream: createReadableStreamFromAzureOpenAIStream(streamResponse, (content, usage) => {
// Update the execution data with the final content and token usage
_streamContent = content
streamingResult.execution.output.content = content
// Update token usage if available from the stream
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokens.prompt,
completion: usage.completion_tokens || tokens.completion,
total: usage.total_tokens || tokens.total,
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.prompt_tokens,
completion: tokens.completion + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
streamingResult.execution.output.tokens = newTokens
const streamCost = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}),
execution: {
success: true,
output: {
content: '', // Will be filled by the callback
content: '',
model: request.model,
tokens: {
prompt: tokens.prompt,
@@ -542,9 +514,13 @@ export const azureOpenAIProvider: ProviderConfig = {
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
// Cost will be calculated in logger
cost: {
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
logs: [], // No block logs at provider level
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -553,11 +529,9 @@ export const azureOpenAIProvider: ProviderConfig = {
},
} as StreamingExecution
// Return the streaming execution object with explicit casting
return streamingResult as StreamingExecution
}
// Calculate overall timing
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -578,10 +552,8 @@ export const azureOpenAIProvider: ProviderConfig = {
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
// We're not calculating cost here as it will be handled in logger.ts
}
} catch (error) {
// Include timing information even for errors
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -591,9 +563,8 @@ export const azureOpenAIProvider: ProviderConfig = {
duration: totalDuration,
})
// Create a new error with timing information
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
// @ts-ignore
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,

View File

@@ -1,70 +1,37 @@
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import type { Stream } from 'openai/streaming'
import type { Logger } from '@/lib/logs/console/logger'
import { trackForcedToolUsage } from '@/providers/utils'
import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils'
/**
* Helper function to convert an Azure OpenAI stream to a standard ReadableStream
* and collect completion metrics
* Creates a ReadableStream from an Azure OpenAI streaming response.
* Uses the shared OpenAI-compatible streaming utility.
*/
export function createReadableStreamFromAzureOpenAIStream(
azureOpenAIStream: any,
onComplete?: (content: string, usage?: any) => void
azureOpenAIStream: Stream<ChatCompletionChunk>,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of azureOpenAIStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
return createOpenAICompatibleStream(azureOpenAIStream, 'Azure OpenAI', onComplete)
}
/**
* Helper function to check for forced tool usage in responses
* Checks if a forced tool was used in an Azure OpenAI response.
* Uses the shared OpenAI-compatible forced tool usage helper.
*/
export function checkForForcedToolUsage(
response: any,
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
logger: Logger,
_logger: Logger,
forcedTools: string[],
usedForcedTools: string[]
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
let hasUsedForcedTool = false
let updatedUsedForcedTools = [...usedForcedTools]
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
const toolCallsResponse = response.choices[0].message.tool_calls
const result = trackForcedToolUsage(
toolCallsResponse,
return checkForForcedToolUsageOpenAI(
response,
toolChoice,
logger,
'azure-openai',
'Azure OpenAI',
forcedTools,
updatedUsedForcedTools
usedForcedTools,
_logger
)
hasUsedForcedTool = result.hasUsedForcedTool
updatedUsedForcedTools = result.usedForcedTools
}
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
}

View File

@@ -12,6 +12,7 @@ import type {
TimeSegment,
} from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
@@ -35,7 +36,6 @@ export const cerebrasProvider: ProviderConfig = {
throw new Error('API key is required for Cerebras')
}
// Start execution timer for the entire provider execution
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
@@ -44,31 +44,23 @@ export const cerebrasProvider: ProviderConfig = {
apiKey: request.apiKey,
})
// Start with an empty array for all messages
const allMessages = []
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push({
role: 'system',
content: request.systemPrompt,
})
}
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
content: request.context,
})
}
// Add remaining messages
if (request.messages) {
allMessages.push(...request.messages)
}
// Transform tools to Cerebras format if provided
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
@@ -80,26 +72,23 @@ export const cerebrasProvider: ProviderConfig = {
}))
: undefined
// Build the request payload
const payload: any = {
model: (request.model || 'cerebras/llama-3.3-70b').replace('cerebras/', ''),
messages: allMessages,
}
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add response format for structured output if specified
if (request.responseFormat) {
payload.response_format = {
type: 'json_schema',
json_schema: {
name: request.responseFormat.name || 'response_schema',
schema: request.responseFormat.schema || request.responseFormat,
strict: request.responseFormat.strict !== false,
},
}
}
// Handle tools and tool usage control
// Cerebras supports full OpenAI-compatible tool_choice including forcing specific tools
let originalToolChoice: any
let forcedTools: string[] = []
let hasFilteredTools = false
@@ -124,30 +113,40 @@ export const cerebrasProvider: ProviderConfig = {
}
}
// EARLY STREAMING: if streaming requested and no tools to execute, stream directly
if (request.stream && (!tools || tools.length === 0)) {
logger.info('Using streaming response for Cerebras request (no tools)')
const streamResponse: any = await client.chat.completions.create({
...payload,
stream: true,
})
// Start collecting token usage
const tokenUsage = {
prompt: 0,
completion: 0,
total: 0,
const streamingResult = {
stream: createReadableStreamFromCerebrasStream(streamResponse, (content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: usage.prompt_tokens,
completion: usage.completion_tokens,
total: usage.total_tokens,
}
// Create a StreamingExecution response with a readable stream
const streamingResult = {
stream: createReadableStreamFromCerebrasStream(streamResponse),
const costResult = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
}),
execution: {
success: true,
output: {
content: '', // Will be filled by streaming content in chat component
content: '',
model: request.model || 'cerebras/llama-3.3-70b',
tokens: tokenUsage,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
@@ -163,14 +162,9 @@ export const cerebrasProvider: ProviderConfig = {
},
],
},
// Estimate token cost
cost: {
total: 0.0,
input: 0.0,
output: 0.0,
cost: { input: 0, output: 0, total: 0 },
},
},
logs: [], // No block logs for direct streaming
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -180,11 +174,8 @@ export const cerebrasProvider: ProviderConfig = {
},
}
// Return the streaming execution object
return streamingResult as StreamingExecution
}
// Make the initial API request
const initialCallTime = Date.now()
let currentResponse = (await client.chat.completions.create(payload)) as CerebrasResponse
@@ -201,11 +192,8 @@ export const cerebrasProvider: ProviderConfig = {
const currentMessages = [...allMessages]
let iterationCount = 0
// Track time spent in model vs tools
let modelTime = firstResponseTime
let toolsTime = 0
// Track each model and tool call segment with timestamps
const timeSegments: TimeSegment[] = [
{
type: 'model',
@@ -216,17 +204,12 @@ export const cerebrasProvider: ProviderConfig = {
},
]
// Keep track of processed tool calls to avoid duplicates
const processedToolCallIds = new Set()
// Keep track of tool call signatures to detect repeats
const toolCallSignatures = new Set()
try {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
// Break if no tool calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
@@ -234,63 +217,99 @@ export const cerebrasProvider: ProviderConfig = {
break
}
// Track time for tool calls in this batch
const toolsStartTime = Date.now()
// Process each tool call
let processedAnyToolCall = false
let hasRepeatedToolCalls = false
for (const toolCall of toolCallsInResponse) {
// Skip if we've already processed this tool call
const filteredToolCalls = toolCallsInResponse.filter((toolCall) => {
if (processedToolCallIds.has(toolCall.id)) {
continue
return false
}
// Create a signature for this tool call to detect repeats
const toolCallSignature = `${toolCall.function.name}-${toolCall.function.arguments}`
if (toolCallSignatures.has(toolCallSignature)) {
hasRepeatedToolCalls = true
continue
return false
}
try {
processedToolCallIds.add(toolCall.id)
toolCallSignatures.add(toolCallSignature)
processedAnyToolCall = true
return true
})
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
// Execute the tool
const processedAnyToolCall = filteredToolCalls.length > 0
const toolExecutionPromises = filteredToolCalls.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) return null
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
// Add to time segments for both success and failure
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('Error processing tool call (Cerebras):', {
error: error instanceof Error ? error.message : String(error),
toolName,
})
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: filteredToolCalls.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
startTime: startTime,
endTime: endTime,
duration: duration,
})
// Prepare result content for the LLM
let resultContent: any
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
} else {
// Include error information so LLM can respond appropriately
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
@@ -301,44 +320,21 @@ export const cerebrasProvider: ProviderConfig = {
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
// Add the tool call and result to messages (both success and failure)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
} catch (error) {
logger.error('Error processing tool call:', { error })
}
}
// Calculate tool call time for this iteration
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
// Check if we used any forced tools and update tool_choice for the next iteration
let usedForcedTools: string[] = []
if (typeof originalToolChoice === 'object' && forcedTools.length > 0) {
const toolTracking = trackForcedToolUsage(
@@ -351,28 +347,20 @@ export const cerebrasProvider: ProviderConfig = {
)
usedForcedTools = toolTracking.usedForcedTools
const nextToolChoice = toolTracking.nextToolChoice
// Update tool_choice for next iteration if we're still forcing tools
if (nextToolChoice && typeof nextToolChoice === 'object') {
payload.tool_choice = nextToolChoice
} else if (nextToolChoice === 'auto' || !nextToolChoice) {
// All forced tools have been used, switch to auto
payload.tool_choice = 'auto'
}
}
// After processing tool calls, get a final response
if (processedAnyToolCall || hasRepeatedToolCalls) {
// Time the next model call
const nextModelStartTime = Date.now()
// Make the final request
const finalPayload = {
...payload,
messages: currentMessages,
}
// Use tool_choice: 'none' for the final response to avoid an infinite loop
finalPayload.tool_choice = 'none'
const finalResponse = (await client.chat.completions.create(
@@ -382,7 +370,6 @@ export const cerebrasProvider: ProviderConfig = {
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
// Add to time segments
timeSegments.push({
type: 'model',
name: 'Final response',
@@ -391,14 +378,11 @@ export const cerebrasProvider: ProviderConfig = {
duration: thisModelTime,
})
// Add to model time
modelTime += thisModelTime
if (finalResponse.choices[0]?.message?.content) {
content = finalResponse.choices[0].message.content
}
// Update final token counts
if (finalResponse.usage) {
tokens.prompt += finalResponse.usage.prompt_tokens || 0
tokens.completion += finalResponse.usage.completion_tokens || 0
@@ -408,18 +392,13 @@ export const cerebrasProvider: ProviderConfig = {
break
}
// Only continue if we haven't processed any tool calls and haven't seen repeats
if (!processedAnyToolCall && !hasRepeatedToolCalls) {
// Make the next request with updated messages
const nextPayload = {
...payload,
messages: currentMessages,
}
// Time the next model call
const nextModelStartTime = Date.now()
// Make the next request
currentResponse = (await client.chat.completions.create(
nextPayload
)) as CerebrasResponse
@@ -427,7 +406,6 @@ export const cerebrasProvider: ProviderConfig = {
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
// Add to time segments
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
@@ -436,10 +414,7 @@ export const cerebrasProvider: ProviderConfig = {
duration: thisModelTime,
})
// Add to model time
modelTime += thisModelTime
// Update token counts
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.prompt_tokens || 0
tokens.completion += currentResponse.usage.completion_tokens || 0
@@ -453,33 +428,48 @@ export const cerebrasProvider: ProviderConfig = {
logger.error('Error in Cerebras tool processing:', { error })
}
// Calculate overall timing
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
// POST-TOOL-STREAMING: stream after tool calls if requested
if (request.stream) {
logger.info('Using streaming for final Cerebras response after tool processing')
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
// This prevents the API from trying to force tool usage again in the final streaming response
const streamingPayload = {
...payload,
messages: currentMessages,
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
tool_choice: 'auto',
stream: true,
}
const streamResponse: any = await client.chat.completions.create(streamingPayload)
// Create a StreamingExecution response with all collected data
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
const streamingResult = {
stream: createReadableStreamFromCerebrasStream(streamResponse),
stream: createReadableStreamFromCerebrasStream(streamResponse, (content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.prompt_tokens,
completion: tokens.completion + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
const streamCost = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}),
execution: {
success: true,
output: {
content: '', // Will be filled by the callback
content: '',
model: request.model || 'cerebras/llama-3.3-70b',
tokens: {
prompt: tokens.prompt,
@@ -504,12 +494,12 @@ export const cerebrasProvider: ProviderConfig = {
timeSegments: timeSegments,
},
cost: {
total: (tokens.total || 0) * 0.0001,
input: (tokens.prompt || 0) * 0.0001,
output: (tokens.completion || 0) * 0.0001,
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
},
logs: [], // No block logs at provider level
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -519,7 +509,6 @@ export const cerebrasProvider: ProviderConfig = {
},
}
// Return the streaming execution object
return streamingResult as StreamingExecution
}
@@ -541,7 +530,6 @@ export const cerebrasProvider: ProviderConfig = {
},
}
} catch (error) {
// Include timing information even for errors
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -551,9 +539,8 @@ export const cerebrasProvider: ProviderConfig = {
duration: totalDuration,
})
// Create a new error with timing information
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
// @ts-ignore - Adding timing property to error for debugging
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,

View File

@@ -1,23 +1,26 @@
import type { CompletionUsage } from 'openai/resources/completions'
import { createOpenAICompatibleStream } from '@/providers/utils'
interface CerebrasChunk {
choices?: Array<{
delta?: {
content?: string
}
}>
usage?: {
prompt_tokens?: number
completion_tokens?: number
total_tokens?: number
}
}
/**
* Helper to convert a Cerebras streaming response (async iterable) into a ReadableStream.
* Enqueues only the model's text delta chunks as UTF-8 encoded bytes.
* Creates a ReadableStream from a Cerebras streaming response.
* Uses the shared OpenAI-compatible streaming utility.
*/
export function createReadableStreamFromCerebrasStream(
cerebrasStream: AsyncIterable<any>
): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of cerebrasStream) {
const content = chunk.choices?.[0]?.delta?.content || ''
if (content) {
controller.enqueue(new TextEncoder().encode(content))
}
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
cerebrasStream: AsyncIterable<CerebrasChunk>,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream<Uint8Array> {
return createOpenAICompatibleStream(cerebrasStream as any, 'Cerebras', onComplete)
}

View File

@@ -11,6 +11,7 @@ import type {
TimeSegment,
} from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
@@ -34,21 +35,17 @@ export const deepseekProvider: ProviderConfig = {
throw new Error('API key is required for Deepseek')
}
// Start execution timer for the entire provider execution
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
// Deepseek uses the OpenAI SDK with a custom baseURL
const deepseek = new OpenAI({
apiKey: request.apiKey,
baseURL: 'https://api.deepseek.com/v1',
})
// Start with an empty array for all messages
const allMessages = []
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push({
role: 'system',
@@ -56,7 +53,6 @@ export const deepseekProvider: ProviderConfig = {
})
}
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
@@ -64,12 +60,10 @@ export const deepseekProvider: ProviderConfig = {
})
}
// Add remaining messages
if (request.messages) {
allMessages.push(...request.messages)
}
// Transform tools to OpenAI format if provided
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
@@ -82,15 +76,13 @@ export const deepseekProvider: ProviderConfig = {
: undefined
const payload: any = {
model: 'deepseek-chat', // Hardcode to deepseek-chat regardless of what's selected in the UI
model: 'deepseek-chat',
messages: allMessages,
}
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Handle tools and tool usage control
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
if (tools?.length) {
@@ -118,7 +110,6 @@ export const deepseekProvider: ProviderConfig = {
}
}
// EARLY STREAMING: if streaming requested and no tools to execute, stream directly
if (request.stream && (!tools || tools.length === 0)) {
logger.info('Using streaming response for DeepSeek request (no tools)')
@@ -127,22 +118,35 @@ export const deepseekProvider: ProviderConfig = {
stream: true,
})
// Start collecting token usage
const tokenUsage = {
prompt: 0,
completion: 0,
total: 0,
const streamingResult = {
stream: createReadableStreamFromDeepseekStream(
streamResponse as any,
(content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: usage.prompt_tokens,
completion: usage.completion_tokens,
total: usage.total_tokens,
}
// Create a StreamingExecution response with a readable stream
const streamingResult = {
stream: createReadableStreamFromDeepseekStream(streamResponse),
const costResult = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
}
),
execution: {
success: true,
output: {
content: '', // Will be filled by streaming content in chat component
content: '',
model: request.model || 'deepseek-chat',
tokens: tokenUsage,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
@@ -158,14 +162,9 @@ export const deepseekProvider: ProviderConfig = {
},
],
},
// Estimate token cost
cost: {
total: 0.0,
input: 0.0,
output: 0.0,
cost: { input: 0, output: 0, total: 0 },
},
},
logs: [], // No block logs for direct streaming
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -175,17 +174,11 @@ export const deepseekProvider: ProviderConfig = {
},
}
// Return the streaming execution object
return streamingResult as StreamingExecution
}
// Make the initial API request
const initialCallTime = Date.now()
// Track the original tool_choice for forced tool tracking
const originalToolChoice = payload.tool_choice
// Track forced tools and their usage
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
@@ -194,11 +187,8 @@ export const deepseekProvider: ProviderConfig = {
let content = currentResponse.choices[0]?.message?.content || ''
// Clean up the response content if it exists
if (content) {
// Remove any markdown code block markers
content = content.replace(/```json\n?|\n?```/g, '')
// Trim any whitespace
content = content.trim()
}
@@ -211,15 +201,10 @@ export const deepseekProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
// Track if a forced tool has been used
let hasUsedForcedTool = false
// Track time spent in model vs tools
let modelTime = firstResponseTime
let toolsTime = 0
// Track each model and tool call segment with timestamps
const timeSegments: TimeSegment[] = [
{
type: 'model',
@@ -230,7 +215,6 @@ export const deepseekProvider: ProviderConfig = {
},
]
// Check if a forced tool was used in the first response
if (
typeof originalToolChoice === 'object' &&
currentResponse.choices[0]?.message?.tool_calls
@@ -250,50 +234,94 @@ export const deepseekProvider: ProviderConfig = {
try {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
// Track time for tool calls in this batch
const toolsStartTime = Date.now()
// Process each tool call
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
// Execute the tool
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) return null
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
// Add to time segments for both success and failure
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('Error processing tool call:', { error, toolName })
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: toolCallsInResponse.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
startTime: startTime,
endTime: endTime,
duration: duration,
})
// Prepare result content for the LLM
let resultContent: any
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
} else {
// Include error information so LLM can respond appropriately
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
@@ -304,79 +332,50 @@ export const deepseekProvider: ProviderConfig = {
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
// Add the tool call and result to messages (both success and failure)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
} catch (error) {
logger.error('Error processing tool call:', { error })
}
}
// Calculate tool call time for this iteration
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
// Make the next request with updated messages
const nextPayload = {
...payload,
messages: currentMessages,
}
// Update tool_choice based on which forced tools have been used
if (
typeof originalToolChoice === 'object' &&
hasUsedForcedTool &&
forcedTools.length > 0
) {
// If we have remaining forced tools, get the next one to force
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
if (remainingTools.length > 0) {
// Force the next tool
nextPayload.tool_choice = {
type: 'function',
function: { name: remainingTools[0] },
}
logger.info(`Forcing next tool: ${remainingTools[0]}`)
} else {
// All forced tools have been used, switch to auto
nextPayload.tool_choice = 'auto'
logger.info('All forced tools have been used, switching to auto tool_choice')
}
}
// Time the next model call
const nextModelStartTime = Date.now()
// Make the next request
currentResponse = await deepseek.chat.completions.create(nextPayload)
// Check if any forced tools were used in this response
if (
typeof nextPayload.tool_choice === 'object' &&
currentResponse.choices[0]?.message?.tool_calls
@@ -397,7 +396,6 @@ export const deepseekProvider: ProviderConfig = {
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
// Add to time segments
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
@@ -406,18 +404,14 @@ export const deepseekProvider: ProviderConfig = {
duration: thisModelTime,
})
// Add to model time
modelTime += thisModelTime
// Update content if we have a text response
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
// Clean up the response content
content = content.replace(/```json\n?|\n?```/g, '')
content = content.trim()
}
// Update token counts
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.prompt_tokens || 0
tokens.completion += currentResponse.usage.completion_tokens || 0
@@ -430,33 +424,51 @@ export const deepseekProvider: ProviderConfig = {
logger.error('Error in Deepseek request:', { error })
}
// Calculate overall timing
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
// POST-TOOL STREAMING: stream final response after tool calls if requested
if (request.stream) {
logger.info('Using streaming for final DeepSeek response after tool processing')
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
// This prevents the API from trying to force tool usage again in the final streaming response
const streamingPayload = {
...payload,
messages: currentMessages,
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
tool_choice: 'auto',
stream: true,
}
const streamResponse = await deepseek.chat.completions.create(streamingPayload)
// Create a StreamingExecution response with all collected data
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
const streamingResult = {
stream: createReadableStreamFromDeepseekStream(streamResponse),
stream: createReadableStreamFromDeepseekStream(
streamResponse as any,
(content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.prompt_tokens,
completion: tokens.completion + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
const streamCost = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}
),
execution: {
success: true,
output: {
content: '', // Will be filled by the callback
content: '',
model: request.model || 'deepseek-chat',
tokens: {
prompt: tokens.prompt,
@@ -481,12 +493,12 @@ export const deepseekProvider: ProviderConfig = {
timeSegments: timeSegments,
},
cost: {
total: (tokens.total || 0) * 0.0001,
input: (tokens.prompt || 0) * 0.0001,
output: (tokens.completion || 0) * 0.0001,
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
},
logs: [], // No block logs at provider level
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -496,7 +508,6 @@ export const deepseekProvider: ProviderConfig = {
},
}
// Return the streaming execution object
return streamingResult as StreamingExecution
}
@@ -518,7 +529,6 @@ export const deepseekProvider: ProviderConfig = {
},
}
} catch (error) {
// Include timing information even for errors
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -528,9 +538,8 @@ export const deepseekProvider: ProviderConfig = {
duration: totalDuration,
})
// Create a new error with timing information
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
// @ts-ignore
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,

View File

@@ -1,21 +1,14 @@
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import { createOpenAICompatibleStream } from '@/providers/utils'
/**
* Helper function to convert a DeepSeek (OpenAI-compatible) stream to a ReadableStream
* of text chunks that can be consumed by the browser.
* Creates a ReadableStream from a DeepSeek streaming response.
* Uses the shared OpenAI-compatible streaming utility.
*/
export function createReadableStreamFromDeepseekStream(deepseekStream: any): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of deepseekStream) {
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
controller.enqueue(new TextEncoder().encode(content))
}
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
export function createReadableStreamFromDeepseekStream(
deepseekStream: AsyncIterable<ChatCompletionChunk>,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream<Uint8Array> {
return createOpenAICompatibleStream(deepseekStream, 'Deepseek', onComplete)
}

View File

@@ -15,6 +15,7 @@ import type {
TimeSegment,
} from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
@@ -23,15 +24,15 @@ import { executeTool } from '@/tools'
const logger = createLogger('GoogleProvider')
/**
* Creates a ReadableStream from Google's Gemini stream response
*/
interface GeminiStreamUsage {
promptTokenCount: number
candidatesTokenCount: number
totalTokenCount: number
}
function createReadableStreamFromGeminiStream(
response: Response,
onComplete?: (
content: string,
usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number }
) => void
onComplete?: (content: string, usage: GeminiStreamUsage) => void
): ReadableStream<Uint8Array> {
const reader = response.body?.getReader()
if (!reader) {
@@ -43,11 +44,32 @@ function createReadableStreamFromGeminiStream(
try {
let buffer = ''
let fullContent = ''
let usageData: {
promptTokenCount?: number
candidatesTokenCount?: number
totalTokenCount?: number
} | null = null
let promptTokenCount = 0
let candidatesTokenCount = 0
let totalTokenCount = 0
const updateUsage = (metadata: any) => {
if (metadata) {
promptTokenCount = metadata.promptTokenCount ?? promptTokenCount
candidatesTokenCount = metadata.candidatesTokenCount ?? candidatesTokenCount
totalTokenCount = metadata.totalTokenCount ?? totalTokenCount
}
}
const buildUsage = (): GeminiStreamUsage => ({
promptTokenCount,
candidatesTokenCount,
totalTokenCount,
})
const complete = () => {
if (onComplete) {
if (promptTokenCount === 0 && candidatesTokenCount === 0) {
logger.warn('Gemini stream completed without usage metadata')
}
onComplete(fullContent, buildUsage())
}
}
while (true) {
const { done, value } = await reader.read()
@@ -55,20 +77,15 @@ function createReadableStreamFromGeminiStream(
if (buffer.trim()) {
try {
const data = JSON.parse(buffer.trim())
if (data.usageMetadata) {
usageData = data.usageMetadata
}
updateUsage(data.usageMetadata)
const candidate = data.candidates?.[0]
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in final buffer, ending stream to execute tool',
{
logger.debug('Function call detected in final buffer', {
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
})
complete()
controller.close()
return
}
@@ -84,20 +101,15 @@ function createReadableStreamFromGeminiStream(
const dataArray = JSON.parse(buffer.trim())
if (Array.isArray(dataArray)) {
for (const item of dataArray) {
if (item.usageMetadata) {
usageData = item.usageMetadata
}
updateUsage(item.usageMetadata)
const candidate = item.candidates?.[0]
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in array item, ending stream to execute tool',
{
logger.debug('Function call detected in array item', {
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
})
complete()
controller.close()
return
}
@@ -109,13 +121,11 @@ function createReadableStreamFromGeminiStream(
}
}
}
} catch (arrayError) {
// Buffer is not valid JSON array
} catch (_) {}
}
}
}
}
if (onComplete) onComplete(fullContent, usageData || undefined)
complete()
controller.close()
break
}
@@ -162,25 +172,17 @@ function createReadableStreamFromGeminiStream(
try {
const data = JSON.parse(jsonStr)
if (data.usageMetadata) {
usageData = data.usageMetadata
}
updateUsage(data.usageMetadata)
const candidate = data.candidates?.[0]
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
logger.warn('Gemini returned UNEXPECTED_TOOL_CALL in streaming mode', {
finishReason: candidate.finishReason,
hasContent: !!candidate?.content,
hasParts: !!candidate?.content?.parts,
})
logger.warn('Gemini returned UNEXPECTED_TOOL_CALL in streaming mode')
const textContent = extractTextContent(candidate)
if (textContent) {
fullContent += textContent
controller.enqueue(new TextEncoder().encode(textContent))
}
if (onComplete) onComplete(fullContent, usageData || undefined)
complete()
controller.close()
return
}
@@ -188,13 +190,10 @@ function createReadableStreamFromGeminiStream(
if (candidate?.content?.parts) {
const functionCall = extractFunctionCall(candidate)
if (functionCall) {
logger.debug(
'Function call detected in stream, ending stream to execute tool',
{
logger.debug('Function call detected in stream', {
functionName: functionCall.name,
}
)
if (onComplete) onComplete(fullContent, usageData || undefined)
})
complete()
controller.close()
return
}
@@ -214,7 +213,6 @@ function createReadableStreamFromGeminiStream(
buffer = buffer.substring(closeBrace + 1)
searchIndex = 0
} else {
// No complete JSON object found, wait for more data
break
}
}
@@ -395,6 +393,7 @@ export const googleProvider: ProviderConfig = {
},
],
},
cost: { input: 0, output: 0, total: 0 },
},
logs: [],
metadata: {
@@ -410,6 +409,22 @@ export const googleProvider: ProviderConfig = {
response,
(content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: usage.promptTokenCount,
completion: usage.candidatesTokenCount,
total: usage.totalTokenCount || usage.promptTokenCount + usage.candidatesTokenCount,
}
const costResult = calculateCost(
request.model,
usage.promptTokenCount,
usage.candidatesTokenCount
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
@@ -426,16 +441,6 @@ export const googleProvider: ProviderConfig = {
streamEndTime - providerStartTime
}
}
if (usage) {
streamingResult.execution.output.tokens = {
prompt: usage.promptTokenCount || 0,
completion: usage.candidatesTokenCount || 0,
total:
usage.totalTokenCount ||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0),
}
}
}
)
@@ -463,6 +468,11 @@ export const googleProvider: ProviderConfig = {
completion: 0,
total: 0,
}
const cost = {
input: 0,
output: 0,
total: 0,
}
const toolCalls = []
const toolResults = []
let iterationCount = 0
@@ -705,6 +715,15 @@ export const googleProvider: ProviderConfig = {
tokens.total +=
(checkResult.usageMetadata.promptTokenCount || 0) +
(checkResult.usageMetadata.candidatesTokenCount || 0)
const iterationCost = calculateCost(
request.model,
checkResult.usageMetadata.promptTokenCount || 0,
checkResult.usageMetadata.candidatesTokenCount || 0
)
cost.input += iterationCost.input
cost.output += iterationCost.output
cost.total += iterationCost.total
}
const nextModelEndTime = Date.now()
@@ -724,8 +743,6 @@ export const googleProvider: ProviderConfig = {
}
logger.info('No function call detected, proceeding with streaming response')
// Apply structured output for the final response if responseFormat is specified
// This works regardless of whether tools were forced or auto
if (request.responseFormat) {
streamingPayload.tools = undefined
streamingPayload.toolConfig = undefined
@@ -806,6 +823,7 @@ export const googleProvider: ProviderConfig = {
iterations: iterationCount + 1,
timeSegments,
},
cost,
},
logs: [],
metadata: {
@@ -822,6 +840,28 @@ export const googleProvider: ProviderConfig = {
(content, usage) => {
streamingExecution.execution.output.content = content
const existingTokens = streamingExecution.execution.output.tokens
streamingExecution.execution.output.tokens = {
prompt: (existingTokens?.prompt ?? 0) + usage.promptTokenCount,
completion: (existingTokens?.completion ?? 0) + usage.candidatesTokenCount,
total:
(existingTokens?.total ?? 0) +
(usage.totalTokenCount ||
usage.promptTokenCount + usage.candidatesTokenCount),
}
const streamCost = calculateCost(
request.model,
usage.promptTokenCount,
usage.candidatesTokenCount
)
const existingCost = streamingExecution.execution.output.cost as any
streamingExecution.execution.output.cost = {
input: (existingCost?.input ?? 0) + streamCost.input,
output: (existingCost?.output ?? 0) + streamCost.output,
total: (existingCost?.total ?? 0) + streamCost.total,
}
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
@@ -831,23 +871,6 @@ export const googleProvider: ProviderConfig = {
streamingExecution.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
}
if (usage) {
const existingTokens = streamingExecution.execution.output.tokens || {
prompt: 0,
completion: 0,
total: 0,
}
streamingExecution.execution.output.tokens = {
prompt: (existingTokens.prompt || 0) + (usage.promptTokenCount || 0),
completion:
(existingTokens.completion || 0) + (usage.candidatesTokenCount || 0),
total:
(existingTokens.total || 0) +
(usage.totalTokenCount ||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0)),
}
}
}
)
@@ -926,7 +949,6 @@ export const googleProvider: ProviderConfig = {
const nextFunctionCall = extractFunctionCall(nextCandidate)
if (!nextFunctionCall) {
// If responseFormat is specified, make one final request with structured output
if (request.responseFormat) {
const finalPayload = {
...payload,
@@ -969,6 +991,15 @@ export const googleProvider: ProviderConfig = {
tokens.total +=
(finalResult.usageMetadata.promptTokenCount || 0) +
(finalResult.usageMetadata.candidatesTokenCount || 0)
const iterationCost = calculateCost(
request.model,
finalResult.usageMetadata.promptTokenCount || 0,
finalResult.usageMetadata.candidatesTokenCount || 0
)
cost.input += iterationCost.input
cost.output += iterationCost.output
cost.total += iterationCost.total
}
} else {
logger.warn(
@@ -1054,7 +1085,7 @@ export const googleProvider: ProviderConfig = {
})
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
// @ts-ignore
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,

View File

@@ -1,3 +1,4 @@
import type { Candidate } from '@google/genai'
import type { ProviderRequest } from '@/providers/types'
/**
@@ -23,7 +24,7 @@ export function cleanSchemaForGemini(schema: any): any {
/**
* Extracts text content from a Gemini response candidate, handling structured output
*/
export function extractTextContent(candidate: any): string {
export function extractTextContent(candidate: Candidate | undefined): string {
if (!candidate?.content?.parts) return ''
if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) {
@@ -32,9 +33,7 @@ export function extractTextContent(candidate: any): string {
try {
JSON.parse(text)
return text
} catch (_e) {
/* Not valid JSON, continue with normal extraction */
}
} catch (_e) {}
}
}
@@ -47,32 +46,18 @@ export function extractTextContent(candidate: any): string {
/**
* Extracts a function call from a Gemini response candidate
*/
export function extractFunctionCall(candidate: any): { name: string; args: any } | null {
export function extractFunctionCall(
candidate: Candidate | undefined
): { name: string; args: any } | null {
if (!candidate?.content?.parts) return null
for (const part of candidate.content.parts) {
if (part.functionCall) {
const args = part.functionCall.args || {}
if (
typeof part.functionCall.args === 'string' &&
part.functionCall.args.trim().startsWith('{')
) {
try {
return { name: part.functionCall.name, args: JSON.parse(part.functionCall.args) }
} catch (_e) {
return { name: part.functionCall.name, args: part.functionCall.args }
return {
name: part.functionCall.name ?? '',
args: part.functionCall.args ?? {},
}
}
return { name: part.functionCall.name, args }
}
}
if (candidate.content.function_call) {
const args =
typeof candidate.content.function_call.arguments === 'string'
? JSON.parse(candidate.content.function_call.arguments || '{}')
: candidate.content.function_call.arguments || {}
return { name: candidate.content.function_call.name, args }
}
return null

View File

@@ -11,6 +11,7 @@ import type {
TimeSegment,
} from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
@@ -34,13 +35,10 @@ export const groqProvider: ProviderConfig = {
throw new Error('API key is required for Groq')
}
// Create Groq client
const groq = new Groq({ apiKey: request.apiKey })
// Start with an empty array for all messages
const allMessages = []
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push({
role: 'system',
@@ -48,7 +46,6 @@ export const groqProvider: ProviderConfig = {
})
}
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
@@ -56,12 +53,10 @@ export const groqProvider: ProviderConfig = {
})
}
// Add remaining messages
if (request.messages) {
allMessages.push(...request.messages)
}
// Transform tools to function format if provided
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
@@ -73,7 +68,6 @@ export const groqProvider: ProviderConfig = {
}))
: undefined
// Build the request payload
const payload: any = {
model: (request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct').replace(
'groq/',
@@ -82,20 +76,20 @@ export const groqProvider: ProviderConfig = {
messages: allMessages,
}
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add response format for structured output if specified
if (request.responseFormat) {
payload.response_format = {
type: 'json_schema',
json_schema: {
name: request.responseFormat.name || 'response_schema',
schema: request.responseFormat.schema || request.responseFormat,
strict: request.responseFormat.strict !== false,
},
}
}
// Handle tools and tool usage control
// Groq supports full OpenAI-compatible tool_choice including forcing specific tools
let originalToolChoice: any
let forcedTools: string[] = []
let hasFilteredTools = false
@@ -120,12 +114,9 @@ export const groqProvider: ProviderConfig = {
}
}
// EARLY STREAMING: if caller requested streaming and there are no tools to execute,
// we can directly stream the completion.
if (request.stream && (!tools || tools.length === 0)) {
logger.info('Using streaming response for Groq request (no tools)')
// Start execution timer for the entire provider execution
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
@@ -134,22 +125,32 @@ export const groqProvider: ProviderConfig = {
stream: true,
})
// Start collecting token usage
const tokenUsage = {
prompt: 0,
completion: 0,
total: 0,
const streamingResult = {
stream: createReadableStreamFromGroqStream(streamResponse as any, (content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: usage.prompt_tokens,
completion: usage.completion_tokens,
total: usage.total_tokens,
}
// Create a StreamingExecution response with a readable stream
const streamingResult = {
stream: createReadableStreamFromGroqStream(streamResponse),
const costResult = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
}),
execution: {
success: true,
output: {
content: '', // Will be filled by streaming content in chat component
content: '',
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
tokens: tokenUsage,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
@@ -165,13 +166,9 @@ export const groqProvider: ProviderConfig = {
},
],
},
cost: {
total: 0.0,
input: 0.0,
output: 0.0,
cost: { input: 0, output: 0, total: 0 },
},
},
logs: [], // No block logs for direct streaming
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -181,16 +178,13 @@ export const groqProvider: ProviderConfig = {
},
}
// Return the streaming execution object
return streamingResult as StreamingExecution
}
// Start execution timer for the entire provider execution
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
// Make the initial API request
const initialCallTime = Date.now()
let currentResponse = await groq.chat.completions.create(payload)
@@ -206,12 +200,9 @@ export const groqProvider: ProviderConfig = {
const toolResults = []
const currentMessages = [...allMessages]
let iterationCount = 0
// Track time spent in model vs tools
let modelTime = firstResponseTime
let toolsTime = 0
// Track each model and tool call segment with timestamps
const timeSegments: TimeSegment[] = [
{
type: 'model',
@@ -224,50 +215,94 @@ export const groqProvider: ProviderConfig = {
try {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
// Track time for tool calls in this batch
const toolsStartTime = Date.now()
// Process each tool call
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
// Execute the tool
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) return null
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
// Add to time segments for both success and failure
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('Error processing tool call:', { error, toolName })
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: toolCallsInResponse.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
startTime: startTime,
endTime: endTime,
duration: duration,
})
// Prepare result content for the LLM
let resultContent: any
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
} else {
// Include error information so LLM can respond appropriately
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
@@ -278,44 +313,23 @@ export const groqProvider: ProviderConfig = {
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
// Add the tool call and result to messages (both success and failure)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
} catch (error) {
logger.error('Error processing tool call:', { error })
}
}
// Calculate tool call time for this iteration
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
// Check if we used any forced tools and update tool_choice for the next iteration
let usedForcedTools: string[] = []
if (typeof originalToolChoice === 'object' && forcedTools.length > 0) {
const toolTracking = trackForcedToolUsage(
@@ -329,31 +343,24 @@ export const groqProvider: ProviderConfig = {
usedForcedTools = toolTracking.usedForcedTools
const nextToolChoice = toolTracking.nextToolChoice
// Update tool_choice for next iteration if we're still forcing tools
if (nextToolChoice && typeof nextToolChoice === 'object') {
payload.tool_choice = nextToolChoice
} else if (nextToolChoice === 'auto' || !nextToolChoice) {
// All forced tools have been used, switch to auto
payload.tool_choice = 'auto'
}
}
// Make the next request with updated messages
const nextPayload = {
...payload,
messages: currentMessages,
}
// Time the next model call
const nextModelStartTime = Date.now()
// Make the next request
currentResponse = await groq.chat.completions.create(nextPayload)
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
// Add to time segments
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
@@ -362,15 +369,12 @@ export const groqProvider: ProviderConfig = {
duration: thisModelTime,
})
// Add to model time
modelTime += thisModelTime
// Update content if we have a text response
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
// Update token counts
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.prompt_tokens || 0
tokens.completion += currentResponse.usage.completion_tokens || 0
@@ -383,28 +387,44 @@ export const groqProvider: ProviderConfig = {
logger.error('Error in Groq request:', { error })
}
// After all tool processing complete, if streaming was requested and we have messages, use streaming for the final response
if (request.stream) {
logger.info('Using streaming for final Groq response after tool processing')
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
// This prevents the API from trying to force tool usage again in the final streaming response
const streamingPayload = {
...payload,
messages: currentMessages,
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
tool_choice: 'auto',
stream: true,
}
const streamResponse = await groq.chat.completions.create(streamingPayload)
// Create a StreamingExecution response with all collected data
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
const streamingResult = {
stream: createReadableStreamFromGroqStream(streamResponse),
stream: createReadableStreamFromGroqStream(streamResponse as any, (content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.prompt_tokens,
completion: tokens.completion + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
const streamCost = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}),
execution: {
success: true,
output: {
content: '', // Will be filled by the callback
content: '',
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
tokens: {
prompt: tokens.prompt,
@@ -429,12 +449,12 @@ export const groqProvider: ProviderConfig = {
timeSegments: timeSegments,
},
cost: {
total: (tokens.total || 0) * 0.0001,
input: (tokens.prompt || 0) * 0.0001,
output: (tokens.completion || 0) * 0.0001,
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
},
logs: [], // No block logs at provider level
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -444,11 +464,9 @@ export const groqProvider: ProviderConfig = {
},
}
// Return the streaming execution object
return streamingResult as StreamingExecution
}
// Calculate overall timing
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -471,7 +489,6 @@ export const groqProvider: ProviderConfig = {
},
}
} catch (error) {
// Include timing information even for errors
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -481,9 +498,8 @@ export const groqProvider: ProviderConfig = {
duration: totalDuration,
})
// Create a new error with timing information
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
// @ts-ignore
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,

View File

@@ -1,23 +1,14 @@
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import { createOpenAICompatibleStream } from '@/providers/utils'
/**
* Helper to wrap Groq streaming into a browser-friendly ReadableStream
* of raw assistant text chunks.
*
* @param groqStream - The Groq streaming response
* @returns A ReadableStream that emits text chunks
* Creates a ReadableStream from a Groq streaming response.
* Uses the shared OpenAI-compatible streaming utility.
*/
export function createReadableStreamFromGroqStream(groqStream: any): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of groqStream) {
if (chunk.choices[0]?.delta?.content) {
controller.enqueue(new TextEncoder().encode(chunk.choices[0].delta.content))
}
}
controller.close()
} catch (err) {
controller.error(err)
}
},
})
export function createReadableStreamFromGroqStream(
groqStream: AsyncIterable<ChatCompletionChunk>,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream<Uint8Array> {
return createOpenAICompatibleStream(groqStream, 'Groq', onComplete)
}

View File

@@ -1,4 +1,5 @@
import OpenAI from 'openai'
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
@@ -11,6 +12,7 @@ import type {
TimeSegment,
} from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
@@ -100,8 +102,6 @@ export const mistralProvider: ProviderConfig = {
strict: request.responseFormat.strict !== false,
},
}
logger.info('Added JSON schema response format to request')
}
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
@@ -138,24 +138,32 @@ export const mistralProvider: ProviderConfig = {
if (request.stream && (!tools || tools.length === 0)) {
logger.info('Using streaming response for Mistral request')
const streamResponse = await mistral.chat.completions.create({
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
stream: true,
stream_options: { include_usage: true },
})
const tokenUsage = {
prompt: 0,
completion: 0,
total: 0,
}
let _streamContent = ''
const streamResponse = await mistral.chat.completions.create(streamingParams)
const streamingResult = {
stream: createReadableStreamFromMistralStream(streamResponse, (content, usage) => {
_streamContent = content
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: usage.prompt_tokens,
completion: usage.completion_tokens,
total: usage.total_tokens,
}
const costResult = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
@@ -172,23 +180,13 @@ export const mistralProvider: ProviderConfig = {
streamEndTime - providerStartTime
}
}
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokenUsage.prompt,
completion: usage.completion_tokens || tokenUsage.completion,
total: usage.total_tokens || tokenUsage.total,
}
streamingResult.execution.output.tokens = newTokens
}
}),
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens: tokenUsage,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
@@ -204,6 +202,7 @@ export const mistralProvider: ProviderConfig = {
},
],
},
cost: { input: 0, output: 0, total: 0 },
},
logs: [],
metadata: {
@@ -275,6 +274,10 @@ export const mistralProvider: ProviderConfig = {
checkForForcedToolUsage(currentResponse, originalToolChoice)
while (iterationCount < MAX_TOOL_ITERATIONS) {
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
@@ -285,28 +288,75 @@ export const mistralProvider: ProviderConfig = {
)
const toolsStartTime = Date.now()
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) return null
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('Error processing tool call:', { error, toolName })
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: toolCallsInResponse.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
startTime: startTime,
endTime: endTime,
duration: duration,
})
let resultContent: any
@@ -324,39 +374,17 @@ export const mistralProvider: ProviderConfig = {
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
} catch (error) {
logger.error('Error processing tool call:', {
error,
toolName: toolCall?.function?.name,
})
}
}
const thisToolsTime = Date.now() - toolsStartTime
@@ -417,31 +445,35 @@ export const mistralProvider: ProviderConfig = {
if (request.stream) {
logger.info('Using streaming for final response after tool processing')
const streamingPayload = {
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
messages: currentMessages,
tool_choice: 'auto',
stream: true,
stream_options: { include_usage: true },
}
const streamResponse = await mistral.chat.completions.create(streamingPayload)
let _streamContent = ''
const streamResponse = await mistral.chat.completions.create(streamingParams)
const streamingResult = {
stream: createReadableStreamFromMistralStream(streamResponse, (content, usage) => {
_streamContent = content
streamingResult.execution.output.content = content
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokens.prompt,
completion: usage.completion_tokens || tokens.completion,
total: usage.total_tokens || tokens.total,
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.prompt_tokens,
completion: tokens.completion + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
streamingResult.execution.output.tokens = newTokens
const streamCost = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}),
execution: {
@@ -471,6 +503,11 @@ export const mistralProvider: ProviderConfig = {
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
cost: {
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
},
logs: [],
metadata: {
@@ -516,7 +553,7 @@ export const mistralProvider: ProviderConfig = {
})
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
// @ts-ignore - Adding timing property to error for debugging
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,

View File

@@ -1,39 +1,14 @@
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import { createOpenAICompatibleStream } from '@/providers/utils'
/**
* Creates a ReadableStream from a Mistral AI streaming response
* @param mistralStream - The Mistral AI stream object
* @param onComplete - Optional callback when streaming completes
* @returns A ReadableStream that yields text chunks
* Creates a ReadableStream from a Mistral streaming response.
* Uses the shared OpenAI-compatible streaming utility.
*/
export function createReadableStreamFromMistralStream(
mistralStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of mistralStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
mistralStream: AsyncIterable<ChatCompletionChunk>,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream<Uint8Array> {
return createOpenAICompatibleStream(mistralStream, 'Mistral', onComplete)
}

View File

@@ -23,13 +23,7 @@ import {
VllmIcon,
xAIIcon,
} from '@/components/icons'
export interface ModelPricing {
input: number // Per 1M tokens
cachedInput?: number // Per 1M tokens (if supported)
output: number // Per 1M tokens
updatedAt: string
}
import type { ModelPricing } from '@/providers/types'
export interface ModelCapabilities {
temperature?: {
@@ -38,6 +32,7 @@ export interface ModelCapabilities {
}
toolUsageControl?: boolean
computerUse?: boolean
nativeStructuredOutputs?: boolean
reasoningEffort?: {
values: string[]
}
@@ -50,7 +45,7 @@ export interface ModelDefinition {
id: string
pricing: ModelPricing
capabilities: ModelCapabilities
contextWindow?: number // Maximum context window in tokens (may be undefined for dynamic providers)
contextWindow?: number
}
export interface ProviderDefinition {
@@ -62,13 +57,9 @@ export interface ProviderDefinition {
modelPatterns?: RegExp[]
icon?: React.ComponentType<{ className?: string }>
capabilities?: ModelCapabilities
// Indicates whether reliable context window information is available for this provider's models
contextInformationAvailable?: boolean
}
/**
* Comprehensive provider definitions, single source of truth
*/
export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
openrouter: {
id: 'openrouter',
@@ -616,6 +607,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 1 },
nativeStructuredOutputs: true,
},
contextWindow: 200000,
},
@@ -629,6 +621,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 1 },
nativeStructuredOutputs: true,
},
contextWindow: 200000,
},
@@ -655,6 +648,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 1 },
nativeStructuredOutputs: true,
},
contextWindow: 200000,
},
@@ -668,6 +662,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
capabilities: {
temperature: { min: 0, max: 1 },
nativeStructuredOutputs: true,
},
contextWindow: 200000,
},
@@ -1619,23 +1614,14 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
},
}
/**
* Get all models for a specific provider
*/
export function getProviderModels(providerId: string): string[] {
return PROVIDER_DEFINITIONS[providerId]?.models.map((m) => m.id) || []
}
/**
* Get the default model for a specific provider
*/
export function getProviderDefaultModel(providerId: string): string {
return PROVIDER_DEFINITIONS[providerId]?.defaultModel || ''
}
/**
* Get pricing information for a specific model
*/
export function getModelPricing(modelId: string): ModelPricing | null {
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase())
@@ -1646,20 +1632,15 @@ export function getModelPricing(modelId: string): ModelPricing | null {
return null
}
/**
* Get capabilities for a specific model
*/
export function getModelCapabilities(modelId: string): ModelCapabilities | null {
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase())
if (model) {
// Merge provider capabilities with model capabilities, model takes precedence
const capabilities: ModelCapabilities = { ...provider.capabilities, ...model.capabilities }
return capabilities
}
}
// If no model found, check for provider-level capabilities for dynamically fetched models
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
if (provider.modelPatterns) {
for (const pattern of provider.modelPatterns) {
@@ -1673,9 +1654,6 @@ export function getModelCapabilities(modelId: string): ModelCapabilities | null
return null
}
/**
* Get all models that support temperature
*/
export function getModelsWithTemperatureSupport(): string[] {
const models: string[] = []
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
@@ -1688,9 +1666,6 @@ export function getModelsWithTemperatureSupport(): string[] {
return models
}
/**
* Get all models with temperature range 0-1
*/
export function getModelsWithTempRange01(): string[] {
const models: string[] = []
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
@@ -1703,9 +1678,6 @@ export function getModelsWithTempRange01(): string[] {
return models
}
/**
* Get all models with temperature range 0-2
*/
export function getModelsWithTempRange02(): string[] {
const models: string[] = []
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
@@ -1718,9 +1690,6 @@ export function getModelsWithTempRange02(): string[] {
return models
}
/**
* Get all providers that support tool usage control
*/
export function getProvidersWithToolUsageControl(): string[] {
const providers: string[] = []
for (const [providerId, provider] of Object.entries(PROVIDER_DEFINITIONS)) {
@@ -1731,9 +1700,6 @@ export function getProvidersWithToolUsageControl(): string[] {
return providers
}
/**
* Get all models that are hosted (don't require user API keys)
*/
export function getHostedModels(): string[] {
return [
...getProviderModels('openai'),
@@ -1742,9 +1708,6 @@ export function getHostedModels(): string[] {
]
}
/**
* Get all computer use models
*/
export function getComputerUseModels(): string[] {
const models: string[] = []
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
@@ -1757,32 +1720,20 @@ export function getComputerUseModels(): string[] {
return models
}
/**
* Check if a model supports temperature
*/
export function supportsTemperature(modelId: string): boolean {
const capabilities = getModelCapabilities(modelId)
return !!capabilities?.temperature
}
/**
* Get maximum temperature for a model
*/
export function getMaxTemperature(modelId: string): number | undefined {
const capabilities = getModelCapabilities(modelId)
return capabilities?.temperature?.max
}
/**
* Check if a provider supports tool usage control
*/
export function supportsToolUsageControl(providerId: string): boolean {
return getProvidersWithToolUsageControl().includes(providerId)
}
/**
* Update Ollama models dynamically
*/
export function updateOllamaModels(models: string[]): void {
PROVIDER_DEFINITIONS.ollama.models = models.map((modelId) => ({
id: modelId,
@@ -1795,9 +1746,6 @@ export function updateOllamaModels(models: string[]): void {
}))
}
/**
* Update vLLM models dynamically
*/
export function updateVLLMModels(models: string[]): void {
PROVIDER_DEFINITIONS.vllm.models = models.map((modelId) => ({
id: modelId,
@@ -1810,9 +1758,6 @@ export function updateVLLMModels(models: string[]): void {
}))
}
/**
* Update OpenRouter models dynamically
*/
export function updateOpenRouterModels(models: string[]): void {
PROVIDER_DEFINITIONS.openrouter.models = models.map((modelId) => ({
id: modelId,
@@ -1825,9 +1770,6 @@ export function updateOpenRouterModels(models: string[]): void {
}))
}
/**
* Embedding model pricing - separate from chat models
*/
export const EMBEDDING_MODEL_PRICING: Record<string, ModelPricing> = {
'text-embedding-3-small': {
input: 0.02, // $0.02 per 1M tokens
@@ -1846,16 +1788,10 @@ export const EMBEDDING_MODEL_PRICING: Record<string, ModelPricing> = {
},
}
/**
* Get pricing for embedding models specifically
*/
export function getEmbeddingModelPricing(modelId: string): ModelPricing | null {
return EMBEDDING_MODEL_PRICING[modelId] || null
}
/**
* Get all models that support reasoning effort
*/
export function getModelsWithReasoningEffort(): string[] {
const models: string[] = []
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
@@ -1882,9 +1818,6 @@ export function getReasoningEffortValuesForModel(modelId: string): string[] | nu
return null
}
/**
* Get all models that support verbosity
*/
export function getModelsWithVerbosity(): string[] {
const models: string[] = []
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
@@ -1910,3 +1843,24 @@ export function getVerbosityValuesForModel(modelId: string): string[] | null {
}
return null
}
/**
* Check if a model supports native structured outputs.
* Handles model IDs with date suffixes (e.g., claude-sonnet-4-5-20250514).
*/
export function supportsNativeStructuredOutputs(modelId: string): boolean {
const normalizedModelId = modelId.toLowerCase()
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
for (const model of provider.models) {
if (model.capabilities.nativeStructuredOutputs) {
const baseModelId = model.id.toLowerCase()
// Check exact match or date-suffixed version (e.g., claude-sonnet-4-5-20250514)
if (normalizedModelId === baseModelId || normalizedModelId.startsWith(`${baseModelId}-`)) {
return true
}
}
}
}
return false
}

View File

@@ -1,4 +1,5 @@
import OpenAI from 'openai'
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
import { env } from '@/lib/core/config/env'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
@@ -11,7 +12,7 @@ import type {
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import { prepareToolExecution } from '@/providers/utils'
import { calculateCost, prepareToolExecution } from '@/providers/utils'
import { useProvidersStore } from '@/stores/providers/store'
import { executeTool } from '@/tools'
@@ -23,10 +24,9 @@ export const ollamaProvider: ProviderConfig = {
name: 'Ollama',
description: 'Local Ollama server for LLM inference',
version: '1.0.0',
models: [], // Will be populated dynamically
models: [],
defaultModel: '',
// Initialize the provider by fetching available models
async initialize() {
if (typeof window !== 'undefined') {
logger.info('Skipping Ollama initialization on client side to avoid CORS issues')
@@ -63,16 +63,13 @@ export const ollamaProvider: ProviderConfig = {
stream: !!request.stream,
})
// Create Ollama client using OpenAI-compatible API
const ollama = new OpenAI({
apiKey: 'empty',
baseURL: `${OLLAMA_HOST}/v1`,
})
// Start with an empty array for all messages
const allMessages = []
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push({
role: 'system',
@@ -80,7 +77,6 @@ export const ollamaProvider: ProviderConfig = {
})
}
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
@@ -88,12 +84,10 @@ export const ollamaProvider: ProviderConfig = {
})
}
// Add remaining messages
if (request.messages) {
allMessages.push(...request.messages)
}
// Transform tools to OpenAI format if provided
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
@@ -105,19 +99,15 @@ export const ollamaProvider: ProviderConfig = {
}))
: undefined
// Build the request payload
const payload: any = {
model: request.model,
messages: allMessages,
}
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add response format for structured output if specified
if (request.responseFormat) {
// Use OpenAI's JSON schema format (Ollama supports this)
payload.response_format = {
type: 'json_schema',
json_schema: {
@@ -130,21 +120,13 @@ export const ollamaProvider: ProviderConfig = {
logger.info('Added JSON schema response format to Ollama request')
}
// Handle tools and tool usage control
// NOTE: Ollama does NOT support the tool_choice parameter beyond basic 'auto' behavior
// According to official documentation, tool_choice is silently ignored
// Ollama only supports basic function calling where the model autonomously decides
if (tools?.length) {
// Filter out tools with usageControl='none'
// Treat 'force' as 'auto' since Ollama doesn't support forced tool selection
const filteredTools = tools.filter((tool) => {
const toolId = tool.function?.name
const toolConfig = request.tools?.find((t) => t.id === toolId)
// Only filter out 'none', treat 'force' as 'auto'
return toolConfig?.usageControl !== 'none'
})
// Check if any tools were forcibly marked
const hasForcedTools = tools.some((tool) => {
const toolId = tool.function?.name
const toolConfig = request.tools?.find((t) => t.id === toolId)
@@ -160,55 +142,58 @@ export const ollamaProvider: ProviderConfig = {
if (filteredTools?.length) {
payload.tools = filteredTools
// Ollama only supports 'auto' behavior - model decides whether to use tools
payload.tool_choice = 'auto'
logger.info('Ollama request configuration:', {
toolCount: filteredTools.length,
toolChoice: 'auto', // Ollama always uses auto
toolChoice: 'auto',
forcedToolsIgnored: hasForcedTools,
model: request.model,
})
}
}
// Start execution timer for the entire provider execution
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
// Check if we can stream directly (no tools required)
if (request.stream && (!tools || tools.length === 0)) {
logger.info('Using streaming response for Ollama request')
// Create a streaming request with token usage tracking
const streamResponse = await ollama.chat.completions.create({
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
stream: true,
stream_options: { include_usage: true },
})
// Start collecting token usage from the stream
const tokenUsage = {
prompt: 0,
completion: 0,
total: 0,
}
const streamResponse = await ollama.chat.completions.create(streamingParams)
// Create a StreamingExecution response with a callback to update content and tokens
const streamingResult = {
stream: createReadableStreamFromOllamaStream(streamResponse, (content, usage) => {
// Update the execution data with the final content and token usage
streamingResult.execution.output.content = content
// Clean up the response content
if (content) {
streamingResult.execution.output.content = content
.replace(/```json\n?|\n?```/g, '')
.trim()
}
// Update the timing information with the actual completion time
streamingResult.execution.output.tokens = {
prompt: usage.prompt_tokens,
completion: usage.completion_tokens,
total: usage.total_tokens,
}
const costResult = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
@@ -217,7 +202,6 @@ export const ollamaProvider: ProviderConfig = {
streamingResult.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
// Update the time segment as well
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
streamEndTime
@@ -225,24 +209,13 @@ export const ollamaProvider: ProviderConfig = {
streamEndTime - providerStartTime
}
}
// Update token usage if available from the stream
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokenUsage.prompt,
completion: usage.completion_tokens || tokenUsage.completion,
total: usage.total_tokens || tokenUsage.total,
}
streamingResult.execution.output.tokens = newTokens
}
}),
execution: {
success: true,
output: {
content: '', // Will be filled by the stream completion callback
content: '',
model: request.model,
tokens: tokenUsage,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
@@ -258,8 +231,9 @@ export const ollamaProvider: ProviderConfig = {
},
],
},
cost: { input: 0, output: 0, total: 0 },
},
logs: [], // No block logs for direct streaming
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -268,11 +242,9 @@ export const ollamaProvider: ProviderConfig = {
},
} as StreamingExecution
// Return the streaming execution object
return streamingResult as StreamingExecution
}
// Make the initial API request
const initialCallTime = Date.now()
let currentResponse = await ollama.chat.completions.create(payload)
@@ -280,13 +252,11 @@ export const ollamaProvider: ProviderConfig = {
let content = currentResponse.choices[0]?.message?.content || ''
// Clean up the response content if it exists
if (content) {
content = content.replace(/```json\n?|\n?```/g, '')
content = content.trim()
}
// Collect token information
const tokens = {
prompt: currentResponse.usage?.prompt_tokens || 0,
completion: currentResponse.usage?.completion_tokens || 0,
@@ -297,11 +267,9 @@ export const ollamaProvider: ProviderConfig = {
const currentMessages = [...allMessages]
let iterationCount = 0
// Track time spent in model vs tools
let modelTime = firstResponseTime
let toolsTime = 0
// Track each model and tool call segment with timestamps
const timeSegments: TimeSegment[] = [
{
type: 'model',
@@ -313,7 +281,10 @@ export const ollamaProvider: ProviderConfig = {
]
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
@@ -323,43 +294,85 @@ export const ollamaProvider: ProviderConfig = {
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
// Track time for tool calls in this batch
const toolsStartTime = Date.now()
// Process each tool call
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
// Execute the tool
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) return null
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
// Add to time segments for both success and failure
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('Error processing tool call:', { error, toolName })
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: toolCallsInResponse.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
startTime: startTime,
endTime: endTime,
duration: duration,
})
// Prepare result content for the LLM
let resultContent: any
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
} else {
// Include error information so LLM can respond appropriately
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
@@ -370,62 +383,35 @@ export const ollamaProvider: ProviderConfig = {
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
// Add the tool call and result to messages (both success and failure)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
} catch (error) {
logger.error('Error processing tool call:', {
error,
toolName: toolCall?.function?.name,
})
}
}
// Calculate tool call time for this iteration
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
// Make the next request with updated messages
const nextPayload = {
...payload,
messages: currentMessages,
}
// Time the next model call
const nextModelStartTime = Date.now()
// Make the next request
currentResponse = await ollama.chat.completions.create(nextPayload)
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
// Add to time segments
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
@@ -434,18 +420,14 @@ export const ollamaProvider: ProviderConfig = {
duration: thisModelTime,
})
// Add to model time
modelTime += thisModelTime
// Update content if we have a text response
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
// Clean up the response content
content = content.replace(/```json\n?|\n?```/g, '')
content = content.trim()
}
// Update token counts
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.prompt_tokens || 0
tokens.completion += currentResponse.usage.completion_tokens || 0
@@ -455,48 +437,51 @@ export const ollamaProvider: ProviderConfig = {
iterationCount++
}
// After all tool processing complete, if streaming was requested and we have messages, use streaming for the final response
if (request.stream) {
logger.info('Using streaming for final response after tool processing')
const streamingPayload = {
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
messages: currentMessages,
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
tool_choice: 'auto',
stream: true,
stream_options: { include_usage: true },
}
const streamResponse = await ollama.chat.completions.create(streamingParams)
const streamResponse = await ollama.chat.completions.create(streamingPayload)
// Create the StreamingExecution object with all collected data
const streamingResult = {
stream: createReadableStreamFromOllamaStream(streamResponse, (content, usage) => {
// Update the execution data with the final content and token usage
streamingResult.execution.output.content = content
// Clean up the response content
if (content) {
streamingResult.execution.output.content = content
.replace(/```json\n?|\n?```/g, '')
.trim()
}
// Update token usage if available from the stream
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokens.prompt,
completion: usage.completion_tokens || tokens.completion,
total: usage.total_tokens || tokens.total,
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.prompt_tokens,
completion: tokens.completion + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
streamingResult.execution.output.tokens = newTokens
const streamCost = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}),
execution: {
success: true,
output: {
content: '', // Will be filled by the callback
content: '',
model: request.model,
tokens: {
prompt: tokens.prompt,
@@ -520,8 +505,13 @@ export const ollamaProvider: ProviderConfig = {
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
cost: {
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
logs: [], // No block logs at provider level
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -530,11 +520,9 @@ export const ollamaProvider: ProviderConfig = {
},
} as StreamingExecution
// Return the streaming execution object
return streamingResult as StreamingExecution
}
// Calculate overall timing
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -557,7 +545,6 @@ export const ollamaProvider: ProviderConfig = {
},
}
} catch (error) {
// Include timing information even for errors
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -567,9 +554,8 @@ export const ollamaProvider: ProviderConfig = {
duration: totalDuration,
})
// Create a new error with timing information
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
// @ts-ignore
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,

View File

@@ -1,37 +1,14 @@
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import { createOpenAICompatibleStream } from '@/providers/utils'
/**
* Helper function to convert an Ollama stream to a standard ReadableStream
* and collect completion metrics
* Creates a ReadableStream from an Ollama streaming response.
* Uses the shared OpenAI-compatible streaming utility.
*/
export function createReadableStreamFromOllamaStream(
ollamaStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of ollamaStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
ollamaStream: AsyncIterable<ChatCompletionChunk>,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream<Uint8Array> {
return createOpenAICompatibleStream(ollamaStream, 'Ollama', onComplete)
}

View File

@@ -1,4 +1,5 @@
import OpenAI from 'openai'
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
@@ -11,6 +12,7 @@ import type {
TimeSegment,
} from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
@@ -19,9 +21,6 @@ import { executeTool } from '@/tools'
const logger = createLogger('OpenAIProvider')
/**
* OpenAI provider configuration
*/
export const openaiProvider: ProviderConfig = {
id: 'openai',
name: 'OpenAI',
@@ -43,13 +42,10 @@ export const openaiProvider: ProviderConfig = {
stream: !!request.stream,
})
// API key is now handled server-side before this function is called
const openai = new OpenAI({ apiKey: request.apiKey })
// Start with an empty array for all messages
const allMessages = []
// Add system prompt if present
if (request.systemPrompt) {
allMessages.push({
role: 'system',
@@ -57,7 +53,6 @@ export const openaiProvider: ProviderConfig = {
})
}
// Add context if present
if (request.context) {
allMessages.push({
role: 'user',
@@ -65,12 +60,10 @@ export const openaiProvider: ProviderConfig = {
})
}
// Add remaining messages
if (request.messages) {
allMessages.push(...request.messages)
}
// Transform tools to OpenAI format if provided
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
@@ -82,23 +75,18 @@ export const openaiProvider: ProviderConfig = {
}))
: undefined
// Build the request payload
const payload: any = {
model: request.model || 'gpt-4o',
messages: allMessages,
}
// Add optional parameters
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
// Add GPT-5 specific parameters
if (request.reasoningEffort !== undefined) payload.reasoning_effort = request.reasoningEffort
if (request.verbosity !== undefined) payload.verbosity = request.verbosity
// Add response format for structured output if specified
if (request.responseFormat) {
// Use OpenAI's JSON schema format
payload.response_format = {
type: 'json_schema',
json_schema: {
@@ -111,7 +99,6 @@ export const openaiProvider: ProviderConfig = {
logger.info('Added JSON schema response format to request')
}
// Handle tools and tool usage control
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
if (tools?.length) {
@@ -139,39 +126,40 @@ export const openaiProvider: ProviderConfig = {
}
}
// Start execution timer for the entire provider execution
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
// Check if we can stream directly (no tools required)
if (request.stream && (!tools || tools.length === 0)) {
logger.info('Using streaming response for OpenAI request')
// Create a streaming request with token usage tracking
const streamResponse = await openai.chat.completions.create({
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
stream: true,
stream_options: { include_usage: true },
})
// Start collecting token usage from the stream
const tokenUsage = {
prompt: 0,
completion: 0,
total: 0,
}
const streamResponse = await openai.chat.completions.create(streamingParams)
let _streamContent = ''
// Create a StreamingExecution response with a callback to update content and tokens
const streamingResult = {
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
// Update the execution data with the final content and token usage
_streamContent = content
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: usage.prompt_tokens,
completion: usage.completion_tokens,
total: usage.total_tokens,
}
const costResult = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
// Update the timing information with the actual completion time
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
@@ -180,7 +168,6 @@ export const openaiProvider: ProviderConfig = {
streamingResult.execution.output.providerTiming.duration =
streamEndTime - providerStartTime
// Update the time segment as well
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
streamEndTime
@@ -188,25 +175,13 @@ export const openaiProvider: ProviderConfig = {
streamEndTime - providerStartTime
}
}
// Update token usage if available from the stream
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokenUsage.prompt,
completion: usage.completion_tokens || tokenUsage.completion,
total: usage.total_tokens || tokenUsage.total,
}
streamingResult.execution.output.tokens = newTokens
}
// We don't need to estimate tokens here as logger.ts will handle that
}),
execution: {
success: true,
output: {
content: '', // Will be filled by the stream completion callback
content: '',
model: request.model,
tokens: tokenUsage,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
@@ -222,9 +197,9 @@ export const openaiProvider: ProviderConfig = {
},
],
},
// Cost will be calculated in logger
cost: { input: 0, output: 0, total: 0 },
},
logs: [], // No block logs for direct streaming
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -233,21 +208,19 @@ export const openaiProvider: ProviderConfig = {
},
} as StreamingExecution
// Return the streaming execution object with explicit casting
return streamingResult as StreamingExecution
}
// Make the initial API request
const initialCallTime = Date.now()
// Track the original tool_choice for forced tool tracking
const originalToolChoice = payload.tool_choice
// Track forced tools and their usage
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
// Helper function to check for forced tool usage in responses
/**
* Helper function to check for forced tool usage in responses
*/
const checkForForcedToolUsage = (
response: any,
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }
@@ -271,7 +244,6 @@ export const openaiProvider: ProviderConfig = {
const firstResponseTime = Date.now() - initialCallTime
let content = currentResponse.choices[0]?.message?.content || ''
// Collect token information but don't calculate costs - that will be done in logger.ts
const tokens = {
prompt: currentResponse.usage?.prompt_tokens || 0,
completion: currentResponse.usage?.completion_tokens || 0,
@@ -282,14 +254,11 @@ export const openaiProvider: ProviderConfig = {
const currentMessages = [...allMessages]
let iterationCount = 0
// Track time spent in model vs tools
let modelTime = firstResponseTime
let toolsTime = 0
// Track if a forced tool has been used
let hasUsedForcedTool = false
// Track each model and tool call segment with timestamps
const timeSegments: TimeSegment[] = [
{
type: 'model',
@@ -300,57 +269,103 @@ export const openaiProvider: ProviderConfig = {
},
]
// Check if a forced tool was used in the first response
checkForForcedToolUsage(currentResponse, originalToolChoice)
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
logger.info(
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
`Processing ${toolCallsInResponse.length} tool calls in parallel (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
)
// Track time for tool calls in this batch
const toolsStartTime = Date.now()
// Process each tool call
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
// Get the tool from the tools registry
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
// Execute the tool
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) {
return null
}
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
// Add to time segments for both success and failure
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('Error processing tool call:', { error, toolName })
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: toolCallsInResponse.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
startTime: startTime,
endTime: endTime,
duration: duration,
})
// Prepare result content for the LLM
let resultContent: any
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
} else {
// Include error information so LLM can respond appropriately
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
@@ -361,84 +376,52 @@ export const openaiProvider: ProviderConfig = {
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
// Add the tool call and result to messages (both success and failure)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
} catch (error) {
logger.error('Error processing tool call:', {
error,
toolName: toolCall?.function?.name,
})
}
}
// Calculate tool call time for this iteration
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
// Make the next request with updated messages
const nextPayload = {
...payload,
messages: currentMessages,
}
// Update tool_choice based on which forced tools have been used
if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) {
// If we have remaining forced tools, get the next one to force
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
if (remainingTools.length > 0) {
// Force the next tool
nextPayload.tool_choice = {
type: 'function',
function: { name: remainingTools[0] },
}
logger.info(`Forcing next tool: ${remainingTools[0]}`)
} else {
// All forced tools have been used, switch to auto
nextPayload.tool_choice = 'auto'
logger.info('All forced tools have been used, switching to auto tool_choice')
}
}
// Time the next model call
const nextModelStartTime = Date.now()
// Make the next request
currentResponse = await openai.chat.completions.create(nextPayload)
// Check if any forced tools were used in this response
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
// Add to time segments
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
@@ -447,15 +430,8 @@ export const openaiProvider: ProviderConfig = {
duration: thisModelTime,
})
// Add to model time
modelTime += thisModelTime
// Update content if we have a text response
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
// Update token counts
if (currentResponse.usage) {
tokens.prompt += currentResponse.usage.prompt_tokens || 0
tokens.completion += currentResponse.usage.completion_tokens || 0
@@ -465,46 +441,44 @@ export const openaiProvider: ProviderConfig = {
iterationCount++
}
// After all tool processing complete, if streaming was requested, use streaming for the final response
if (request.stream) {
logger.info('Using streaming for final response after tool processing')
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
// This prevents OpenAI API from trying to force tool usage again in the final streaming response
const streamingPayload = {
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
messages: currentMessages,
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
tool_choice: 'auto',
stream: true,
stream_options: { include_usage: true },
}
const streamResponse = await openai.chat.completions.create(streamingPayload)
// Create the StreamingExecution object with all collected data
let _streamContent = ''
const streamResponse = await openai.chat.completions.create(streamingParams)
const streamingResult = {
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
// Update the execution data with the final content and token usage
_streamContent = content
streamingResult.execution.output.content = content
// Update token usage if available from the stream
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokens.prompt,
completion: usage.completion_tokens || tokens.completion,
total: usage.total_tokens || tokens.total,
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.prompt_tokens,
completion: tokens.completion + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
streamingResult.execution.output.tokens = newTokens
const streamCost = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}),
execution: {
success: true,
output: {
content: '', // Will be filled by the callback
content: '',
model: request.model,
tokens: {
prompt: tokens.prompt,
@@ -528,9 +502,13 @@ export const openaiProvider: ProviderConfig = {
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
// Cost will be calculated in logger
cost: {
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
logs: [], // No block logs at provider level
},
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -539,11 +517,9 @@ export const openaiProvider: ProviderConfig = {
},
} as StreamingExecution
// Return the streaming execution object with explicit casting
return streamingResult as StreamingExecution
}
// Calculate overall timing
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -564,10 +540,8 @@ export const openaiProvider: ProviderConfig = {
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
// We're not calculating cost here as it will be handled in logger.ts
}
} catch (error) {
// Include timing information even for errors
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -577,7 +551,6 @@ export const openaiProvider: ProviderConfig = {
duration: totalDuration,
})
// Create a new error with timing information
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
enhancedError.timing = {

View File

@@ -1,37 +1,15 @@
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import type { Stream } from 'openai/streaming'
import { createOpenAICompatibleStream } from '@/providers/utils'
/**
* Helper function to convert an OpenAI stream to a standard ReadableStream
* and collect completion metrics
* Creates a ReadableStream from an OpenAI streaming response.
* Uses the shared OpenAI-compatible streaming utility.
*/
export function createReadableStreamFromOpenAIStream(
openaiStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of openaiStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
openaiStream: Stream<ChatCompletionChunk>,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream<Uint8Array> {
return createOpenAICompatibleStream(openaiStream, 'OpenAI', onComplete)
}

View File

@@ -1,4 +1,5 @@
import OpenAI from 'openai'
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
@@ -6,6 +7,7 @@ import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
import {
checkForForcedToolUsage,
createReadableStreamFromOpenAIStream,
supportsNativeStructuredOutputs,
} from '@/providers/openrouter/utils'
import type {
ProviderConfig,
@@ -13,11 +15,49 @@ import type {
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
import {
calculateCost,
generateSchemaInstructions,
prepareToolExecution,
prepareToolsWithUsageControl,
} from '@/providers/utils'
import { executeTool } from '@/tools'
const logger = createLogger('OpenRouterProvider')
/**
* Applies structured output configuration to a payload based on model capabilities.
* Uses json_schema with require_parameters for supported models, falls back to json_object with prompt instructions.
*/
async function applyResponseFormat(
targetPayload: any,
messages: any[],
responseFormat: any,
model: string
): Promise<any[]> {
const useNative = await supportsNativeStructuredOutputs(model)
if (useNative) {
logger.info('Using native structured outputs for OpenRouter model', { model })
targetPayload.response_format = {
type: 'json_schema',
json_schema: {
name: responseFormat.name || 'response_schema',
schema: responseFormat.schema || responseFormat,
strict: responseFormat.strict !== false,
},
}
targetPayload.provider = { ...targetPayload.provider, require_parameters: true }
return messages
}
logger.info('Using json_object mode with prompt instructions for OpenRouter model', { model })
const schema = responseFormat.schema || responseFormat
const schemaInstructions = generateSchemaInstructions(schema, responseFormat.name)
targetPayload.response_format = { type: 'json_object' }
return [...messages, { role: 'user', content: schemaInstructions }]
}
export const openRouterProvider: ProviderConfig = {
id: 'openrouter',
name: 'OpenRouter',
@@ -83,17 +123,6 @@ export const openRouterProvider: ProviderConfig = {
if (request.temperature !== undefined) payload.temperature = request.temperature
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
if (request.responseFormat) {
payload.response_format = {
type: 'json_schema',
json_schema: {
name: request.responseFormat.name || 'response_schema',
schema: request.responseFormat.schema || request.responseFormat,
strict: request.responseFormat.strict !== false,
},
}
}
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
let hasActiveTools = false
if (tools?.length) {
@@ -110,26 +139,43 @@ export const openRouterProvider: ProviderConfig = {
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
if (request.responseFormat && !hasActiveTools) {
payload.messages = await applyResponseFormat(
payload,
payload.messages,
request.responseFormat,
requestedModel
)
}
if (request.stream && (!tools || tools.length === 0 || !hasActiveTools)) {
const streamResponse = await client.chat.completions.create({
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
stream: true,
stream_options: { include_usage: true },
})
const tokenUsage = { prompt: 0, completion: 0, total: 0 }
}
const streamResponse = await client.chat.completions.create(streamingParams)
const streamingResult = {
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokenUsage.prompt,
completion: usage.completion_tokens || tokenUsage.completion,
total: usage.total_tokens || tokenUsage.total,
}
streamingResult.execution.output.tokens = newTokens
}
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: usage.prompt_tokens,
completion: usage.completion_tokens,
total: usage.total_tokens,
}
const costResult = calculateCost(
requestedModel,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
const end = Date.now()
const endISO = new Date(end).toISOString()
if (streamingResult.execution.output.providerTiming) {
@@ -147,7 +193,7 @@ export const openRouterProvider: ProviderConfig = {
output: {
content: '',
model: requestedModel,
tokens: tokenUsage,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
@@ -163,6 +209,7 @@ export const openRouterProvider: ProviderConfig = {
},
],
},
cost: { input: 0, output: 0, total: 0 },
},
logs: [],
metadata: {
@@ -217,31 +264,90 @@ export const openRouterProvider: ProviderConfig = {
usedForcedTools = forcedToolResult.usedForcedTools
while (iterationCount < MAX_TOOL_ITERATIONS) {
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
const toolsStartTime = Date.now()
for (const toolCall of toolCallsInResponse) {
try {
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
const toolCallStartTime = Date.now()
if (!tool) return null
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('Error processing tool call (OpenRouter):', {
error: error instanceof Error ? error.message : String(error),
toolName,
})
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: toolCallsInResponse.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
startTime: startTime,
endTime: endTime,
duration: duration,
})
let resultContent: any
@@ -259,45 +365,24 @@ export const openRouterProvider: ProviderConfig = {
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
} catch (error) {
logger.error('Error processing tool call (OpenRouter):', {
error: error instanceof Error ? error.message : String(error),
toolName: toolCall?.function?.name,
})
}
}
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
const nextPayload: any = {
const nextPayload = {
...payload,
messages: currentMessages,
}
@@ -343,26 +428,52 @@ export const openRouterProvider: ProviderConfig = {
}
if (request.stream) {
const streamingPayload = {
...payload,
messages: currentMessages,
tool_choice: 'auto',
const accumulatedCost = calculateCost(requestedModel, tokens.prompt, tokens.completion)
const streamingParams: ChatCompletionCreateParamsStreaming & { provider?: any } = {
model: payload.model,
messages: [...currentMessages],
stream: true,
stream_options: { include_usage: true },
}
const streamResponse = await client.chat.completions.create(streamingPayload)
if (payload.temperature !== undefined) {
streamingParams.temperature = payload.temperature
}
if (payload.max_tokens !== undefined) {
streamingParams.max_tokens = payload.max_tokens
}
if (request.responseFormat) {
;(streamingParams as any).messages = await applyResponseFormat(
streamingParams as any,
streamingParams.messages,
request.responseFormat,
requestedModel
)
}
const streamResponse = await client.chat.completions.create(streamingParams)
const streamingResult = {
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokens.prompt,
completion: usage.completion_tokens || tokens.completion,
total: usage.total_tokens || tokens.total,
}
streamingResult.execution.output.tokens = newTokens
}
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.prompt_tokens,
completion: tokens.completion + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
const streamCost = calculateCost(
requestedModel,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}),
execution: {
success: true,
@@ -387,6 +498,11 @@ export const openRouterProvider: ProviderConfig = {
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
cost: {
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
},
logs: [],
metadata: {
@@ -400,6 +516,49 @@ export const openRouterProvider: ProviderConfig = {
return streamingResult as StreamingExecution
}
if (request.responseFormat && hasActiveTools && toolCalls.length > 0) {
const finalPayload: any = {
model: payload.model,
messages: [...currentMessages],
}
if (payload.temperature !== undefined) {
finalPayload.temperature = payload.temperature
}
if (payload.max_tokens !== undefined) {
finalPayload.max_tokens = payload.max_tokens
}
finalPayload.messages = await applyResponseFormat(
finalPayload,
finalPayload.messages,
request.responseFormat,
requestedModel
)
const finalStartTime = Date.now()
const finalResponse = await client.chat.completions.create(finalPayload)
const finalEndTime = Date.now()
const finalDuration = finalEndTime - finalStartTime
timeSegments.push({
type: 'model',
name: 'Final structured response',
startTime: finalStartTime,
endTime: finalEndTime,
duration: finalDuration,
})
modelTime += finalDuration
if (finalResponse.choices[0]?.message?.content) {
content = finalResponse.choices[0].message.content
}
if (finalResponse.usage) {
tokens.prompt += finalResponse.usage.prompt_tokens || 0
tokens.completion += finalResponse.usage.completion_tokens || 0
tokens.total += finalResponse.usage.total_tokens || 0
}
}
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -425,10 +584,21 @@ export const openRouterProvider: ProviderConfig = {
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
logger.error('Error in OpenRouter request:', {
const errorDetails: Record<string, any> = {
error: error instanceof Error ? error.message : String(error),
duration: totalDuration,
})
}
if (error && typeof error === 'object') {
const err = error as any
if (err.status) errorDetails.status = err.status
if (err.code) errorDetails.code = err.code
if (err.type) errorDetails.type = err.type
if (err.error?.message) errorDetails.providerMessage = err.error.message
if (err.error?.metadata) errorDetails.metadata = err.error.metadata
}
logger.error('Error in OpenRouter request:', errorDetails)
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore
enhancedError.timing = {

View File

@@ -1,55 +1,107 @@
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import { createLogger } from '@/lib/logs/console/logger'
import { trackForcedToolUsage } from '@/providers/utils'
import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils'
const logger = createLogger('OpenRouterProvider')
const logger = createLogger('OpenRouterUtils')
interface OpenRouterModelData {
id: string
supported_parameters?: string[]
}
interface ModelCapabilities {
supportsStructuredOutputs: boolean
supportsTools: boolean
}
let modelCapabilitiesCache: Map<string, ModelCapabilities> | null = null
let cacheTimestamp = 0
const CACHE_TTL_MS = 5 * 60 * 1000 // 5 minutes
/**
* Creates a ReadableStream from an OpenAI-compatible stream response
* @param openaiStream - The OpenAI stream to convert
* @param onComplete - Optional callback when streaming is complete with content and usage data
* @returns ReadableStream that emits text chunks
* Fetches and caches OpenRouter model capabilities from their API.
*/
export function createReadableStreamFromOpenAIStream(
openaiStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
async function fetchModelCapabilities(): Promise<Map<string, ModelCapabilities>> {
try {
for await (const chunk of openaiStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
const response = await fetch('https://openrouter.ai/api/v1/models', {
headers: { 'Content-Type': 'application/json' },
})
if (!response.ok) {
logger.warn('Failed to fetch OpenRouter model capabilities', {
status: response.status,
})
return new Map()
}
const data = await response.json()
const capabilities = new Map<string, ModelCapabilities>()
for (const model of (data.data ?? []) as OpenRouterModelData[]) {
const supportedParams = model.supported_parameters ?? []
capabilities.set(model.id, {
supportsStructuredOutputs: supportedParams.includes('structured_outputs'),
supportsTools: supportedParams.includes('tools'),
})
}
logger.info('Cached OpenRouter model capabilities', {
modelCount: capabilities.size,
withStructuredOutputs: Array.from(capabilities.values()).filter(
(c) => c.supportsStructuredOutputs
).length,
})
return capabilities
} catch (error) {
logger.error('Error fetching OpenRouter model capabilities', {
error: error instanceof Error ? error.message : String(error),
})
return new Map()
}
}
/**
* Checks if a forced tool was used in the response and updates tracking
* @param response - The API response containing tool calls
* @param toolChoice - The tool choice configuration (string or object)
* @param forcedTools - Array of forced tool names
* @param usedForcedTools - Array of already used forced tools
* @returns Object with hasUsedForcedTool flag and updated usedForcedTools array
* Gets capabilities for a specific OpenRouter model.
* Fetches from API if cache is stale or empty.
*/
export async function getOpenRouterModelCapabilities(
modelId: string
): Promise<ModelCapabilities | null> {
const now = Date.now()
if (!modelCapabilitiesCache || now - cacheTimestamp > CACHE_TTL_MS) {
modelCapabilitiesCache = await fetchModelCapabilities()
cacheTimestamp = now
}
const normalizedId = modelId.replace(/^openrouter\//, '')
return modelCapabilitiesCache.get(normalizedId) ?? null
}
/**
* Checks if a model supports native structured outputs (json_schema).
*/
export async function supportsNativeStructuredOutputs(modelId: string): Promise<boolean> {
const capabilities = await getOpenRouterModelCapabilities(modelId)
return capabilities?.supportsStructuredOutputs ?? false
}
/**
* Creates a ReadableStream from an OpenRouter streaming response.
* Uses the shared OpenAI-compatible streaming utility.
*/
export function createReadableStreamFromOpenAIStream(
openaiStream: AsyncIterable<ChatCompletionChunk>,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream<Uint8Array> {
return createOpenAICompatibleStream(openaiStream, 'OpenRouter', onComplete)
}
/**
* Checks if a forced tool was used in an OpenRouter response.
* Uses the shared OpenAI-compatible forced tool usage helper.
*/
export function checkForForcedToolUsage(
response: any,
@@ -57,22 +109,11 @@ export function checkForForcedToolUsage(
forcedTools: string[],
usedForcedTools: string[]
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
let hasUsedForcedTool = false
let updatedUsedForcedTools = usedForcedTools
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
const toolCallsResponse = response.choices[0].message.tool_calls
const result = trackForcedToolUsage(
toolCallsResponse,
return checkForForcedToolUsageOpenAI(
response,
toolChoice,
logger,
'openrouter',
'OpenRouter',
forcedTools,
updatedUsedForcedTools
usedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
updatedUsedForcedTools = result.usedForcedTools
}
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
}

View File

@@ -19,15 +19,12 @@ export type ProviderId =
* Model pricing information per million tokens
*/
export interface ModelPricing {
input: number // Cost per million tokens for input
cachedInput?: number // Cost per million tokens for cached input (optional)
output: number // Cost per million tokens for output
updatedAt: string // ISO timestamp when pricing was last updated
input: number // Per 1M tokens
cachedInput?: number // Per 1M tokens (if supported)
output: number // Per 1M tokens
updatedAt: string // Last updated date
}
/**
* Map of model IDs to their pricing information
*/
export type ModelPricingMap = Record<string, ModelPricing>
export interface TokenInfo {
@@ -84,20 +81,20 @@ export interface ProviderResponse {
toolCalls?: FunctionCallResponse[]
toolResults?: any[]
timing?: {
startTime: string // ISO timestamp when provider execution started
endTime: string // ISO timestamp when provider execution completed
duration: number // Total duration in milliseconds
modelTime?: number // Time spent in model generation (excluding tool calls)
toolsTime?: number // Time spent in tool calls
firstResponseTime?: number // Time to first token/response
iterations?: number // Number of model calls for tool use
timeSegments?: TimeSegment[] // Detailed timeline of all operations
startTime: string
endTime: string
duration: number
modelTime?: number
toolsTime?: number
firstResponseTime?: number
iterations?: number
timeSegments?: TimeSegment[]
}
cost?: {
input: number // Cost in USD for input tokens
output: number // Cost in USD for output tokens
total: number // Total cost in USD
pricing: ModelPricing // The pricing used for calculation
input: number
output: number
total: number
pricing: ModelPricing
}
}
@@ -150,27 +147,23 @@ export interface ProviderRequest {
strict?: boolean
}
local_execution?: boolean
workflowId?: string // Optional workflow ID for authentication context
workspaceId?: string // Optional workspace ID for MCP tool scoping
chatId?: string // Optional chat ID for checkpoint context
userId?: string // Optional user ID for tool execution context
workflowId?: string
workspaceId?: string
chatId?: string
userId?: string
stream?: boolean
streamToolCalls?: boolean // Whether to stream tool call responses back to user (default: false)
environmentVariables?: Record<string, string> // Environment variables for tool execution
workflowVariables?: Record<string, any> // Workflow variables for <variable.name> resolution
blockData?: Record<string, any> // Runtime block outputs for <block.field> resolution in custom tools
blockNameMapping?: Record<string, string> // Mapping of block names to IDs for resolution
isCopilotRequest?: boolean // Flag to indicate this request is from the copilot system
// Azure OpenAI specific parameters
streamToolCalls?: boolean
environmentVariables?: Record<string, string>
workflowVariables?: Record<string, any>
blockData?: Record<string, any>
blockNameMapping?: Record<string, string>
isCopilotRequest?: boolean
azureEndpoint?: string
azureApiVersion?: string
// Vertex AI specific parameters
vertexProject?: string
vertexLocation?: string
// GPT-5 specific parameters
reasoningEffort?: string
verbosity?: string
}
// Map of provider IDs to their configurations
export const providers: Record<string, ProviderConfig> = {}

View File

@@ -1,6 +1,8 @@
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import { getEnv, isTruthy } from '@/lib/core/config/env'
import { isHosted } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger'
import { createLogger, type Logger } from '@/lib/logs/console/logger'
import { anthropicProvider } from '@/providers/anthropic'
import { azureOpenAIProvider } from '@/providers/azure-openai'
import { cerebrasProvider } from '@/providers/cerebras'
@@ -40,9 +42,6 @@ import { useProvidersStore } from '@/stores/providers/store'
const logger = createLogger('ProviderUtils')
/**
* Provider configurations - built from the comprehensive definitions
*/
export const providers: Record<
ProviderId,
ProviderConfig & {
@@ -213,7 +212,6 @@ export function getProviderFromModel(model: string): ProviderId {
}
export function getProvider(id: string): ProviderConfig | undefined {
// Handle both formats: 'openai' and 'openai/chat'
const providerId = id.split('/')[0] as ProviderId
return providers[providerId]
}
@@ -273,14 +271,27 @@ export function filterBlacklistedModels(models: string[]): string[] {
return models.filter((model) => !isModelBlacklisted(model))
}
/**
* Get provider icon for a given model
*/
export function getProviderIcon(model: string): React.ComponentType<{ className?: string }> | null {
const providerId = getProviderFromModel(model)
return PROVIDER_DEFINITIONS[providerId]?.icon || null
}
/**
* Generates prompt instructions for structured JSON output from a JSON schema.
* Used as a fallback when native structured outputs are not supported.
*/
export function generateSchemaInstructions(schema: any, schemaName?: string): string {
const name = schemaName || 'response'
return `IMPORTANT: You must respond with a valid JSON object that conforms to the following schema.
Do not include any text before or after the JSON object. Only output the JSON.
Schema name: ${name}
JSON Schema:
${JSON.stringify(schema, null, 2)}
Your response must be valid JSON that exactly matches this schema structure.`
}
export function generateStructuredOutputInstructions(responseFormat: any): string {
if (!responseFormat) return ''
@@ -479,7 +490,6 @@ export async function transformBlockTool(
const llmSchema = await createLLMToolSchema(toolConfig, userProvidedParams)
// Create unique tool ID by appending resource ID for multi-instance tools
let uniqueToolId = toolConfig.id
if (toolId === 'workflow_executor' && userProvidedParams.workflowId) {
uniqueToolId = `${toolConfig.id}_${userProvidedParams.workflowId}`
@@ -554,9 +564,6 @@ export function calculateCost(
}
}
/**
* Get pricing information for a specific model (including embedding models)
*/
export function getModelPricing(modelId: string): any {
const embeddingPricing = getEmbeddingModelPricing(modelId)
if (embeddingPricing) {
@@ -615,29 +622,25 @@ export function shouldBillModelUsage(model: string): boolean {
* For use server-side only
*/
export function getApiKey(provider: string, model: string, userProvidedKey?: string): string {
// If user provided a key, use it as a fallback
const hasUserKey = !!userProvidedKey
// Ollama and vLLM models don't require API keys
const isOllamaModel =
provider === 'ollama' || useProvidersStore.getState().providers.ollama.models.includes(model)
if (isOllamaModel) {
return 'empty' // Ollama uses 'empty' as a placeholder API key
return 'empty'
}
const isVllmModel =
provider === 'vllm' || useProvidersStore.getState().providers.vllm.models.includes(model)
if (isVllmModel) {
return userProvidedKey || 'empty' // vLLM uses 'empty' as a placeholder if no key provided
return userProvidedKey || 'empty'
}
// Use server key rotation for all OpenAI models, Anthropic's Claude models, and Google's Gemini models on the hosted platform
const isOpenAIModel = provider === 'openai'
const isClaudeModel = provider === 'anthropic'
const isGeminiModel = provider === 'google'
if (isHosted && (isOpenAIModel || isClaudeModel || isGeminiModel)) {
// Only use server key if model is explicitly in our hosted list
const hostedModels = getHostedModels()
const isModelHosted = hostedModels.some((m) => m.toLowerCase() === model.toLowerCase())
@@ -656,7 +659,6 @@ export function getApiKey(provider: string, model: string, userProvidedKey?: str
}
}
// For all other cases, require user-provided key
if (!hasUserKey) {
throw new Error(`API key is required for ${provider} ${model}`)
}
@@ -688,16 +690,14 @@ export function prepareToolsWithUsageControl(
| { type: 'any'; any: { model: string; name: string } }
| undefined
toolConfig?: {
// Add toolConfig for Google's format
functionCallingConfig: {
mode: 'AUTO' | 'ANY' | 'NONE'
allowedFunctionNames?: string[]
}
}
hasFilteredTools: boolean
forcedTools: string[] // Return all forced tool IDs
forcedTools: string[]
} {
// If no tools, return early
if (!tools || tools.length === 0) {
return {
tools: undefined,
@@ -707,14 +707,12 @@ export function prepareToolsWithUsageControl(
}
}
// Filter out tools marked with usageControl='none'
const filteredTools = tools.filter((tool) => {
const toolId = tool.function?.name || tool.name
const toolConfig = providerTools?.find((t) => t.id === toolId)
return toolConfig?.usageControl !== 'none'
})
// Check if any tools were filtered out
const hasFilteredTools = filteredTools.length < tools.length
if (hasFilteredTools) {
logger.info(
@@ -722,7 +720,6 @@ export function prepareToolsWithUsageControl(
)
}
// If all tools were filtered out, return empty
if (filteredTools.length === 0) {
logger.info('All tools were filtered out due to usageControl="none"')
return {
@@ -733,11 +730,9 @@ export function prepareToolsWithUsageControl(
}
}
// Get all tools that should be forced
const forcedTools = providerTools?.filter((tool) => tool.usageControl === 'force') || []
const forcedToolIds = forcedTools.map((tool) => tool.id)
// Determine tool_choice setting
let toolChoice:
| 'auto'
| 'none'
@@ -745,7 +740,6 @@ export function prepareToolsWithUsageControl(
| { type: 'tool'; name: string }
| { type: 'any'; any: { model: string; name: string } } = 'auto'
// For Google, we'll use a separate toolConfig object
let toolConfig:
| {
functionCallingConfig: {
@@ -756,30 +750,22 @@ export function prepareToolsWithUsageControl(
| undefined
if (forcedTools.length > 0) {
// Force the first tool that has usageControl='force'
const forcedTool = forcedTools[0]
// Adjust format based on provider
if (provider === 'anthropic') {
toolChoice = {
type: 'tool',
name: forcedTool.id,
}
} else if (provider === 'google') {
// Google Gemini format uses a separate toolConfig object
toolConfig = {
functionCallingConfig: {
mode: 'ANY',
allowedFunctionNames:
forcedTools.length === 1
? [forcedTool.id] // If only one tool, specify just that one
: forcedToolIds, // If multiple tools, include all of them
allowedFunctionNames: forcedTools.length === 1 ? [forcedTool.id] : forcedToolIds,
},
}
// Keep toolChoice as 'auto' since we use toolConfig instead
toolChoice = 'auto'
} else {
// Default OpenAI format
toolChoice = {
type: 'function',
function: { name: forcedTool.id },
@@ -794,7 +780,6 @@ export function prepareToolsWithUsageControl(
)
}
} else {
// Default to auto if no forced tools
toolChoice = 'auto'
if (provider === 'google') {
toolConfig = { functionCallingConfig: { mode: 'AUTO' } }
@@ -845,7 +830,6 @@ export function trackForcedToolUsage(
}
}
} {
// Default to keeping the original tool_choice
let hasUsedForcedTool = false
let nextToolChoice = originalToolChoice
let nextToolConfig:
@@ -859,13 +843,10 @@ export function trackForcedToolUsage(
const updatedUsedForcedTools = [...usedForcedTools]
// Special handling for Google format
const isGoogleFormat = provider === 'google'
// Get the name of the current forced tool(s)
let forcedToolNames: string[] = []
if (isGoogleFormat && originalToolChoice?.functionCallingConfig?.allowedFunctionNames) {
// For Google format
forcedToolNames = originalToolChoice.functionCallingConfig.allowedFunctionNames
} else if (
typeof originalToolChoice === 'object' &&
@@ -873,7 +854,6 @@ export function trackForcedToolUsage(
(originalToolChoice?.type === 'tool' && originalToolChoice?.name) ||
(originalToolChoice?.type === 'any' && originalToolChoice?.any?.name))
) {
// For other providers
forcedToolNames = [
originalToolChoice?.function?.name ||
originalToolChoice?.name ||
@@ -881,27 +861,20 @@ export function trackForcedToolUsage(
].filter(Boolean)
}
// If we're forcing specific tools and we have tool calls in the response
if (forcedToolNames.length > 0 && toolCallsResponse && toolCallsResponse.length > 0) {
// Check if any of the tool calls used the forced tools
const toolNames = toolCallsResponse.map((tc) => tc.function?.name || tc.name || tc.id)
// Find any forced tools that were used
const usedTools = forcedToolNames.filter((toolName) => toolNames.includes(toolName))
if (usedTools.length > 0) {
// At least one forced tool was used
hasUsedForcedTool = true
updatedUsedForcedTools.push(...usedTools)
// Find the next tools to force that haven't been used yet
const remainingTools = forcedTools.filter((tool) => !updatedUsedForcedTools.includes(tool))
if (remainingTools.length > 0) {
// There are still forced tools to use
const nextToolToForce = remainingTools[0]
// Format based on provider
if (provider === 'anthropic') {
nextToolChoice = {
type: 'tool',
@@ -912,13 +885,10 @@ export function trackForcedToolUsage(
functionCallingConfig: {
mode: 'ANY',
allowedFunctionNames:
remainingTools.length === 1
? [nextToolToForce] // If only one tool left, specify just that one
: remainingTools, // If multiple tools, include all remaining
remainingTools.length === 1 ? [nextToolToForce] : remainingTools,
},
}
} else {
// Default OpenAI format
nextToolChoice = {
type: 'function',
function: { name: nextToolToForce },
@@ -929,9 +899,7 @@ export function trackForcedToolUsage(
`Forced tool(s) ${usedTools.join(', ')} used, switching to next forced tool(s): ${remainingTools.join(', ')}`
)
} else {
// All forced tools have been used, switch to auto mode
if (provider === 'anthropic') {
// Anthropic: return null to signal the parameter should be deleted/omitted
nextToolChoice = null
} else if (provider === 'google') {
nextToolConfig = { functionCallingConfig: { mode: 'AUTO' } }
@@ -963,9 +931,6 @@ export const MODELS_WITH_REASONING_EFFORT = getModelsWithReasoningEffort()
export const MODELS_WITH_VERBOSITY = getModelsWithVerbosity()
export const PROVIDERS_WITH_TOOL_USAGE_CONTROL = getProvidersWithToolUsageControl()
/**
* Check if a model supports temperature parameter
*/
export function supportsTemperature(model: string): boolean {
return supportsTemperatureFromDefinitions(model)
}
@@ -978,9 +943,6 @@ export function getMaxTemperature(model: string): number | undefined {
return getMaxTempFromDefinitions(model)
}
/**
* Check if a provider supports tool usage control
*/
export function supportsToolUsageControl(provider: string): boolean {
return supportsToolUsageControlFromDefinitions(provider)
}
@@ -1021,8 +983,6 @@ export function prepareToolExecution(
toolParams: Record<string, any>
executionParams: Record<string, any>
} {
// Filter out empty/null/undefined values from user params
// so that cleared fields don't override LLM-generated values
const filteredUserParams: Record<string, any> = {}
if (tool.params) {
for (const [key, value] of Object.entries(tool.params)) {
@@ -1032,13 +992,11 @@ export function prepareToolExecution(
}
}
// User-provided params take precedence over LLM-generated params
const toolParams = {
...llmArgs,
...filteredUserParams,
}
// Add system parameters for execution
const executionParams = {
...toolParams,
...(request.workflowId
@@ -1055,9 +1013,107 @@ export function prepareToolExecution(
...(request.workflowVariables ? { workflowVariables: request.workflowVariables } : {}),
...(request.blockData ? { blockData: request.blockData } : {}),
...(request.blockNameMapping ? { blockNameMapping: request.blockNameMapping } : {}),
// Pass tool schema for MCP tools to skip discovery
...(tool.parameters ? { _toolSchema: tool.parameters } : {}),
}
return { toolParams, executionParams }
}
/**
* Creates a ReadableStream from an OpenAI-compatible streaming response.
* This is a shared utility used by all OpenAI-compatible providers:
* OpenAI, Groq, DeepSeek, xAI, OpenRouter, Mistral, Ollama, vLLM, Azure OpenAI, Cerebras
*
* @param stream - The async iterable stream from the provider
* @param providerName - Name of the provider for logging purposes
* @param onComplete - Optional callback called when stream completes with full content and usage
* @returns A ReadableStream that can be used for streaming responses
*/
export function createOpenAICompatibleStream(
stream: AsyncIterable<ChatCompletionChunk>,
providerName: string,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream<Uint8Array> {
const streamLogger = createLogger(`${providerName}Utils`)
let fullContent = ''
let promptTokens = 0
let completionTokens = 0
let totalTokens = 0
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of stream) {
if (chunk.usage) {
promptTokens = chunk.usage.prompt_tokens ?? 0
completionTokens = chunk.usage.completion_tokens ?? 0
totalTokens = chunk.usage.total_tokens ?? 0
}
const content = chunk.choices?.[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
if (promptTokens === 0 && completionTokens === 0) {
streamLogger.warn(`${providerName} stream completed without usage data`)
}
onComplete(fullContent, {
prompt_tokens: promptTokens,
completion_tokens: completionTokens,
total_tokens: totalTokens || promptTokens + completionTokens,
})
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
}
/**
* Checks if a forced tool was used in an OpenAI-compatible response and updates tracking.
* This is a shared utility used by OpenAI-compatible providers:
* OpenAI, Groq, DeepSeek, xAI, OpenRouter, Mistral, Ollama, vLLM, Azure OpenAI, Cerebras
*
* @param response - The API response containing tool calls
* @param toolChoice - The tool choice configuration (string or object)
* @param providerName - Name of the provider for logging purposes
* @param forcedTools - Array of forced tool names
* @param usedForcedTools - Array of already used forced tools
* @param customLogger - Optional custom logger instance
* @returns Object with hasUsedForcedTool flag and updated usedForcedTools array
*/
export function checkForForcedToolUsageOpenAI(
response: any,
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
providerName: string,
forcedTools: string[],
usedForcedTools: string[],
customLogger?: Logger
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
const checkLogger = customLogger || createLogger(`${providerName}Utils`)
let hasUsedForcedTool = false
let updatedUsedForcedTools = [...usedForcedTools]
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
const toolCallsResponse = response.choices[0].message.tool_calls
const result = trackForcedToolUsage(
toolCallsResponse,
toolChoice,
checkLogger,
providerName.toLowerCase().replace(/\s+/g, '-'),
forcedTools,
updatedUsedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
updatedUsedForcedTools = result.usedForcedTools
}
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
}

View File

@@ -16,6 +16,7 @@ import type {
TimeSegment,
} from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
@@ -238,14 +239,21 @@ export const vertexProvider: ProviderConfig = {
}
}
if (usage) {
const promptTokens = usage?.promptTokenCount || 0
const completionTokens = usage?.candidatesTokenCount || 0
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
streamingResult.execution.output.tokens = {
prompt: usage.promptTokenCount || 0,
completion: usage.candidatesTokenCount || 0,
total:
usage.totalTokenCount ||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0),
prompt: promptTokens,
completion: completionTokens,
total: totalTokens,
}
const costResult = calculateCost(request.model, promptTokens, completionTokens)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
}
)
@@ -541,8 +549,6 @@ export const vertexProvider: ProviderConfig = {
logger.info('No function call detected, proceeding with streaming response')
// Apply structured output for the final response if responseFormat is specified
// This works regardless of whether tools were forced or auto
if (request.responseFormat) {
streamingPayload.tools = undefined
streamingPayload.toolConfig = undefined
@@ -654,21 +660,40 @@ export const vertexProvider: ProviderConfig = {
streamEndTime - providerStartTime
}
if (usage) {
const promptTokens = usage?.promptTokenCount || 0
const completionTokens = usage?.candidatesTokenCount || 0
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
const existingTokens = streamingExecution.execution.output.tokens || {
prompt: 0,
completion: 0,
total: 0,
}
const existingPrompt = existingTokens.prompt || 0
const existingCompletion = existingTokens.completion || 0
const existingTotal = existingTokens.total || 0
streamingExecution.execution.output.tokens = {
prompt: (existingTokens.prompt || 0) + (usage.promptTokenCount || 0),
completion:
(existingTokens.completion || 0) + (usage.candidatesTokenCount || 0),
total:
(existingTokens.total || 0) +
(usage.totalTokenCount ||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0)),
prompt: existingPrompt + promptTokens,
completion: existingCompletion + completionTokens,
total: existingTotal + totalTokens,
}
const accumulatedCost = calculateCost(
request.model,
existingPrompt,
existingCompletion
)
const streamCost = calculateCost(
request.model,
promptTokens,
completionTokens
)
streamingExecution.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}
)
@@ -753,7 +778,6 @@ export const vertexProvider: ProviderConfig = {
const nextFunctionCall = extractFunctionCall(nextCandidate)
if (!nextFunctionCall) {
// If responseFormat is specified, make one final request with structured output
if (request.responseFormat) {
const finalPayload = {
...payload,
@@ -886,7 +910,7 @@ export const vertexProvider: ProviderConfig = {
})
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
// @ts-ignore
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,

View File

@@ -89,9 +89,7 @@ export function createReadableStreamFromVertexStream(
}
}
}
} catch (arrayError) {
// Buffer is not valid JSON array
}
} catch (arrayError) {}
}
}
}

View File

@@ -1,4 +1,5 @@
import OpenAI from 'openai'
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
import { env } from '@/lib/core/config/env'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
@@ -11,6 +12,7 @@ import type {
TimeSegment,
} from '@/providers/types'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
trackForcedToolUsage,
@@ -180,17 +182,12 @@ export const vllmProvider: ProviderConfig = {
if (request.stream && (!tools || tools.length === 0 || !hasActiveTools)) {
logger.info('Using streaming response for vLLM request')
const streamResponse = await vllm.chat.completions.create({
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
stream: true,
stream_options: { include_usage: true },
})
const tokenUsage = {
prompt: 0,
completion: 0,
total: 0,
}
const streamResponse = await vllm.chat.completions.create(streamingParams)
const streamingResult = {
stream: createReadableStreamFromVLLMStream(streamResponse, (content, usage) => {
@@ -200,6 +197,22 @@ export const vllmProvider: ProviderConfig = {
}
streamingResult.execution.output.content = cleanContent
streamingResult.execution.output.tokens = {
prompt: usage.prompt_tokens,
completion: usage.completion_tokens,
total: usage.total_tokens,
}
const costResult = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
const streamEndTime = Date.now()
const streamEndTimeISO = new Date(streamEndTime).toISOString()
@@ -216,23 +229,13 @@ export const vllmProvider: ProviderConfig = {
streamEndTime - providerStartTime
}
}
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokenUsage.prompt,
completion: usage.completion_tokens || tokenUsage.completion,
total: usage.total_tokens || tokenUsage.total,
}
streamingResult.execution.output.tokens = newTokens
}
}),
execution: {
success: true,
output: {
content: '',
model: request.model,
tokens: tokenUsage,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
@@ -248,6 +251,7 @@ export const vllmProvider: ProviderConfig = {
},
],
},
cost: { input: 0, output: 0, total: 0 },
},
logs: [],
metadata: {
@@ -324,6 +328,13 @@ export const vllmProvider: ProviderConfig = {
checkForForcedToolUsage(currentResponse, originalToolChoice)
while (iterationCount < MAX_TOOL_ITERATIONS) {
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
if (request.responseFormat) {
content = content.replace(/```json\n?|\n?```/g, '').trim()
}
}
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
@@ -335,27 +346,76 @@ export const vllmProvider: ProviderConfig = {
const toolsStartTime = Date.now()
for (const toolCall of toolCallsInResponse) {
try {
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) continue
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) return null
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('Error processing tool call:', { error, toolName })
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: toolCallsInResponse.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
startTime: startTime,
endTime: endTime,
duration: duration,
})
let resultContent: any
@@ -373,39 +433,18 @@ export const vllmProvider: ProviderConfig = {
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
} catch (error) {
logger.error('Error processing tool call:', {
error,
toolName: toolCall?.function?.name,
})
}
}
const thisToolsTime = Date.now() - toolsStartTime
@@ -469,15 +508,16 @@ export const vllmProvider: ProviderConfig = {
if (request.stream) {
logger.info('Using streaming for final response after tool processing')
const streamingPayload = {
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
messages: currentMessages,
tool_choice: 'auto',
stream: true,
stream_options: { include_usage: true },
}
const streamResponse = await vllm.chat.completions.create(streamingPayload)
const streamResponse = await vllm.chat.completions.create(streamingParams)
const streamingResult = {
stream: createReadableStreamFromVLLMStream(streamResponse, (content, usage) => {
@@ -487,15 +527,21 @@ export const vllmProvider: ProviderConfig = {
}
streamingResult.execution.output.content = cleanContent
if (usage) {
const newTokens = {
prompt: usage.prompt_tokens || tokens.prompt,
completion: usage.completion_tokens || tokens.completion,
total: usage.total_tokens || tokens.total,
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.prompt_tokens,
completion: tokens.completion + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
streamingResult.execution.output.tokens = newTokens
const streamCost = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}),
execution: {
@@ -525,6 +571,11 @@ export const vllmProvider: ProviderConfig = {
iterations: iterationCount + 1,
timeSegments: timeSegments,
},
cost: {
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
},
logs: [],
metadata: {
@@ -585,7 +636,7 @@ export const vllmProvider: ProviderConfig = {
})
const enhancedError = new Error(errorMessage)
// @ts-ignore - Adding timing and vLLM error properties
// @ts-ignore
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,

View File

@@ -1,37 +1,14 @@
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import { createOpenAICompatibleStream } from '@/providers/utils'
/**
* Helper function to convert a vLLM stream to a standard ReadableStream
* and collect completion metrics
* Creates a ReadableStream from a vLLM streaming response.
* Uses the shared OpenAI-compatible streaming utility.
*/
export function createReadableStreamFromVLLMStream(
vllmStream: any,
onComplete?: (content: string, usage?: any) => void
): ReadableStream {
let fullContent = ''
let usageData: any = null
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of vllmStream) {
if (chunk.usage) {
usageData = chunk.usage
}
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
fullContent += content
controller.enqueue(new TextEncoder().encode(content))
}
}
if (onComplete) {
onComplete(fullContent, usageData)
}
controller.close()
} catch (error) {
controller.error(error)
}
},
})
vllmStream: AsyncIterable<ChatCompletionChunk>,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream<Uint8Array> {
return createOpenAICompatibleStream(vllmStream, 'vLLM', onComplete)
}

View File

@@ -1,4 +1,5 @@
import OpenAI from 'openai'
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
@@ -9,7 +10,11 @@ import type {
ProviderResponse,
TimeSegment,
} from '@/providers/types'
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
import {
calculateCost,
prepareToolExecution,
prepareToolsWithUsageControl,
} from '@/providers/utils'
import {
checkForForcedToolUsage,
createReadableStreamFromXAIStream,
@@ -66,8 +71,6 @@ export const xAIProvider: ProviderConfig = {
if (request.messages) {
allMessages.push(...request.messages)
}
// Set up tools
const tools = request.tools?.length
? request.tools.map((tool) => ({
type: 'function',
@@ -78,15 +81,11 @@ export const xAIProvider: ProviderConfig = {
},
}))
: undefined
// Log tools and response format conflict detection
if (tools?.length && request.responseFormat) {
logger.warn(
'XAI Provider - Detected both tools and response format. Using tools first, then response format for final response.'
)
}
// Build the base request payload
const basePayload: any = {
model: request.model || 'grok-3-latest',
messages: allMessages,
@@ -94,52 +93,54 @@ export const xAIProvider: ProviderConfig = {
if (request.temperature !== undefined) basePayload.temperature = request.temperature
if (request.maxTokens !== undefined) basePayload.max_tokens = request.maxTokens
// Handle tools and tool usage control
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
if (tools?.length) {
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'xai')
}
// EARLY STREAMING: if caller requested streaming and there are no tools to execute,
// we can directly stream the completion with response format if needed
if (request.stream && (!tools || tools.length === 0)) {
logger.info('XAI Provider - Using direct streaming (no tools)')
// Start execution timer for the entire provider execution
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
// Use response format payload if needed, otherwise use base payload
const streamingPayload = request.responseFormat
? createResponseFormatPayload(basePayload, allMessages, request.responseFormat)
: { ...basePayload, stream: true }
if (!request.responseFormat) {
streamingPayload.stream = true
} else {
streamingPayload.stream = true
const streamingParams: ChatCompletionCreateParamsStreaming = request.responseFormat
? {
...createResponseFormatPayload(basePayload, allMessages, request.responseFormat),
stream: true,
stream_options: { include_usage: true },
}
: { ...basePayload, stream: true, stream_options: { include_usage: true } }
const streamResponse = await xai.chat.completions.create(streamingPayload)
const streamResponse = await xai.chat.completions.create(streamingParams)
// Start collecting token usage
const tokenUsage = {
prompt: 0,
completion: 0,
total: 0,
}
// Create a StreamingExecution response with a readable stream
const streamingResult = {
stream: createReadableStreamFromXAIStream(streamResponse),
stream: createReadableStreamFromXAIStream(streamResponse, (content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: usage.prompt_tokens,
completion: usage.completion_tokens,
total: usage.total_tokens,
}
const costResult = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: costResult.input,
output: costResult.output,
total: costResult.total,
}
}),
execution: {
success: true,
output: {
content: '', // Will be filled by streaming content in chat component
content: '',
model: request.model || 'grok-3-latest',
tokens: tokenUsage,
tokens: { prompt: 0, completion: 0, total: 0 },
toolCalls: undefined,
providerTiming: {
startTime: providerStartTimeISO,
@@ -155,14 +156,9 @@ export const xAIProvider: ProviderConfig = {
},
],
},
// Estimate token cost
cost: {
total: 0.0,
input: 0.0,
output: 0.0,
cost: { input: 0, output: 0, total: 0 },
},
},
logs: [], // No block logs for direct streaming
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -172,26 +168,18 @@ export const xAIProvider: ProviderConfig = {
},
}
// Return the streaming execution object
return streamingResult as StreamingExecution
}
// Start execution timer for the entire provider execution
const providerStartTime = Date.now()
const providerStartTimeISO = new Date(providerStartTime).toISOString()
try {
// Make the initial API request
const initialCallTime = Date.now()
// For the initial request with tools, we NEVER include response_format
// This is the key fix: tools and response_format cannot be used together with xAI
// xAI cannot use tools and response_format together in the same request
const initialPayload = { ...basePayload }
// Track the original tool_choice for forced tool tracking
let originalToolChoice: any
// Track forced tools and their usage
const forcedTools = preparedTools?.forcedTools || []
let usedForcedTools: string[] = []
@@ -201,7 +189,6 @@ export const xAIProvider: ProviderConfig = {
initialPayload.tool_choice = toolChoice
originalToolChoice = toolChoice
} else if (request.responseFormat) {
// Only add response format if there are no tools
const responseFormatPayload = createResponseFormatPayload(
basePayload,
allMessages,
@@ -224,14 +211,9 @@ export const xAIProvider: ProviderConfig = {
const currentMessages = [...allMessages]
let iterationCount = 0
// Track if a forced tool has been used
let hasUsedForcedTool = false
// Track time spent in model vs tools
let modelTime = firstResponseTime
let toolsTime = 0
// Track each model and tool call segment with timestamps
const timeSegments: TimeSegment[] = [
{
type: 'model',
@@ -241,8 +223,6 @@ export const xAIProvider: ProviderConfig = {
duration: firstResponseTime,
},
]
// Check if a forced tool was used in the first response
if (originalToolChoice) {
const result = checkForForcedToolUsage(
currentResponse,
@@ -256,56 +236,102 @@ export const xAIProvider: ProviderConfig = {
try {
while (iterationCount < MAX_TOOL_ITERATIONS) {
// Check for tool calls
if (currentResponse.choices[0]?.message?.content) {
content = currentResponse.choices[0].message.content
}
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
break
}
// Track time for tool calls in this batch
const toolsStartTime = Date.now()
for (const toolCall of toolCallsInResponse) {
try {
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
const toolCallStartTime = Date.now()
const toolName = toolCall.function.name
const toolArgs = JSON.parse(toolCall.function.arguments)
try {
const toolArgs = JSON.parse(toolCall.function.arguments)
const tool = request.tools?.find((t) => t.id === toolName)
if (!tool) {
logger.warn('XAI Provider - Tool not found:', { toolName })
continue
return null
}
const toolCallStartTime = Date.now()
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
const result = await executeTool(toolName, executionParams, true)
const toolCallEndTime = Date.now()
const toolCallDuration = toolCallEndTime - toolCallStartTime
// Add to time segments for both success and failure
return {
toolCall,
toolName,
toolParams,
result,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
} catch (error) {
const toolCallEndTime = Date.now()
logger.error('XAI Provider - Error processing tool call:', {
error: error instanceof Error ? error.message : String(error),
toolCall: toolCall.function.name,
})
return {
toolCall,
toolName,
toolParams: {},
result: {
success: false,
output: undefined,
error: error instanceof Error ? error.message : 'Tool execution failed',
},
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallEndTime - toolCallStartTime,
}
}
})
const executionResults = await Promise.allSettled(toolExecutionPromises)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: toolCallsInResponse.map((tc) => ({
id: tc.id,
type: 'function',
function: {
name: tc.function.name,
arguments: tc.function.arguments,
},
})),
})
for (const settledResult of executionResults) {
if (settledResult.status === 'rejected' || !settledResult.value) continue
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
settledResult.value
timeSegments.push({
type: 'tool',
name: toolName,
startTime: toolCallStartTime,
endTime: toolCallEndTime,
duration: toolCallDuration,
startTime: startTime,
endTime: endTime,
duration: duration,
})
// Prepare result content for the LLM
let resultContent: any
if (result.success) {
toolResults.push(result.output)
resultContent = result.output
} else {
// Include error information so LLM can respond appropriately
resultContent = {
error: true,
message: result.error || 'Tool execution failed',
tool: toolName,
}
logger.warn('XAI Provider - Tool execution failed:', {
toolName,
error: result.error,
@@ -315,60 +341,31 @@ export const xAIProvider: ProviderConfig = {
toolCalls.push({
name: toolName,
arguments: toolParams,
startTime: new Date(toolCallStartTime).toISOString(),
endTime: new Date(toolCallEndTime).toISOString(),
duration: toolCallDuration,
startTime: new Date(startTime).toISOString(),
endTime: new Date(endTime).toISOString(),
duration: duration,
result: resultContent,
success: result.success,
})
// Add the tool call and result to messages (both success and failure)
currentMessages.push({
role: 'assistant',
content: null,
tool_calls: [
{
id: toolCall.id,
type: 'function',
function: {
name: toolName,
arguments: toolCall.function.arguments,
},
},
],
})
currentMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: JSON.stringify(resultContent),
})
} catch (error) {
logger.error('XAI Provider - Error processing tool call:', {
error: error instanceof Error ? error.message : String(error),
toolCall: toolCall.function.name,
})
}
}
// Calculate tool call time for this iteration
const thisToolsTime = Date.now() - toolsStartTime
toolsTime += thisToolsTime
// After tool calls, create next payload based on whether we need more tools or final response
let nextPayload: any
// Update tool_choice based on which forced tools have been used
if (
typeof originalToolChoice === 'object' &&
hasUsedForcedTool &&
forcedTools.length > 0
) {
// If we have remaining forced tools, get the next one to force
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
if (remainingTools.length > 0) {
// Force the next tool - continue with tools, no response format
nextPayload = {
...basePayload,
messages: currentMessages,
@@ -379,7 +376,6 @@ export const xAIProvider: ProviderConfig = {
},
}
} else {
// All forced tools have been used, check if we need response format for final response
if (request.responseFormat) {
nextPayload = createResponseFormatPayload(
basePayload,
@@ -397,9 +393,7 @@ export const xAIProvider: ProviderConfig = {
}
}
} else {
// Normal tool processing - check if this might be the final response
if (request.responseFormat) {
// Use response format for what might be the final response
nextPayload = createResponseFormatPayload(
basePayload,
allMessages,
@@ -416,12 +410,9 @@ export const xAIProvider: ProviderConfig = {
}
}
// Time the next model call
const nextModelStartTime = Date.now()
currentResponse = await xai.chat.completions.create(nextPayload)
// Check if any forced tools were used in this response
if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') {
const result = checkForForcedToolUsage(
currentResponse,
@@ -435,8 +426,6 @@ export const xAIProvider: ProviderConfig = {
const nextModelEndTime = Date.now()
const thisModelTime = nextModelEndTime - nextModelStartTime
// Add to time segments
timeSegments.push({
type: 'model',
name: `Model response (iteration ${iterationCount + 1})`,
@@ -445,7 +434,6 @@ export const xAIProvider: ProviderConfig = {
duration: thisModelTime,
})
// Add to model time
modelTime += thisModelTime
if (currentResponse.choices[0]?.message?.content) {
@@ -466,14 +454,10 @@ export const xAIProvider: ProviderConfig = {
iterationCount,
})
}
// After all tool processing complete, if streaming was requested, use streaming for the final response
if (request.stream) {
// For final streaming response, choose between tools (auto) or response_format (never both)
let finalStreamingPayload: any
if (request.responseFormat) {
// Use response format, no tools
finalStreamingPayload = {
...createResponseFormatPayload(
basePayload,
@@ -484,7 +468,6 @@ export const xAIProvider: ProviderConfig = {
stream: true,
}
} else {
// Use tools with auto choice
finalStreamingPayload = {
...basePayload,
messages: currentMessages,
@@ -494,15 +477,34 @@ export const xAIProvider: ProviderConfig = {
}
}
const streamResponse = await xai.chat.completions.create(finalStreamingPayload)
const streamResponse = await xai.chat.completions.create(finalStreamingPayload as any)
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
// Create a StreamingExecution response with all collected data
const streamingResult = {
stream: createReadableStreamFromXAIStream(streamResponse),
stream: createReadableStreamFromXAIStream(streamResponse as any, (content, usage) => {
streamingResult.execution.output.content = content
streamingResult.execution.output.tokens = {
prompt: tokens.prompt + usage.prompt_tokens,
completion: tokens.completion + usage.completion_tokens,
total: tokens.total + usage.total_tokens,
}
const streamCost = calculateCost(
request.model,
usage.prompt_tokens,
usage.completion_tokens
)
streamingResult.execution.output.cost = {
input: accumulatedCost.input + streamCost.input,
output: accumulatedCost.output + streamCost.output,
total: accumulatedCost.total + streamCost.total,
}
}),
execution: {
success: true,
output: {
content: '', // Will be filled by the callback
content: '',
model: request.model || 'grok-3-latest',
tokens: {
prompt: tokens.prompt,
@@ -527,12 +529,12 @@ export const xAIProvider: ProviderConfig = {
timeSegments: timeSegments,
},
cost: {
total: (tokens.total || 0) * 0.0001,
input: (tokens.prompt || 0) * 0.0001,
output: (tokens.completion || 0) * 0.0001,
input: accumulatedCost.input,
output: accumulatedCost.output,
total: accumulatedCost.total,
},
},
logs: [], // No block logs at provider level
logs: [],
metadata: {
startTime: providerStartTimeISO,
endTime: new Date().toISOString(),
@@ -542,11 +544,8 @@ export const xAIProvider: ProviderConfig = {
},
}
// Return the streaming execution object
return streamingResult as StreamingExecution
}
// Calculate overall timing
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -577,7 +576,6 @@ export const xAIProvider: ProviderConfig = {
},
}
} catch (error) {
// Include timing information even for errors
const providerEndTime = Date.now()
const providerEndTimeISO = new Date(providerEndTime).toISOString()
const totalDuration = providerEndTime - providerStartTime
@@ -589,9 +587,8 @@ export const xAIProvider: ProviderConfig = {
hasResponseFormat: !!request.responseFormat,
})
// Create a new error with timing information
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
// @ts-ignore - Adding timing property to the error
// @ts-ignore - Adding timing property to error for debugging
enhancedError.timing = {
startTime: providerStartTimeISO,
endTime: providerEndTimeISO,

View File

@@ -1,32 +1,20 @@
import { createLogger } from '@/lib/logs/console/logger'
import { trackForcedToolUsage } from '@/providers/utils'
const logger = createLogger('XAIProvider')
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions'
import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils'
/**
* Helper to wrap XAI (OpenAI-compatible) streaming into a browser-friendly
* ReadableStream of raw assistant text chunks.
* Creates a ReadableStream from an xAI streaming response.
* Uses the shared OpenAI-compatible streaming utility.
*/
export function createReadableStreamFromXAIStream(xaiStream: any): ReadableStream {
return new ReadableStream({
async start(controller) {
try {
for await (const chunk of xaiStream) {
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
controller.enqueue(new TextEncoder().encode(content))
}
}
controller.close()
} catch (err) {
controller.error(err)
}
},
})
export function createReadableStreamFromXAIStream(
xaiStream: AsyncIterable<ChatCompletionChunk>,
onComplete?: (content: string, usage: CompletionUsage) => void
): ReadableStream<Uint8Array> {
return createOpenAICompatibleStream(xaiStream, 'xAI', onComplete)
}
/**
* Creates a response format payload for XAI API requests.
* Creates a response format payload for xAI requests with JSON schema.
*/
export function createResponseFormatPayload(
basePayload: any,
@@ -54,7 +42,8 @@ export function createResponseFormatPayload(
}
/**
* Helper function to check for forced tool usage in responses.
* Checks if a forced tool was used in an xAI response.
* Uses the shared OpenAI-compatible forced tool usage helper.
*/
export function checkForForcedToolUsage(
response: any,
@@ -62,22 +51,5 @@ export function checkForForcedToolUsage(
forcedTools: string[],
usedForcedTools: string[]
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
let hasUsedForcedTool = false
let updatedUsedForcedTools = usedForcedTools
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
const toolCallsResponse = response.choices[0].message.tool_calls
const result = trackForcedToolUsage(
toolCallsResponse,
toolChoice,
logger,
'xai',
forcedTools,
updatedUsedForcedTools
)
hasUsedForcedTool = result.hasUsedForcedTool
updatedUsedForcedTools = result.usedForcedTools
}
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
return checkForForcedToolUsageOpenAI(response, toolChoice, 'xAI', forcedTools, usedForcedTools)
}

View File

@@ -1,6 +1,6 @@
import { create } from 'zustand'
import { createLogger } from '@/lib/logs/console/logger'
import type { ProvidersStore } from '@/stores/providers/types'
import type { OpenRouterModelInfo, ProvidersStore } from '@/stores/providers/types'
const logger = createLogger('ProvidersStore')
@@ -11,6 +11,7 @@ export const useProvidersStore = create<ProvidersStore>((set, get) => ({
vllm: { models: [], isLoading: false },
openrouter: { models: [], isLoading: false },
},
openRouterModelInfo: {},
setProviderModels: (provider, models) => {
logger.info(`Updated ${provider} models`, { count: models.length })
@@ -37,7 +38,22 @@ export const useProvidersStore = create<ProvidersStore>((set, get) => ({
}))
},
setOpenRouterModelInfo: (modelInfo: Record<string, OpenRouterModelInfo>) => {
const structuredOutputCount = Object.values(modelInfo).filter(
(m) => m.supportsStructuredOutputs
).length
logger.info('Updated OpenRouter model info', {
count: Object.keys(modelInfo).length,
withStructuredOutputs: structuredOutputCount,
})
set({ openRouterModelInfo: modelInfo })
},
getProvider: (provider) => {
return get().providers[provider]
},
getOpenRouterModelInfo: (modelId: string) => {
return get().openRouterModelInfo[modelId]
},
}))

View File

@@ -1,5 +1,16 @@
export type ProviderName = 'ollama' | 'vllm' | 'openrouter' | 'base'
export interface OpenRouterModelInfo {
id: string
contextLength?: number
supportsStructuredOutputs?: boolean
supportsTools?: boolean
pricing?: {
input: number
output: number
}
}
export interface ProviderState {
models: string[]
isLoading: boolean
@@ -7,7 +18,10 @@ export interface ProviderState {
export interface ProvidersStore {
providers: Record<ProviderName, ProviderState>
openRouterModelInfo: Record<string, OpenRouterModelInfo>
setProviderModels: (provider: ProviderName, models: string[]) => void
setProviderLoading: (provider: ProviderName, isLoading: boolean) => void
setOpenRouterModelInfo: (modelInfo: Record<string, OpenRouterModelInfo>) => void
getProvider: (provider: ProviderName) => ProviderState
getOpenRouterModelInfo: (modelId: string) => OpenRouterModelInfo | undefined
}

View File

@@ -13,7 +13,7 @@ export const memoryAddTool: ToolConfig<any, MemoryResponse> = {
type: 'string',
required: false,
description:
'Conversation identifier (e.g., user-123, session-abc). If a memory with this conversationId already exists for this block, the new message will be appended to it.',
'Conversation identifier (e.g., user-123, session-abc). If a memory with this conversationId already exists, the new message will be appended to it.',
},
id: {
type: 'string',
@@ -31,12 +31,6 @@ export const memoryAddTool: ToolConfig<any, MemoryResponse> = {
required: true,
description: 'Content for agent memory',
},
blockId: {
type: 'string',
required: false,
description:
'Optional block ID. If not provided, uses the current block ID from execution context, or defaults to "default".',
},
},
request: {
@@ -46,29 +40,24 @@ export const memoryAddTool: ToolConfig<any, MemoryResponse> = {
'Content-Type': 'application/json',
}),
body: (params) => {
const workflowId = params._context?.workflowId
const contextBlockId = params._context?.blockId
const workspaceId = params._context?.workspaceId
if (!workflowId) {
if (!workspaceId) {
return {
_errorResponse: {
status: 400,
data: {
success: false,
error: {
message: 'workflowId is required and must be provided in execution context',
message: 'workspaceId is required and must be provided in execution context',
},
},
},
}
}
// Use 'id' as fallback for 'conversationId' for backwards compatibility
const conversationId = params.conversationId || params.id
// Default blockId to 'default' if not provided in params or context
const blockId = params.blockId || contextBlockId || 'default'
if (!conversationId || conversationId.trim() === '') {
return {
_errorResponse: {
@@ -97,11 +86,11 @@ export const memoryAddTool: ToolConfig<any, MemoryResponse> = {
}
}
const key = buildMemoryKey(conversationId, blockId)
const key = buildMemoryKey(conversationId)
const body: Record<string, any> = {
key,
workflowId,
workspaceId,
data: {
role: params.role,
content: params.content,

View File

@@ -4,8 +4,7 @@ import type { ToolConfig } from '@/tools/types'
export const memoryDeleteTool: ToolConfig<any, MemoryResponse> = {
id: 'memory_delete',
name: 'Delete Memory',
description:
'Delete memories by conversationId, blockId, blockName, or a combination. Supports bulk deletion.',
description: 'Delete memories by conversationId.',
version: '1.0.0',
params: {
@@ -13,7 +12,7 @@ export const memoryDeleteTool: ToolConfig<any, MemoryResponse> = {
type: 'string',
required: false,
description:
'Conversation identifier (e.g., user-123, session-abc). If provided alone, deletes all memories for this conversation across all blocks.',
'Conversation identifier (e.g., user-123, session-abc). Deletes all memories for this conversation.',
},
id: {
type: 'string',
@@ -21,50 +20,36 @@ export const memoryDeleteTool: ToolConfig<any, MemoryResponse> = {
description:
'Legacy parameter for conversation identifier. Use conversationId instead. Provided for backwards compatibility.',
},
blockId: {
type: 'string',
required: false,
description:
'Block identifier. If provided alone, deletes all memories for this block across all conversations. If provided with conversationId, deletes memories for that specific conversation in this block.',
},
blockName: {
type: 'string',
required: false,
description:
'Block name. Alternative to blockId. If provided alone, deletes all memories for blocks with this name. If provided with conversationId, deletes memories for that conversation in blocks with this name.',
},
},
request: {
url: (params): any => {
const workflowId = params._context?.workflowId
const workspaceId = params._context?.workspaceId
if (!workflowId) {
if (!workspaceId) {
return {
_errorResponse: {
status: 400,
data: {
success: false,
error: {
message: 'workflowId is required and must be provided in execution context',
message: 'workspaceId is required and must be provided in execution context',
},
},
},
}
}
// Use 'id' as fallback for 'conversationId' for backwards compatibility
const conversationId = params.conversationId || params.id
if (!conversationId && !params.blockId && !params.blockName) {
if (!conversationId) {
return {
_errorResponse: {
status: 400,
data: {
success: false,
error: {
message:
'At least one of conversationId, id, blockId, or blockName must be provided',
message: 'conversationId or id must be provided',
},
},
},
@@ -72,17 +57,8 @@ export const memoryDeleteTool: ToolConfig<any, MemoryResponse> = {
}
const url = new URL('/api/memory', 'http://dummy')
url.searchParams.set('workflowId', workflowId)
if (conversationId) {
url.searchParams.set('workspaceId', workspaceId)
url.searchParams.set('conversationId', conversationId)
}
if (params.blockId) {
url.searchParams.set('blockId', params.blockId)
}
if (params.blockName) {
url.searchParams.set('blockName', params.blockName)
}
return url.pathname + url.search
},

View File

@@ -5,8 +5,7 @@ import type { ToolConfig } from '@/tools/types'
export const memoryGetTool: ToolConfig<any, MemoryResponse> = {
id: 'memory_get',
name: 'Get Memory',
description:
'Retrieve memory by conversationId, blockId, blockName, or a combination. Returns all matching memories.',
description: 'Retrieve memory by conversationId. Returns matching memories.',
version: '1.0.0',
params: {
@@ -14,7 +13,7 @@ export const memoryGetTool: ToolConfig<any, MemoryResponse> = {
type: 'string',
required: false,
description:
'Conversation identifier (e.g., user-123, session-abc). If provided alone, returns all memories for this conversation across all blocks.',
'Conversation identifier (e.g., user-123, session-abc). Returns memories for this conversation.',
},
id: {
type: 'string',
@@ -22,75 +21,47 @@ export const memoryGetTool: ToolConfig<any, MemoryResponse> = {
description:
'Legacy parameter for conversation identifier. Use conversationId instead. Provided for backwards compatibility.',
},
blockId: {
type: 'string',
required: false,
description:
'Block identifier. If provided alone, returns all memories for this block across all conversations. If provided with conversationId, returns memories for that specific conversation in this block.',
},
blockName: {
type: 'string',
required: false,
description:
'Block name. Alternative to blockId. If provided alone, returns all memories for blocks with this name. If provided with conversationId, returns memories for that conversation in blocks with this name.',
},
},
request: {
url: (params): any => {
const workflowId = params._context?.workflowId
const workspaceId = params._context?.workspaceId
if (!workflowId) {
if (!workspaceId) {
return {
_errorResponse: {
status: 400,
data: {
success: false,
error: {
message: 'workflowId is required and must be provided in execution context',
message: 'workspaceId is required and must be provided in execution context',
},
},
},
}
}
// Use 'id' as fallback for 'conversationId' for backwards compatibility
const conversationId = params.conversationId || params.id
if (!conversationId && !params.blockId && !params.blockName) {
if (!conversationId) {
return {
_errorResponse: {
status: 400,
data: {
success: false,
error: {
message:
'At least one of conversationId, id, blockId, or blockName must be provided',
message: 'conversationId or id must be provided',
},
},
},
}
}
let query = ''
if (conversationId && params.blockId) {
query = buildMemoryKey(conversationId, params.blockId)
} else if (conversationId) {
// Also check for legacy format (conversationId without blockId)
query = `${conversationId}:`
} else if (params.blockId) {
query = `:${params.blockId}`
}
const query = buildMemoryKey(conversationId)
const url = new URL('/api/memory', 'http://dummy')
url.searchParams.set('workflowId', workflowId)
if (query) {
url.searchParams.set('workspaceId', workspaceId)
url.searchParams.set('query', query)
}
if (params.blockName) {
url.searchParams.set('blockName', params.blockName)
}
url.searchParams.set('limit', '1000')
return url.pathname + url.search
@@ -128,8 +99,7 @@ export const memoryGetTool: ToolConfig<any, MemoryResponse> = {
success: { type: 'boolean', description: 'Whether the memory was retrieved successfully' },
memories: {
type: 'array',
description:
'Array of memory objects with conversationId, blockId, blockName, and data fields',
description: 'Array of memory objects with conversationId and data fields',
},
message: { type: 'string', description: 'Success or error message' },
error: { type: 'string', description: 'Error message if operation failed' },

View File

@@ -11,23 +11,23 @@ export const memoryGetAllTool: ToolConfig<any, MemoryResponse> = {
request: {
url: (params): any => {
const workflowId = params._context?.workflowId
const workspaceId = params._context?.workspaceId
if (!workflowId) {
if (!workspaceId) {
return {
_errorResponse: {
status: 400,
data: {
success: false,
error: {
message: 'workflowId is required and must be provided in execution context',
message: 'workspaceId is required and must be provided in execution context',
},
},
},
}
}
return `/api/memory?workflowId=${encodeURIComponent(workflowId)}`
return `/api/memory?workspaceId=${encodeURIComponent(workspaceId)}`
},
method: 'GET',
headers: () => ({
@@ -64,8 +64,7 @@ export const memoryGetAllTool: ToolConfig<any, MemoryResponse> = {
success: { type: 'boolean', description: 'Whether all memories were retrieved successfully' },
memories: {
type: 'array',
description:
'Array of all memory objects with key, conversationId, blockId, blockName, and data fields',
description: 'Array of all memory objects with key, conversationId, and data fields',
},
message: { type: 'string', description: 'Success or error message' },
error: { type: 'string', description: 'Error message if operation failed' },

View File

@@ -1,45 +1,25 @@
/**
* Parse memory key into conversationId and blockId
* Supports two formats:
* - New format: conversationId:blockId (splits on LAST colon to handle IDs with colons)
* - Legacy format: id (without colon, treated as conversationId with blockId='default')
* @param key The memory key to parse
* @returns Object with conversationId and blockId, or null if invalid
* Parse memory key to extract conversationId
* Memory is now thread-scoped, so the key is just the conversationId
* @param key The memory key (conversationId)
* @returns Object with conversationId, or null if invalid
*/
export function parseMemoryKey(key: string): { conversationId: string; blockId: string } | null {
export function parseMemoryKey(key: string): { conversationId: string } | null {
if (!key) {
return null
}
const lastColonIndex = key.lastIndexOf(':')
// Legacy format: no colon found
if (lastColonIndex === -1) {
return {
conversationId: key,
blockId: 'default',
}
}
// Invalid: colon at start or end
if (lastColonIndex === 0 || lastColonIndex === key.length - 1) {
return null
}
// New format: split on last colon to handle IDs with colons
// This allows conversationIds like "user:123" to work correctly
return {
conversationId: key.substring(0, lastColonIndex),
blockId: key.substring(lastColonIndex + 1),
}
}
/**
* Build memory key from conversationId and blockId
* Build memory key from conversationId
* Memory is thread-scoped, so key is just the conversationId
* @param conversationId The conversation ID
* @param blockId The block ID
* @returns The memory key in format conversationId:blockId
* @returns The memory key (same as conversationId)
*/
export function buildMemoryKey(conversationId: string, blockId: string): string {
return `${conversationId}:${blockId}`
export function buildMemoryKey(conversationId: string): string {
return conversationId
}

View File

@@ -16,8 +16,6 @@ export interface MemoryRecord {
id: string
key: string
conversationId: string
blockId: string
blockName: string
data: AgentMemoryData[]
createdAt: string
updatedAt: string

View File

@@ -37,7 +37,7 @@
"drizzle-kit": "^0.31.4",
"husky": "9.1.7",
"lint-staged": "16.0.0",
"turbo": "2.7.0",
"turbo": "2.7.1",
},
},
"apps/docs": {
@@ -88,6 +88,7 @@
"@browserbasehq/stagehand": "^3.0.5",
"@cerebras/cerebras_cloud_sdk": "^1.23.0",
"@e2b/code-interpreter": "^2.0.0",
"@google/genai": "1.34.0",
"@hookform/resolvers": "^4.1.3",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/exporter-jaeger": "2.1.0",
@@ -627,7 +628,7 @@
"@google-cloud/precise-date": ["@google-cloud/precise-date@4.0.0", "", {}, "sha512-1TUx3KdaU3cN7nfCdNf+UVqA/PSX29Cjcox3fZZBtINlRrXVTmUkQnCKv2MbBUbCopbK4olAT1IHl76uZyCiVA=="],
"@google/genai": ["@google/genai@1.33.0", "", { "dependencies": { "google-auth-library": "^10.3.0", "ws": "^8.18.0" }, "peerDependencies": { "@modelcontextprotocol/sdk": "^1.24.0" }, "optionalPeers": ["@modelcontextprotocol/sdk"] }, "sha512-ThUjFZ1N0DU88peFjnQkb8K198EWaW2RmmnDShFQ+O+xkIH9itjpRe358x3L/b4X/A7dimkvq63oz49Vbh7Cog=="],
"@google/genai": ["@google/genai@1.34.0", "", { "dependencies": { "google-auth-library": "^10.3.0", "ws": "^8.18.0" }, "peerDependencies": { "@modelcontextprotocol/sdk": "^1.24.0" }, "optionalPeers": ["@modelcontextprotocol/sdk"] }, "sha512-vu53UMPvjmb7PGzlYu6Tzxso8Dfhn+a7eQFaS2uNemVtDZKwzSpJ5+ikqBbXplF7RGB1STcVDqCkPvquiwb2sw=="],
"@graphql-typed-document-node/core": ["@graphql-typed-document-node/core@3.2.0", "", { "peerDependencies": { "graphql": "^0.8.0 || ^0.9.0 || ^0.10.0 || ^0.11.0 || ^0.12.0 || ^0.13.0 || ^14.0.0 || ^15.0.0 || ^16.0.0 || ^17.0.0" } }, "sha512-mB9oAsNCm9aM3/SOv4YtBMqZbYj10R7dkq8byBqxGY/ncFwhf2oQzMV+LCRlWoDSEBJ3COiR1yeDvMtsoOsuFQ=="],
@@ -3303,19 +3304,19 @@
"tunnel-agent": ["tunnel-agent@0.6.0", "", { "dependencies": { "safe-buffer": "^5.0.1" } }, "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w=="],
"turbo": ["turbo@2.7.0", "", { "optionalDependencies": { "turbo-darwin-64": "2.7.0", "turbo-darwin-arm64": "2.7.0", "turbo-linux-64": "2.7.0", "turbo-linux-arm64": "2.7.0", "turbo-windows-64": "2.7.0", "turbo-windows-arm64": "2.7.0" }, "bin": { "turbo": "bin/turbo" } }, "sha512-1dUGwi6cSSVZts1BwJa/Gh7w5dPNNGsNWZEAuRKxXWME44hTKWpQZrgiPnqMc5jJJOovzPK5N6tL+PHYRYL5Wg=="],
"turbo": ["turbo@2.7.1", "", { "optionalDependencies": { "turbo-darwin-64": "2.7.1", "turbo-darwin-arm64": "2.7.1", "turbo-linux-64": "2.7.1", "turbo-linux-arm64": "2.7.1", "turbo-windows-64": "2.7.1", "turbo-windows-arm64": "2.7.1" }, "bin": { "turbo": "bin/turbo" } }, "sha512-zAj9jGc7VDvuAo/5Jbos4QTtWz9uUpkMhMKGyTjDJkx//hdL2bM31qQoJSAbU+7JyK5vb0LPzpwf6DUt3zayqg=="],
"turbo-darwin-64": ["turbo-darwin-64@2.7.0", "", { "os": "darwin", "cpu": "x64" }, "sha512-gwqL7cJOSYrV/jNmhXM8a2uzSFn7GcUASOuen6OgmUsafUj9SSWcgXZ/q0w9hRoL917hpidkdI//UpbxbZbwwg=="],
"turbo-darwin-64": ["turbo-darwin-64@2.7.1", "", { "os": "darwin", "cpu": "x64" }, "sha512-EaA7UfYujbY9/Ku0WqPpvfctxm91h9LF7zo8vjielz+omfAPB54Si+ADmUoBczBDC6RoLgbURC3GmUW2alnjJg=="],
"turbo-darwin-arm64": ["turbo-darwin-arm64@2.7.0", "", { "os": "darwin", "cpu": "arm64" }, "sha512-f3F5DYOnfE6lR6v/rSld7QGZgartKsnlIYY7jcF/AA7Wz27za9XjxMHzb+3i4pvRhAkryFgf2TNq7eCFrzyTpg=="],
"turbo-darwin-arm64": ["turbo-darwin-arm64@2.7.1", "", { "os": "darwin", "cpu": "arm64" }, "sha512-/pWGSygtBugd7sKQOeMm+jKY3qN1vyB0RiHBM6bN/6qUOo2VHo8IQwBTIaSgINN4Ue6fzEU+WfePNvonSU9yXw=="],
"turbo-linux-64": ["turbo-linux-64@2.7.0", "", { "os": "linux", "cpu": "x64" }, "sha512-KsC+UuKlhjCL+lom10/IYoxUsdhJOsuEki72YSr7WGYUSRihcdJQnaUyIDTlm0nPOb+gVihVNBuVP4KsNg1UnA=="],
"turbo-linux-64": ["turbo-linux-64@2.7.1", "", { "os": "linux", "cpu": "x64" }, "sha512-Y5H11mdhASw/dJuRFyGtTCDFX5/MPT73EKsVEiHbw5MkFc77lx3nMc5L/Q7bKEhef/vYJAsAb61QuHsB6qdP8Q=="],
"turbo-linux-arm64": ["turbo-linux-arm64@2.7.0", "", { "os": "linux", "cpu": "arm64" }, "sha512-1tjIYULeJtpmE/ovoI9qPBFJCtUEM7mYfeIMOIs4bXR6t/8u+rHPwr3j+vRHcXanIc42V1n3Pz52VqmJtIAviw=="],
"turbo-linux-arm64": ["turbo-linux-arm64@2.7.1", "", { "os": "linux", "cpu": "arm64" }, "sha512-L/r77jD7cqIEXoyu2LGBUrTY5GJSi/XcGLsQ2nZ/fefk6x3MpljTvwsXUVG1BUkiBPc4zaKRj6yGyWMo5MbLxQ=="],
"turbo-windows-64": ["turbo-windows-64@2.7.0", "", { "os": "win32", "cpu": "x64" }, "sha512-KThkAeax46XiH+qICCQm7R8V2pPdeTTP7ArCSRrSLqnlO75ftNm8Ljx4VAllwIZkILrq/GDM8PlyhZdPeUdDxQ=="],
"turbo-windows-64": ["turbo-windows-64@2.7.1", "", { "os": "win32", "cpu": "x64" }, "sha512-rkeuviXZ/1F7lCare7TNKvYtT/SH9dZR55FAMrxrFRh88b+ZKwlXEBfq5/1OctEzRUo/VLIm+s5LJMOEy+QshA=="],
"turbo-windows-arm64": ["turbo-windows-arm64@2.7.0", "", { "os": "win32", "cpu": "arm64" }, "sha512-kzI6rsQ3Ejs+CkM9HEEP3Z4h5YMCRxwIlQXFQmgXSG3BIgorCkRF2Xr7iQ2i9AGwY/6jbiAYeJbvi3yCp+noFw=="],
"turbo-windows-arm64": ["turbo-windows-arm64@2.7.1", "", { "os": "win32", "cpu": "arm64" }, "sha512-1rZk9htm3+iP/rWCf/h4/DFQey9sMs2TJPC4T5QQfwqAdMWsphgrxBuFqHdxczlbBCgbWNhVw0CH2bTxe1/GFg=="],
"tweetnacl": ["tweetnacl@0.14.5", "", {}, "sha512-KXXFFdAbFXY4geFIwoyNK+f5Z1b7swfXABfL7HXCmoIWMKU3dmS26672A4EeQtDzLKy7SXmfBu51JolvEKwtGA=="],
@@ -3561,6 +3562,8 @@
"@browserbasehq/sdk/node-fetch": ["node-fetch@2.7.0", "", { "dependencies": { "whatwg-url": "^5.0.0" }, "peerDependencies": { "encoding": "^0.1.0" }, "optionalPeers": ["encoding"] }, "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A=="],
"@browserbasehq/stagehand/@google/genai": ["@google/genai@1.33.0", "", { "dependencies": { "google-auth-library": "^10.3.0", "ws": "^8.18.0" }, "peerDependencies": { "@modelcontextprotocol/sdk": "^1.24.0" }, "optionalPeers": ["@modelcontextprotocol/sdk"] }, "sha512-ThUjFZ1N0DU88peFjnQkb8K198EWaW2RmmnDShFQ+O+xkIH9itjpRe358x3L/b4X/A7dimkvq63oz49Vbh7Cog=="],
"@cerebras/cerebras_cloud_sdk/@types/node": ["@types/node@18.19.130", "", { "dependencies": { "undici-types": "~5.26.4" } }, "sha512-GRaXQx6jGfL8sKfaIDD6OupbIHBr9jv7Jnaml9tB7l4v068PAOXqfcujMMo5PhbIs6ggR1XODELqahT2R8v0fg=="],
"@cerebras/cerebras_cloud_sdk/node-fetch": ["node-fetch@2.7.0", "", { "dependencies": { "whatwg-url": "^5.0.0" }, "peerDependencies": { "encoding": "^0.1.0" }, "optionalPeers": ["encoding"] }, "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A=="],

View File

@@ -66,7 +66,7 @@
"drizzle-kit": "^0.31.4",
"husky": "9.1.7",
"lint-staged": "16.0.0",
"turbo": "2.7.0"
"turbo": "2.7.1"
},
"lint-staged": {
"*.{js,jsx,ts,tsx,json,css,scss}": [

View File

@@ -0,0 +1,32 @@
-- Step 1: Add workspace_id as nullable first
ALTER TABLE "memory" ADD COLUMN "workspace_id" text;
-- Step 2: Backfill workspace_id from workflow's workspace_id
UPDATE memory m
SET workspace_id = w.workspace_id
FROM workflow w
WHERE m.workflow_id = w.id
AND w.workspace_id IS NOT NULL;
-- Step 3: Delete rows where workspace_id couldn't be resolved
DELETE FROM memory WHERE workspace_id IS NULL;
-- Step 4: Now make workspace_id NOT NULL
ALTER TABLE "memory" ALTER COLUMN "workspace_id" SET NOT NULL;
-- Step 5: Drop old constraint and indexes
ALTER TABLE "memory" DROP CONSTRAINT IF EXISTS "memory_workflow_id_workflow_id_fk";
--> statement-breakpoint
DROP INDEX IF EXISTS "memory_workflow_idx";
--> statement-breakpoint
DROP INDEX IF EXISTS "memory_workflow_key_idx";
-- Step 6: Add new foreign key and indexes
ALTER TABLE "memory" ADD CONSTRAINT "memory_workspace_id_workspace_id_fk" FOREIGN KEY ("workspace_id") REFERENCES "public"."workspace"("id") ON DELETE cascade ON UPDATE no action;
--> statement-breakpoint
CREATE INDEX "memory_workspace_idx" ON "memory" USING btree ("workspace_id");
--> statement-breakpoint
CREATE UNIQUE INDEX "memory_workspace_key_idx" ON "memory" USING btree ("workspace_id","key");
-- Step 7: Drop old column
ALTER TABLE "memory" DROP COLUMN IF EXISTS "workflow_id";

File diff suppressed because it is too large Load Diff

View File

@@ -904,6 +904,13 @@
"when": 1766275541149,
"tag": "0129_stormy_nightmare",
"breakpoints": true
},
{
"idx": 130,
"version": "7",
"when": 1766433914366,
"tag": "0130_bored_master_chief",
"breakpoints": true
}
]
}

View File

@@ -962,24 +962,21 @@ export const memory = pgTable(
'memory',
{
id: text('id').primaryKey(),
workflowId: text('workflow_id').references(() => workflow.id, { onDelete: 'cascade' }),
key: text('key').notNull(), // Conversation ID provided by user with format: conversationId:blockId
data: jsonb('data').notNull(), // Stores agent messages as array of {role, content} objects
workspaceId: text('workspace_id')
.notNull()
.references(() => workspace.id, { onDelete: 'cascade' }),
key: text('key').notNull(),
data: jsonb('data').notNull(),
createdAt: timestamp('created_at').notNull().defaultNow(),
updatedAt: timestamp('updated_at').notNull().defaultNow(),
deletedAt: timestamp('deleted_at'),
},
(table) => {
return {
// Add index on key for faster lookups
keyIdx: index('memory_key_idx').on(table.key),
// Add index on workflowId for faster filtering
workflowIdx: index('memory_workflow_idx').on(table.workflowId),
// Compound unique index to ensure keys are unique per workflow
uniqueKeyPerWorkflowIdx: uniqueIndex('memory_workflow_key_idx').on(
table.workflowId,
workspaceIdx: index('memory_workspace_idx').on(table.workspaceId),
uniqueKeyPerWorkspaceIdx: uniqueIndex('memory_workspace_key_idx').on(
table.workspaceId,
table.key
),
}