mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-09 15:07:55 -05:00
fix(models): memory fixes, provider code typing, cost calculation cleanup (#2515)
* improvement(memory): should not be block scoped * cleanup provider code * update other providers * cleanup fallback code * remove flaky test * fix memory * move streaming fix to right level * cleanup streaming server * make memories workspace scoped * update docs * fix dedup logic * fix streaming parsing issue for multiple onStream calls for same block * fix(provieders): support parallel agent tool calls, consolidate utils * address greptile comments * remove all comments * fixed openrouter response format handling, groq & cerebras response formats * removed duplicate type --------- Co-authored-by: waleed <walif6@gmail.com>
This commit is contained in:
committed by
GitHub
parent
086982c7a3
commit
8c2c49eb14
@@ -1,6 +1,6 @@
|
||||
---
|
||||
title: Memory
|
||||
description: Add memory store
|
||||
description: Store and retrieve conversation history
|
||||
---
|
||||
|
||||
import { BlockInfoCard } from "@/components/ui/block-info-card"
|
||||
@@ -10,100 +10,94 @@ import { BlockInfoCard } from "@/components/ui/block-info-card"
|
||||
color="#F64F9E"
|
||||
/>
|
||||
|
||||
## Usage Instructions
|
||||
|
||||
Integrate Memory into the workflow. Can add, get a memory, get all memories, and delete memories.
|
||||
## Overview
|
||||
|
||||
The Memory block stores conversation history for agents. Each memory is identified by a `conversationId` that you provide. Multiple agents can share the same memory by using the same `conversationId`.
|
||||
|
||||
Memory stores only user and assistant messages. System messages are not stored—they are configured in the Agent block and prefixed at runtime.
|
||||
|
||||
## Tools
|
||||
|
||||
### `memory_add`
|
||||
|
||||
Add a new memory to the database or append to existing memory with the same ID.
|
||||
Add a message to memory. Creates a new memory if the `conversationId` doesn't exist, or appends to existing memory.
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `conversationId` | string | No | Conversation identifier \(e.g., user-123, session-abc\). If a memory with this conversationId already exists for this block, the new message will be appended to it. |
|
||||
| `id` | string | No | Legacy parameter for conversation identifier. Use conversationId instead. Provided for backwards compatibility. |
|
||||
| `role` | string | Yes | Role for agent memory \(user, assistant, or system\) |
|
||||
| `content` | string | Yes | Content for agent memory |
|
||||
| `blockId` | string | No | Optional block ID. If not provided, uses the current block ID from execution context, or defaults to "default". |
|
||||
| `conversationId` | string | Yes | Unique identifier for the conversation (e.g., `user-123`, `session-abc`) |
|
||||
| `role` | string | Yes | Message role: `user` or `assistant` |
|
||||
| `content` | string | Yes | Message content |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `success` | boolean | Whether the memory was added successfully |
|
||||
| `memories` | array | Array of memory objects including the new or updated memory |
|
||||
| `error` | string | Error message if operation failed |
|
||||
| `success` | boolean | Whether the operation succeeded |
|
||||
| `memories` | array | Updated memory array |
|
||||
| `error` | string | Error message if failed |
|
||||
|
||||
### `memory_get`
|
||||
|
||||
Retrieve memory by conversationId, blockId, blockName, or a combination. Returns all matching memories.
|
||||
Retrieve memory by conversation ID.
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `conversationId` | string | No | Conversation identifier \(e.g., user-123, session-abc\). If provided alone, returns all memories for this conversation across all blocks. |
|
||||
| `id` | string | No | Legacy parameter for conversation identifier. Use conversationId instead. Provided for backwards compatibility. |
|
||||
| `blockId` | string | No | Block identifier. If provided alone, returns all memories for this block across all conversations. If provided with conversationId, returns memories for that specific conversation in this block. |
|
||||
| `blockName` | string | No | Block name. Alternative to blockId. If provided alone, returns all memories for blocks with this name. If provided with conversationId, returns memories for that conversation in blocks with this name. |
|
||||
| `conversationId` | string | Yes | Conversation identifier |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `success` | boolean | Whether the memory was retrieved successfully |
|
||||
| `memories` | array | Array of memory objects with conversationId, blockId, blockName, and data fields |
|
||||
| `message` | string | Success or error message |
|
||||
| `error` | string | Error message if operation failed |
|
||||
| `success` | boolean | Whether the operation succeeded |
|
||||
| `memories` | array | Array of messages with `role` and `content` |
|
||||
| `error` | string | Error message if failed |
|
||||
|
||||
### `memory_get_all`
|
||||
|
||||
Retrieve all memories from the database
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
Retrieve all memories for the current workspace.
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `success` | boolean | Whether all memories were retrieved successfully |
|
||||
| `memories` | array | Array of all memory objects with key, conversationId, blockId, blockName, and data fields |
|
||||
| `message` | string | Success or error message |
|
||||
| `error` | string | Error message if operation failed |
|
||||
| `success` | boolean | Whether the operation succeeded |
|
||||
| `memories` | array | All memory objects with `conversationId` and `data` fields |
|
||||
| `error` | string | Error message if failed |
|
||||
|
||||
### `memory_delete`
|
||||
|
||||
Delete memories by conversationId, blockId, blockName, or a combination. Supports bulk deletion.
|
||||
Delete memory by conversation ID.
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `conversationId` | string | No | Conversation identifier \(e.g., user-123, session-abc\). If provided alone, deletes all memories for this conversation across all blocks. |
|
||||
| `id` | string | No | Legacy parameter for conversation identifier. Use conversationId instead. Provided for backwards compatibility. |
|
||||
| `blockId` | string | No | Block identifier. If provided alone, deletes all memories for this block across all conversations. If provided with conversationId, deletes memories for that specific conversation in this block. |
|
||||
| `blockName` | string | No | Block name. Alternative to blockId. If provided alone, deletes all memories for blocks with this name. If provided with conversationId, deletes memories for that conversation in blocks with this name. |
|
||||
| `conversationId` | string | Yes | Conversation identifier to delete |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `success` | boolean | Whether the memory was deleted successfully |
|
||||
| `message` | string | Success or error message |
|
||||
| `error` | string | Error message if operation failed |
|
||||
| `success` | boolean | Whether the operation succeeded |
|
||||
| `message` | string | Confirmation message |
|
||||
| `error` | string | Error message if failed |
|
||||
|
||||
## Agent Memory Types
|
||||
|
||||
When using memory with an Agent block, you can configure how conversation history is managed:
|
||||
|
||||
| Type | Description |
|
||||
| ---- | ----------- |
|
||||
| **Full Conversation** | Stores all messages, limited by model's context window (uses 90% to leave room for response) |
|
||||
| **Sliding Window (Messages)** | Keeps the last N messages (default: 10) |
|
||||
| **Sliding Window (Tokens)** | Keeps messages that fit within a token limit (default: 4000) |
|
||||
|
||||
## Notes
|
||||
|
||||
- Category: `blocks`
|
||||
- Type: `memory`
|
||||
- Memory is scoped per workspace—workflows in the same workspace share the memory store
|
||||
- Use unique `conversationId` values to keep conversations separate (e.g., session IDs, user IDs, or UUIDs)
|
||||
- System messages belong in the Agent block configuration, not in memory
|
||||
|
||||
@@ -70,19 +70,6 @@ vi.mock('@/lib/core/utils/request', () => ({
|
||||
generateRequestId: vi.fn().mockReturnValue('test-request-id'),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/api/workflows/[id]/execute/route', () => ({
|
||||
createFilteredResult: vi.fn().mockImplementation((result: any) => ({
|
||||
...result,
|
||||
logs: undefined,
|
||||
metadata: result.metadata
|
||||
? {
|
||||
...result.metadata,
|
||||
workflowConnections: undefined,
|
||||
}
|
||||
: undefined,
|
||||
})),
|
||||
}))
|
||||
|
||||
describe('Chat Identifier API Route', () => {
|
||||
const mockAddCorsHeaders = vi.fn().mockImplementation((response) => response)
|
||||
const mockValidateChatAuth = vi.fn().mockResolvedValue({ authorized: true })
|
||||
|
||||
@@ -206,7 +206,6 @@ export async function POST(
|
||||
|
||||
const { createStreamingResponse } = await import('@/lib/workflows/streaming/streaming')
|
||||
const { SSE_HEADERS } = await import('@/lib/core/utils/sse')
|
||||
const { createFilteredResult } = await import('@/app/api/workflows/[id]/execute/route')
|
||||
|
||||
const workflowInput: any = { input, conversationId }
|
||||
if (files && Array.isArray(files) && files.length > 0) {
|
||||
@@ -267,7 +266,6 @@ export async function POST(
|
||||
isSecureMode: true,
|
||||
workflowTriggerType: 'chat',
|
||||
},
|
||||
createFilteredResult,
|
||||
executionId,
|
||||
})
|
||||
|
||||
|
||||
@@ -1,54 +1,16 @@
|
||||
import { db } from '@sim/db'
|
||||
import { memory, workflowBlocks } from '@sim/db/schema'
|
||||
import { memory, permissions, workspace } from '@sim/db/schema'
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { checkHybridAuth } from '@/lib/auth/hybrid'
|
||||
import { generateRequestId } from '@/lib/core/utils/request'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getWorkflowAccessContext } from '@/lib/workflows/utils'
|
||||
|
||||
const logger = createLogger('MemoryByIdAPI')
|
||||
|
||||
/**
|
||||
* Parse memory key into conversationId and blockId
|
||||
* Key format: conversationId:blockId
|
||||
*/
|
||||
function parseMemoryKey(key: string): { conversationId: string; blockId: string } | null {
|
||||
const parts = key.split(':')
|
||||
if (parts.length !== 2) {
|
||||
return null
|
||||
}
|
||||
return {
|
||||
conversationId: parts[0],
|
||||
blockId: parts[1],
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Lookup block name from block ID
|
||||
*/
|
||||
async function getBlockName(blockId: string, workflowId: string): Promise<string | undefined> {
|
||||
try {
|
||||
const result = await db
|
||||
.select({ name: workflowBlocks.name })
|
||||
.from(workflowBlocks)
|
||||
.where(and(eq(workflowBlocks.id, blockId), eq(workflowBlocks.workflowId, workflowId)))
|
||||
.limit(1)
|
||||
|
||||
if (result.length === 0) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return result[0].name
|
||||
} catch (error) {
|
||||
logger.error('Error looking up block name', { error, blockId, workflowId })
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
const memoryQuerySchema = z.object({
|
||||
workflowId: z.string().uuid('Invalid workflow ID format'),
|
||||
workspaceId: z.string().uuid('Invalid workspace ID format'),
|
||||
})
|
||||
|
||||
const agentMemoryDataSchema = z.object({
|
||||
@@ -64,26 +26,56 @@ const memoryPutBodySchema = z.object({
|
||||
data: z.union([agentMemoryDataSchema, genericMemoryDataSchema], {
|
||||
errorMap: () => ({ message: 'Invalid memory data structure' }),
|
||||
}),
|
||||
workflowId: z.string().uuid('Invalid workflow ID format'),
|
||||
workspaceId: z.string().uuid('Invalid workspace ID format'),
|
||||
})
|
||||
|
||||
/**
|
||||
* Validates authentication and workflow access for memory operations
|
||||
* @param request - The incoming request
|
||||
* @param workflowId - The workflow ID to check access for
|
||||
* @param requestId - Request ID for logging
|
||||
* @param action - 'read' for GET, 'write' for PUT/DELETE
|
||||
* @returns Object with userId if successful, or error response if failed
|
||||
*/
|
||||
async function checkWorkspaceAccess(
|
||||
workspaceId: string,
|
||||
userId: string
|
||||
): Promise<{ hasAccess: boolean; canWrite: boolean }> {
|
||||
const [workspaceRow] = await db
|
||||
.select({ ownerId: workspace.ownerId })
|
||||
.from(workspace)
|
||||
.where(eq(workspace.id, workspaceId))
|
||||
.limit(1)
|
||||
|
||||
if (!workspaceRow) {
|
||||
return { hasAccess: false, canWrite: false }
|
||||
}
|
||||
|
||||
if (workspaceRow.ownerId === userId) {
|
||||
return { hasAccess: true, canWrite: true }
|
||||
}
|
||||
|
||||
const [permissionRow] = await db
|
||||
.select({ permissionType: permissions.permissionType })
|
||||
.from(permissions)
|
||||
.where(
|
||||
and(
|
||||
eq(permissions.userId, userId),
|
||||
eq(permissions.entityType, 'workspace'),
|
||||
eq(permissions.entityId, workspaceId)
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (!permissionRow) {
|
||||
return { hasAccess: false, canWrite: false }
|
||||
}
|
||||
|
||||
return {
|
||||
hasAccess: true,
|
||||
canWrite: permissionRow.permissionType === 'write' || permissionRow.permissionType === 'admin',
|
||||
}
|
||||
}
|
||||
|
||||
async function validateMemoryAccess(
|
||||
request: NextRequest,
|
||||
workflowId: string,
|
||||
workspaceId: string,
|
||||
requestId: string,
|
||||
action: 'read' | 'write'
|
||||
): Promise<{ userId: string } | { error: NextResponse }> {
|
||||
const authResult = await checkHybridAuth(request, {
|
||||
requireWorkflowId: false,
|
||||
})
|
||||
const authResult = await checkHybridAuth(request, { requireWorkflowId: false })
|
||||
if (!authResult.success || !authResult.userId) {
|
||||
logger.warn(`[${requestId}] Unauthorized memory ${action} attempt`)
|
||||
return {
|
||||
@@ -94,30 +86,20 @@ async function validateMemoryAccess(
|
||||
}
|
||||
}
|
||||
|
||||
const accessContext = await getWorkflowAccessContext(workflowId, authResult.userId)
|
||||
if (!accessContext) {
|
||||
logger.warn(`[${requestId}] Workflow ${workflowId} not found`)
|
||||
const { hasAccess, canWrite } = await checkWorkspaceAccess(workspaceId, authResult.userId)
|
||||
if (!hasAccess) {
|
||||
return {
|
||||
error: NextResponse.json(
|
||||
{ success: false, error: { message: 'Workflow not found' } },
|
||||
{ success: false, error: { message: 'Workspace not found' } },
|
||||
{ status: 404 }
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
const { isOwner, workspacePermission } = accessContext
|
||||
const hasAccess =
|
||||
action === 'read'
|
||||
? isOwner || workspacePermission !== null
|
||||
: isOwner || workspacePermission === 'write' || workspacePermission === 'admin'
|
||||
|
||||
if (!hasAccess) {
|
||||
logger.warn(
|
||||
`[${requestId}] User ${authResult.userId} denied ${action} access to workflow ${workflowId}`
|
||||
)
|
||||
if (action === 'write' && !canWrite) {
|
||||
return {
|
||||
error: NextResponse.json(
|
||||
{ success: false, error: { message: 'Access denied' } },
|
||||
{ success: false, error: { message: 'Write access denied' } },
|
||||
{ status: 403 }
|
||||
),
|
||||
}
|
||||
@@ -129,40 +111,28 @@ async function validateMemoryAccess(
|
||||
export const dynamic = 'force-dynamic'
|
||||
export const runtime = 'nodejs'
|
||||
|
||||
/**
|
||||
* GET handler for retrieving a specific memory by ID
|
||||
*/
|
||||
export async function GET(request: NextRequest, { params }: { params: Promise<{ id: string }> }) {
|
||||
const requestId = generateRequestId()
|
||||
const { id } = await params
|
||||
|
||||
try {
|
||||
logger.info(`[${requestId}] Processing memory get request for ID: ${id}`)
|
||||
|
||||
const url = new URL(request.url)
|
||||
const workflowId = url.searchParams.get('workflowId')
|
||||
|
||||
const validation = memoryQuerySchema.safeParse({ workflowId })
|
||||
const workspaceId = url.searchParams.get('workspaceId')
|
||||
|
||||
const validation = memoryQuerySchema.safeParse({ workspaceId })
|
||||
if (!validation.success) {
|
||||
const errorMessage = validation.error.errors
|
||||
.map((err) => `${err.path.join('.')}: ${err.message}`)
|
||||
.join(', ')
|
||||
logger.warn(`[${requestId}] Validation error: ${errorMessage}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: errorMessage,
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: errorMessage } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const { workflowId: validatedWorkflowId } = validation.data
|
||||
const { workspaceId: validatedWorkspaceId } = validation.data
|
||||
|
||||
const accessCheck = await validateMemoryAccess(request, validatedWorkflowId, requestId, 'read')
|
||||
const accessCheck = await validateMemoryAccess(request, validatedWorkspaceId, requestId, 'read')
|
||||
if ('error' in accessCheck) {
|
||||
return accessCheck.error
|
||||
}
|
||||
@@ -170,72 +140,33 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
|
||||
const memories = await db
|
||||
.select()
|
||||
.from(memory)
|
||||
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
|
||||
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
|
||||
.orderBy(memory.createdAt)
|
||||
.limit(1)
|
||||
|
||||
if (memories.length === 0) {
|
||||
logger.warn(`[${requestId}] Memory not found: ${id} for workflow: ${validatedWorkflowId}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Memory not found',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Memory not found' } },
|
||||
{ status: 404 }
|
||||
)
|
||||
}
|
||||
|
||||
const mem = memories[0]
|
||||
const parsed = parseMemoryKey(mem.key)
|
||||
|
||||
let enrichedMemory
|
||||
if (!parsed) {
|
||||
enrichedMemory = {
|
||||
conversationId: mem.key,
|
||||
blockId: 'unknown',
|
||||
blockName: 'unknown',
|
||||
data: mem.data,
|
||||
}
|
||||
} else {
|
||||
const { conversationId, blockId } = parsed
|
||||
const blockName = (await getBlockName(blockId, validatedWorkflowId)) || 'unknown'
|
||||
|
||||
enrichedMemory = {
|
||||
conversationId,
|
||||
blockId,
|
||||
blockName,
|
||||
data: mem.data,
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Memory retrieved successfully: ${id} for workflow: ${validatedWorkflowId}`
|
||||
)
|
||||
logger.info(`[${requestId}] Memory retrieved: ${id} for workspace: ${validatedWorkspaceId}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
data: enrichedMemory,
|
||||
},
|
||||
{ success: true, data: { conversationId: mem.key, data: mem.data } },
|
||||
{ status: 200 }
|
||||
)
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Error retrieving memory`, { error })
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: error.message || 'Failed to retrieve memory',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: error.message || 'Failed to retrieve memory' } },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* DELETE handler for removing a specific memory
|
||||
*/
|
||||
export async function DELETE(
|
||||
request: NextRequest,
|
||||
{ params }: { params: Promise<{ id: string }> }
|
||||
@@ -244,32 +175,28 @@ export async function DELETE(
|
||||
const { id } = await params
|
||||
|
||||
try {
|
||||
logger.info(`[${requestId}] Processing memory delete request for ID: ${id}`)
|
||||
|
||||
const url = new URL(request.url)
|
||||
const workflowId = url.searchParams.get('workflowId')
|
||||
|
||||
const validation = memoryQuerySchema.safeParse({ workflowId })
|
||||
const workspaceId = url.searchParams.get('workspaceId')
|
||||
|
||||
const validation = memoryQuerySchema.safeParse({ workspaceId })
|
||||
if (!validation.success) {
|
||||
const errorMessage = validation.error.errors
|
||||
.map((err) => `${err.path.join('.')}: ${err.message}`)
|
||||
.join(', ')
|
||||
logger.warn(`[${requestId}] Validation error: ${errorMessage}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: errorMessage,
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: errorMessage } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const { workflowId: validatedWorkflowId } = validation.data
|
||||
const { workspaceId: validatedWorkspaceId } = validation.data
|
||||
|
||||
const accessCheck = await validateMemoryAccess(request, validatedWorkflowId, requestId, 'write')
|
||||
const accessCheck = await validateMemoryAccess(
|
||||
request,
|
||||
validatedWorkspaceId,
|
||||
requestId,
|
||||
'write'
|
||||
)
|
||||
if ('error' in accessCheck) {
|
||||
return accessCheck.error
|
||||
}
|
||||
@@ -277,61 +204,41 @@ export async function DELETE(
|
||||
const existingMemory = await db
|
||||
.select({ id: memory.id })
|
||||
.from(memory)
|
||||
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
|
||||
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
|
||||
.limit(1)
|
||||
|
||||
if (existingMemory.length === 0) {
|
||||
logger.warn(`[${requestId}] Memory not found: ${id} for workflow: ${validatedWorkflowId}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Memory not found',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Memory not found' } },
|
||||
{ status: 404 }
|
||||
)
|
||||
}
|
||||
|
||||
await db
|
||||
.delete(memory)
|
||||
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
|
||||
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Memory deleted successfully: ${id} for workflow: ${validatedWorkflowId}`
|
||||
)
|
||||
logger.info(`[${requestId}] Memory deleted: ${id} for workspace: ${validatedWorkspaceId}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
data: { message: 'Memory deleted successfully' },
|
||||
},
|
||||
{ success: true, data: { message: 'Memory deleted successfully' } },
|
||||
{ status: 200 }
|
||||
)
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Error deleting memory`, { error })
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: error.message || 'Failed to delete memory',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: error.message || 'Failed to delete memory' } },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* PUT handler for updating a specific memory
|
||||
*/
|
||||
export async function PUT(request: NextRequest, { params }: { params: Promise<{ id: string }> }) {
|
||||
const requestId = generateRequestId()
|
||||
const { id } = await params
|
||||
|
||||
try {
|
||||
logger.info(`[${requestId}] Processing memory update request for ID: ${id}`)
|
||||
|
||||
let validatedData
|
||||
let validatedWorkflowId
|
||||
let validatedWorkspaceId
|
||||
try {
|
||||
const body = await request.json()
|
||||
const validation = memoryPutBodySchema.safeParse(body)
|
||||
@@ -340,34 +247,27 @@ export async function PUT(request: NextRequest, { params }: { params: Promise<{
|
||||
const errorMessage = validation.error.errors
|
||||
.map((err) => `${err.path.join('.')}: ${err.message}`)
|
||||
.join(', ')
|
||||
logger.warn(`[${requestId}] Validation error: ${errorMessage}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: `Invalid request body: ${errorMessage}`,
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: `Invalid request body: ${errorMessage}` } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
validatedData = validation.data.data
|
||||
validatedWorkflowId = validation.data.workflowId
|
||||
} catch (error: any) {
|
||||
logger.warn(`[${requestId}] Failed to parse request body: ${error.message}`)
|
||||
validatedWorkspaceId = validation.data.workspaceId
|
||||
} catch {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Invalid JSON in request body',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Invalid JSON in request body' } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const accessCheck = await validateMemoryAccess(request, validatedWorkflowId, requestId, 'write')
|
||||
const accessCheck = await validateMemoryAccess(
|
||||
request,
|
||||
validatedWorkspaceId,
|
||||
requestId,
|
||||
'write'
|
||||
)
|
||||
if ('error' in accessCheck) {
|
||||
return accessCheck.error
|
||||
}
|
||||
@@ -375,18 +275,12 @@ export async function PUT(request: NextRequest, { params }: { params: Promise<{
|
||||
const existingMemories = await db
|
||||
.select()
|
||||
.from(memory)
|
||||
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
|
||||
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
|
||||
.limit(1)
|
||||
|
||||
if (existingMemories.length === 0) {
|
||||
logger.warn(`[${requestId}] Memory not found: ${id} for workflow: ${validatedWorkflowId}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Memory not found',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Memory not found' } },
|
||||
{ status: 404 }
|
||||
)
|
||||
}
|
||||
@@ -396,14 +290,8 @@ export async function PUT(request: NextRequest, { params }: { params: Promise<{
|
||||
const errorMessage = agentValidation.error.errors
|
||||
.map((err) => `${err.path.join('.')}: ${err.message}`)
|
||||
.join(', ')
|
||||
logger.warn(`[${requestId}] Agent memory validation error: ${errorMessage}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: `Invalid agent memory data: ${errorMessage}`,
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: `Invalid agent memory data: ${errorMessage}` } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
@@ -411,59 +299,26 @@ export async function PUT(request: NextRequest, { params }: { params: Promise<{
|
||||
const now = new Date()
|
||||
await db
|
||||
.update(memory)
|
||||
.set({
|
||||
data: validatedData,
|
||||
updatedAt: now,
|
||||
})
|
||||
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
|
||||
.set({ data: validatedData, updatedAt: now })
|
||||
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
|
||||
|
||||
const updatedMemories = await db
|
||||
.select()
|
||||
.from(memory)
|
||||
.where(and(eq(memory.key, id), eq(memory.workflowId, validatedWorkflowId)))
|
||||
.where(and(eq(memory.key, id), eq(memory.workspaceId, validatedWorkspaceId)))
|
||||
.limit(1)
|
||||
|
||||
const mem = updatedMemories[0]
|
||||
const parsed = parseMemoryKey(mem.key)
|
||||
|
||||
let enrichedMemory
|
||||
if (!parsed) {
|
||||
enrichedMemory = {
|
||||
conversationId: mem.key,
|
||||
blockId: 'unknown',
|
||||
blockName: 'unknown',
|
||||
data: mem.data,
|
||||
}
|
||||
} else {
|
||||
const { conversationId, blockId } = parsed
|
||||
const blockName = (await getBlockName(blockId, validatedWorkflowId)) || 'unknown'
|
||||
|
||||
enrichedMemory = {
|
||||
conversationId,
|
||||
blockId,
|
||||
blockName,
|
||||
data: mem.data,
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Memory updated successfully: ${id} for workflow: ${validatedWorkflowId}`
|
||||
)
|
||||
logger.info(`[${requestId}] Memory updated: ${id} for workspace: ${validatedWorkspaceId}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
data: enrichedMemory,
|
||||
},
|
||||
{ success: true, data: { conversationId: mem.key, data: mem.data } },
|
||||
{ status: 200 }
|
||||
)
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Error updating memory`, { error })
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: error.message || 'Failed to update memory',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: error.message || 'Failed to update memory' } },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,42 +1,56 @@
|
||||
import { db } from '@sim/db'
|
||||
import { memory, workflowBlocks } from '@sim/db/schema'
|
||||
import { and, eq, inArray, isNull, like } from 'drizzle-orm'
|
||||
import { memory, permissions, workspace } from '@sim/db/schema'
|
||||
import { and, eq, isNull, like } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { checkHybridAuth } from '@/lib/auth/hybrid'
|
||||
import { generateRequestId } from '@/lib/core/utils/request'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getWorkflowAccessContext } from '@/lib/workflows/utils'
|
||||
|
||||
const logger = createLogger('MemoryAPI')
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
export const runtime = 'nodejs'
|
||||
|
||||
/**
|
||||
* Parse memory key into conversationId and blockId
|
||||
* Key format: conversationId:blockId
|
||||
* @param key The memory key to parse
|
||||
* @returns Object with conversationId and blockId, or null if invalid
|
||||
*/
|
||||
function parseMemoryKey(key: string): { conversationId: string; blockId: string } | null {
|
||||
const parts = key.split(':')
|
||||
if (parts.length !== 2) {
|
||||
return null
|
||||
async function checkWorkspaceAccess(
|
||||
workspaceId: string,
|
||||
userId: string
|
||||
): Promise<{ hasAccess: boolean; canWrite: boolean }> {
|
||||
const [workspaceRow] = await db
|
||||
.select({ ownerId: workspace.ownerId })
|
||||
.from(workspace)
|
||||
.where(eq(workspace.id, workspaceId))
|
||||
.limit(1)
|
||||
|
||||
if (!workspaceRow) {
|
||||
return { hasAccess: false, canWrite: false }
|
||||
}
|
||||
|
||||
if (workspaceRow.ownerId === userId) {
|
||||
return { hasAccess: true, canWrite: true }
|
||||
}
|
||||
|
||||
const [permissionRow] = await db
|
||||
.select({ permissionType: permissions.permissionType })
|
||||
.from(permissions)
|
||||
.where(
|
||||
and(
|
||||
eq(permissions.userId, userId),
|
||||
eq(permissions.entityType, 'workspace'),
|
||||
eq(permissions.entityId, workspaceId)
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (!permissionRow) {
|
||||
return { hasAccess: false, canWrite: false }
|
||||
}
|
||||
|
||||
return {
|
||||
conversationId: parts[0],
|
||||
blockId: parts[1],
|
||||
hasAccess: true,
|
||||
canWrite: permissionRow.permissionType === 'write' || permissionRow.permissionType === 'admin',
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* GET handler for searching and retrieving memories
|
||||
* Supports query parameters:
|
||||
* - query: Search string for memory keys
|
||||
* - type: Filter by memory type
|
||||
* - limit: Maximum number of results (default: 50)
|
||||
* - workflowId: Filter by workflow ID (required)
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
const requestId = generateRequestId()
|
||||
|
||||
@@ -45,102 +59,32 @@ export async function GET(request: NextRequest) {
|
||||
if (!authResult.success || !authResult.userId) {
|
||||
logger.warn(`[${requestId}] Unauthorized memory access attempt`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: authResult.error || 'Authentication required',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: authResult.error || 'Authentication required' } },
|
||||
{ status: 401 }
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Processing memory search request`)
|
||||
|
||||
const url = new URL(request.url)
|
||||
const workflowId = url.searchParams.get('workflowId')
|
||||
const workspaceId = url.searchParams.get('workspaceId')
|
||||
const searchQuery = url.searchParams.get('query')
|
||||
const blockNameFilter = url.searchParams.get('blockName')
|
||||
const limit = Number.parseInt(url.searchParams.get('limit') || '50')
|
||||
|
||||
if (!workflowId) {
|
||||
logger.warn(`[${requestId}] Missing required parameter: workflowId`)
|
||||
if (!workspaceId) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'workflowId parameter is required',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'workspaceId parameter is required' } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const accessContext = await getWorkflowAccessContext(workflowId, authResult.userId)
|
||||
if (!accessContext) {
|
||||
logger.warn(`[${requestId}] Workflow ${workflowId} not found for user ${authResult.userId}`)
|
||||
const { hasAccess } = await checkWorkspaceAccess(workspaceId, authResult.userId)
|
||||
if (!hasAccess) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Workflow not found',
|
||||
},
|
||||
},
|
||||
{ status: 404 }
|
||||
)
|
||||
}
|
||||
|
||||
const { workspacePermission, isOwner } = accessContext
|
||||
|
||||
if (!isOwner && !workspacePermission) {
|
||||
logger.warn(
|
||||
`[${requestId}] User ${authResult.userId} denied access to workflow ${workflowId}`
|
||||
)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Access denied to this workflow',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Access denied to this workspace' } },
|
||||
{ status: 403 }
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] User ${authResult.userId} (${authResult.authType}) accessing memories for workflow ${workflowId}`
|
||||
)
|
||||
|
||||
const conditions = []
|
||||
|
||||
conditions.push(isNull(memory.deletedAt))
|
||||
|
||||
conditions.push(eq(memory.workflowId, workflowId))
|
||||
|
||||
let blockIdsToFilter: string[] | null = null
|
||||
if (blockNameFilter) {
|
||||
const blocks = await db
|
||||
.select({ id: workflowBlocks.id })
|
||||
.from(workflowBlocks)
|
||||
.where(
|
||||
and(eq(workflowBlocks.workflowId, workflowId), eq(workflowBlocks.name, blockNameFilter))
|
||||
)
|
||||
|
||||
if (blocks.length === 0) {
|
||||
logger.info(
|
||||
`[${requestId}] No blocks found with name "${blockNameFilter}" for workflow: ${workflowId}`
|
||||
)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
data: { memories: [] },
|
||||
},
|
||||
{ status: 200 }
|
||||
)
|
||||
}
|
||||
|
||||
blockIdsToFilter = blocks.map((b) => b.id)
|
||||
}
|
||||
const conditions = [isNull(memory.deletedAt), eq(memory.workspaceId, workspaceId)]
|
||||
|
||||
if (searchQuery) {
|
||||
conditions.push(like(memory.key, `%${searchQuery}%`))
|
||||
@@ -153,95 +97,27 @@ export async function GET(request: NextRequest) {
|
||||
.orderBy(memory.createdAt)
|
||||
.limit(limit)
|
||||
|
||||
const filteredMemories = blockIdsToFilter
|
||||
? rawMemories.filter((mem) => {
|
||||
const parsed = parseMemoryKey(mem.key)
|
||||
return parsed && blockIdsToFilter.includes(parsed.blockId)
|
||||
})
|
||||
: rawMemories
|
||||
|
||||
const blockIds = new Set<string>()
|
||||
const parsedKeys = new Map<string, { conversationId: string; blockId: string }>()
|
||||
|
||||
for (const mem of filteredMemories) {
|
||||
const parsed = parseMemoryKey(mem.key)
|
||||
if (parsed) {
|
||||
blockIds.add(parsed.blockId)
|
||||
parsedKeys.set(mem.key, parsed)
|
||||
}
|
||||
}
|
||||
|
||||
const blockNameMap = new Map<string, string>()
|
||||
if (blockIds.size > 0) {
|
||||
const blocks = await db
|
||||
.select({ id: workflowBlocks.id, name: workflowBlocks.name })
|
||||
.from(workflowBlocks)
|
||||
.where(
|
||||
and(
|
||||
eq(workflowBlocks.workflowId, workflowId),
|
||||
inArray(workflowBlocks.id, Array.from(blockIds))
|
||||
)
|
||||
)
|
||||
|
||||
for (const block of blocks) {
|
||||
blockNameMap.set(block.id, block.name)
|
||||
}
|
||||
}
|
||||
|
||||
const enrichedMemories = filteredMemories.map((mem) => {
|
||||
const parsed = parsedKeys.get(mem.key)
|
||||
|
||||
if (!parsed) {
|
||||
return {
|
||||
conversationId: mem.key,
|
||||
blockId: 'unknown',
|
||||
blockName: 'unknown',
|
||||
data: mem.data,
|
||||
}
|
||||
}
|
||||
|
||||
const { conversationId, blockId } = parsed
|
||||
const blockName = blockNameMap.get(blockId) || 'unknown'
|
||||
|
||||
return {
|
||||
conversationId,
|
||||
blockId,
|
||||
blockName,
|
||||
data: mem.data,
|
||||
}
|
||||
})
|
||||
const enrichedMemories = rawMemories.map((mem) => ({
|
||||
conversationId: mem.key,
|
||||
data: mem.data,
|
||||
}))
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Found ${enrichedMemories.length} memories for workflow: ${workflowId}`
|
||||
`[${requestId}] Found ${enrichedMemories.length} memories for workspace: ${workspaceId}`
|
||||
)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
data: { memories: enrichedMemories },
|
||||
},
|
||||
{ success: true, data: { memories: enrichedMemories } },
|
||||
{ status: 200 }
|
||||
)
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Error searching memories`, { error })
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: error.message || 'Failed to search memories',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: error.message || 'Failed to search memories' } },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* POST handler for creating new memories
|
||||
* Requires:
|
||||
* - key: Unique identifier for the memory (within workflow scope)
|
||||
* - type: Memory type ('agent')
|
||||
* - data: Memory content (agent message with role and content)
|
||||
* - workflowId: ID of the workflow this memory belongs to
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = generateRequestId()
|
||||
|
||||
@@ -250,123 +126,63 @@ export async function POST(request: NextRequest) {
|
||||
if (!authResult.success || !authResult.userId) {
|
||||
logger.warn(`[${requestId}] Unauthorized memory creation attempt`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: authResult.error || 'Authentication required',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: authResult.error || 'Authentication required' } },
|
||||
{ status: 401 }
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Processing memory creation request`)
|
||||
|
||||
const body = await request.json()
|
||||
const { key, data, workflowId } = body
|
||||
const { key, data, workspaceId } = body
|
||||
|
||||
if (!key) {
|
||||
logger.warn(`[${requestId}] Missing required field: key`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Memory key is required',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Memory key is required' } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (!data) {
|
||||
logger.warn(`[${requestId}] Missing required field: data`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Memory data is required',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Memory data is required' } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (!workflowId) {
|
||||
logger.warn(`[${requestId}] Missing required field: workflowId`)
|
||||
if (!workspaceId) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'workflowId is required',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'workspaceId is required' } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const accessContext = await getWorkflowAccessContext(workflowId, authResult.userId)
|
||||
if (!accessContext) {
|
||||
logger.warn(`[${requestId}] Workflow ${workflowId} not found for user ${authResult.userId}`)
|
||||
const { hasAccess, canWrite } = await checkWorkspaceAccess(workspaceId, authResult.userId)
|
||||
if (!hasAccess) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Workflow not found',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Workspace not found' } },
|
||||
{ status: 404 }
|
||||
)
|
||||
}
|
||||
|
||||
const { workspacePermission, isOwner } = accessContext
|
||||
|
||||
const hasWritePermission =
|
||||
isOwner || workspacePermission === 'write' || workspacePermission === 'admin'
|
||||
|
||||
if (!hasWritePermission) {
|
||||
logger.warn(
|
||||
`[${requestId}] User ${authResult.userId} denied write access to workflow ${workflowId}`
|
||||
)
|
||||
if (!canWrite) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Write access denied to this workflow',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Write access denied to this workspace' } },
|
||||
{ status: 403 }
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] User ${authResult.userId} (${authResult.authType}) creating memory for workflow ${workflowId}`
|
||||
)
|
||||
|
||||
const dataToValidate = Array.isArray(data) ? data : [data]
|
||||
|
||||
for (const msg of dataToValidate) {
|
||||
if (!msg || typeof msg !== 'object' || !msg.role || !msg.content) {
|
||||
logger.warn(`[${requestId}] Missing required message fields`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Memory requires messages with role and content',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Memory requires messages with role and content' } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (!['user', 'assistant', 'system'].includes(msg.role)) {
|
||||
logger.warn(`[${requestId}] Invalid message role: ${msg.role}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Message role must be user, assistant, or system',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Message role must be user, assistant, or system' } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
@@ -382,114 +198,59 @@ export async function POST(request: NextRequest) {
|
||||
.insert(memory)
|
||||
.values({
|
||||
id,
|
||||
workflowId,
|
||||
workspaceId,
|
||||
key,
|
||||
data: initialData,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
})
|
||||
.onConflictDoUpdate({
|
||||
target: [memory.workflowId, memory.key],
|
||||
target: [memory.workspaceId, memory.key],
|
||||
set: {
|
||||
data: sql`${memory.data} || ${JSON.stringify(initialData)}::jsonb`,
|
||||
updatedAt: now,
|
||||
},
|
||||
})
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Memory operation successful (atomic): ${key} for workflow: ${workflowId}`
|
||||
)
|
||||
logger.info(`[${requestId}] Memory operation successful: ${key} for workspace: ${workspaceId}`)
|
||||
|
||||
const allMemories = await db
|
||||
.select()
|
||||
.from(memory)
|
||||
.where(and(eq(memory.key, key), eq(memory.workflowId, workflowId), isNull(memory.deletedAt)))
|
||||
.where(
|
||||
and(eq(memory.key, key), eq(memory.workspaceId, workspaceId), isNull(memory.deletedAt))
|
||||
)
|
||||
.orderBy(memory.createdAt)
|
||||
|
||||
if (allMemories.length === 0) {
|
||||
logger.warn(`[${requestId}] No memories found after creating/updating memory: ${key}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Failed to retrieve memory after creation/update',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Failed to retrieve memory after creation/update' } },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
|
||||
const memoryRecord = allMemories[0]
|
||||
const parsed = parseMemoryKey(memoryRecord.key)
|
||||
|
||||
let enrichedMemory
|
||||
if (!parsed) {
|
||||
enrichedMemory = {
|
||||
conversationId: memoryRecord.key,
|
||||
blockId: 'unknown',
|
||||
blockName: 'unknown',
|
||||
data: memoryRecord.data,
|
||||
}
|
||||
} else {
|
||||
const { conversationId, blockId } = parsed
|
||||
const blockName = await (async () => {
|
||||
const blocks = await db
|
||||
.select({ name: workflowBlocks.name })
|
||||
.from(workflowBlocks)
|
||||
.where(and(eq(workflowBlocks.id, blockId), eq(workflowBlocks.workflowId, workflowId)))
|
||||
.limit(1)
|
||||
return blocks.length > 0 ? blocks[0].name : 'unknown'
|
||||
})()
|
||||
|
||||
enrichedMemory = {
|
||||
conversationId,
|
||||
blockId,
|
||||
blockName,
|
||||
data: memoryRecord.data,
|
||||
}
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
data: enrichedMemory,
|
||||
},
|
||||
{ success: true, data: { conversationId: memoryRecord.key, data: memoryRecord.data } },
|
||||
{ status: 200 }
|
||||
)
|
||||
} catch (error: any) {
|
||||
if (error.code === '23505') {
|
||||
logger.warn(`[${requestId}] Duplicate key violation`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Memory with this key already exists',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Memory with this key already exists' } },
|
||||
{ status: 409 }
|
||||
)
|
||||
}
|
||||
|
||||
logger.error(`[${requestId}] Error creating memory`, { error })
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: error.message || 'Failed to create memory',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: error.message || 'Failed to create memory' } },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* DELETE handler for pattern-based memory deletion
|
||||
* Supports query parameters:
|
||||
* - workflowId: Required
|
||||
* - conversationId: Optional - delete all memories for this conversation
|
||||
* - blockId: Optional - delete all memories for this block
|
||||
* - blockName: Optional - delete all memories for blocks with this name
|
||||
*/
|
||||
export async function DELETE(request: NextRequest) {
|
||||
const requestId = generateRequestId()
|
||||
|
||||
@@ -498,175 +259,52 @@ export async function DELETE(request: NextRequest) {
|
||||
if (!authResult.success || !authResult.userId) {
|
||||
logger.warn(`[${requestId}] Unauthorized memory deletion attempt`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: authResult.error || 'Authentication required',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: authResult.error || 'Authentication required' } },
|
||||
{ status: 401 }
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Processing memory deletion request`)
|
||||
|
||||
const url = new URL(request.url)
|
||||
const workflowId = url.searchParams.get('workflowId')
|
||||
const workspaceId = url.searchParams.get('workspaceId')
|
||||
const conversationId = url.searchParams.get('conversationId')
|
||||
const blockId = url.searchParams.get('blockId')
|
||||
const blockName = url.searchParams.get('blockName')
|
||||
|
||||
if (!workflowId) {
|
||||
logger.warn(`[${requestId}] Missing required parameter: workflowId`)
|
||||
if (!workspaceId) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'workflowId parameter is required',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'workspaceId parameter is required' } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (!conversationId && !blockId && !blockName) {
|
||||
logger.warn(`[${requestId}] No filter parameters provided`)
|
||||
if (!conversationId) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'At least one of conversationId, blockId, or blockName must be provided',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'conversationId must be provided' } },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const accessContext = await getWorkflowAccessContext(workflowId, authResult.userId)
|
||||
if (!accessContext) {
|
||||
logger.warn(`[${requestId}] Workflow ${workflowId} not found for user ${authResult.userId}`)
|
||||
const { hasAccess, canWrite } = await checkWorkspaceAccess(workspaceId, authResult.userId)
|
||||
if (!hasAccess) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Workflow not found',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Workspace not found' } },
|
||||
{ status: 404 }
|
||||
)
|
||||
}
|
||||
|
||||
const { workspacePermission, isOwner } = accessContext
|
||||
|
||||
const hasWritePermission =
|
||||
isOwner || workspacePermission === 'write' || workspacePermission === 'admin'
|
||||
|
||||
if (!hasWritePermission) {
|
||||
logger.warn(
|
||||
`[${requestId}] User ${authResult.userId} denied delete access to workflow ${workflowId}`
|
||||
)
|
||||
if (!canWrite) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: 'Write access denied to this workflow',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: 'Write access denied to this workspace' } },
|
||||
{ status: 403 }
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] User ${authResult.userId} (${authResult.authType}) deleting memories for workflow ${workflowId}`
|
||||
)
|
||||
const result = await db
|
||||
.delete(memory)
|
||||
.where(and(eq(memory.key, conversationId), eq(memory.workspaceId, workspaceId)))
|
||||
.returning({ id: memory.id })
|
||||
|
||||
let deletedCount = 0
|
||||
const deletedCount = result.length
|
||||
|
||||
if (conversationId && blockId) {
|
||||
const key = `${conversationId}:${blockId}`
|
||||
const result = await db
|
||||
.delete(memory)
|
||||
.where(and(eq(memory.key, key), eq(memory.workflowId, workflowId)))
|
||||
.returning({ id: memory.id })
|
||||
|
||||
deletedCount = result.length
|
||||
} else if (conversationId && blockName) {
|
||||
const blocks = await db
|
||||
.select({ id: workflowBlocks.id })
|
||||
.from(workflowBlocks)
|
||||
.where(and(eq(workflowBlocks.workflowId, workflowId), eq(workflowBlocks.name, blockName)))
|
||||
|
||||
if (blocks.length === 0) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
data: {
|
||||
message: `No blocks found with name "${blockName}"`,
|
||||
deletedCount: 0,
|
||||
},
|
||||
},
|
||||
{ status: 200 }
|
||||
)
|
||||
}
|
||||
|
||||
for (const block of blocks) {
|
||||
const key = `${conversationId}:${block.id}`
|
||||
const result = await db
|
||||
.delete(memory)
|
||||
.where(and(eq(memory.key, key), eq(memory.workflowId, workflowId)))
|
||||
.returning({ id: memory.id })
|
||||
|
||||
deletedCount += result.length
|
||||
}
|
||||
} else if (conversationId) {
|
||||
const pattern = `${conversationId}:%`
|
||||
const result = await db
|
||||
.delete(memory)
|
||||
.where(and(like(memory.key, pattern), eq(memory.workflowId, workflowId)))
|
||||
.returning({ id: memory.id })
|
||||
|
||||
deletedCount = result.length
|
||||
} else if (blockId) {
|
||||
const pattern = `%:${blockId}`
|
||||
const result = await db
|
||||
.delete(memory)
|
||||
.where(and(like(memory.key, pattern), eq(memory.workflowId, workflowId)))
|
||||
.returning({ id: memory.id })
|
||||
|
||||
deletedCount = result.length
|
||||
} else if (blockName) {
|
||||
const blocks = await db
|
||||
.select({ id: workflowBlocks.id })
|
||||
.from(workflowBlocks)
|
||||
.where(and(eq(workflowBlocks.workflowId, workflowId), eq(workflowBlocks.name, blockName)))
|
||||
|
||||
if (blocks.length === 0) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
data: {
|
||||
message: `No blocks found with name "${blockName}"`,
|
||||
deletedCount: 0,
|
||||
},
|
||||
},
|
||||
{ status: 200 }
|
||||
)
|
||||
}
|
||||
|
||||
for (const block of blocks) {
|
||||
const pattern = `%:${block.id}`
|
||||
const result = await db
|
||||
.delete(memory)
|
||||
.where(and(like(memory.key, pattern), eq(memory.workflowId, workflowId)))
|
||||
.returning({ id: memory.id })
|
||||
|
||||
deletedCount += result.length
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Successfully deleted ${deletedCount} memories for workflow: ${workflowId}`
|
||||
)
|
||||
logger.info(`[${requestId}] Deleted ${deletedCount} memories for workspace: ${workspaceId}`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
@@ -683,12 +321,7 @@ export async function DELETE(request: NextRequest) {
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Error deleting memories`, { error })
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: {
|
||||
message: error.message || 'Failed to delete memories',
|
||||
},
|
||||
},
|
||||
{ success: false, error: { message: error.message || 'Failed to delete memories' } },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
|
||||
@@ -6,12 +6,29 @@ const logger = createLogger('OpenRouterModelsAPI')
|
||||
|
||||
interface OpenRouterModel {
|
||||
id: string
|
||||
context_length?: number
|
||||
supported_parameters?: string[]
|
||||
pricing?: {
|
||||
prompt?: string
|
||||
completion?: string
|
||||
}
|
||||
}
|
||||
|
||||
interface OpenRouterResponse {
|
||||
data: OpenRouterModel[]
|
||||
}
|
||||
|
||||
export interface OpenRouterModelInfo {
|
||||
id: string
|
||||
contextLength?: number
|
||||
supportsStructuredOutputs?: boolean
|
||||
supportsTools?: boolean
|
||||
pricing?: {
|
||||
input: number
|
||||
output: number
|
||||
}
|
||||
}
|
||||
|
||||
export async function GET(_request: NextRequest) {
|
||||
try {
|
||||
const response = await fetch('https://openrouter.ai/api/v1/models', {
|
||||
@@ -24,23 +41,51 @@ export async function GET(_request: NextRequest) {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
})
|
||||
return NextResponse.json({ models: [] })
|
||||
return NextResponse.json({ models: [], modelInfo: {} })
|
||||
}
|
||||
|
||||
const data = (await response.json()) as OpenRouterResponse
|
||||
const allModels = Array.from(new Set(data.data?.map((model) => `openrouter/${model.id}`) ?? []))
|
||||
const models = filterBlacklistedModels(allModels)
|
||||
|
||||
const modelInfo: Record<string, OpenRouterModelInfo> = {}
|
||||
const allModels: string[] = []
|
||||
|
||||
for (const model of data.data ?? []) {
|
||||
const modelId = `openrouter/${model.id}`
|
||||
allModels.push(modelId)
|
||||
|
||||
const supportedParams = model.supported_parameters ?? []
|
||||
modelInfo[modelId] = {
|
||||
id: modelId,
|
||||
contextLength: model.context_length,
|
||||
supportsStructuredOutputs: supportedParams.includes('structured_outputs'),
|
||||
supportsTools: supportedParams.includes('tools'),
|
||||
pricing: model.pricing
|
||||
? {
|
||||
input: Number.parseFloat(model.pricing.prompt ?? '0') * 1000000,
|
||||
output: Number.parseFloat(model.pricing.completion ?? '0') * 1000000,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
const uniqueModels = Array.from(new Set(allModels))
|
||||
const models = filterBlacklistedModels(uniqueModels)
|
||||
|
||||
const structuredOutputCount = Object.values(modelInfo).filter(
|
||||
(m) => m.supportsStructuredOutputs
|
||||
).length
|
||||
|
||||
logger.info('Successfully fetched OpenRouter models', {
|
||||
count: models.length,
|
||||
filtered: allModels.length - models.length,
|
||||
filtered: uniqueModels.length - models.length,
|
||||
withStructuredOutputs: structuredOutputCount,
|
||||
})
|
||||
|
||||
return NextResponse.json({ models })
|
||||
return NextResponse.json({ models, modelInfo })
|
||||
} catch (error) {
|
||||
logger.error('Error fetching OpenRouter models', {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
return NextResponse.json({ models: [] })
|
||||
return NextResponse.json({ models: [], modelInfo: {} })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,126 +49,6 @@ const ExecuteWorkflowSchema = z.object({
|
||||
export const runtime = 'nodejs'
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
/**
|
||||
* Execute workflow with streaming support - used by chat and other streaming endpoints
|
||||
*
|
||||
* This function assumes preprocessing has already been completed.
|
||||
* Callers must run preprocessExecution() first to validate workflow, check usage limits,
|
||||
* and resolve actor before calling this function.
|
||||
*
|
||||
* This is a wrapper function that:
|
||||
* - Supports streaming callbacks (onStream, onBlockComplete)
|
||||
* - Returns ExecutionResult instead of NextResponse
|
||||
* - Handles pause/resume logic
|
||||
*
|
||||
* Used by:
|
||||
* - Chat execution (/api/chat/[identifier]/route.ts)
|
||||
* - Streaming responses (lib/workflows/streaming.ts)
|
||||
*/
|
||||
export async function executeWorkflow(
|
||||
workflow: any,
|
||||
requestId: string,
|
||||
input: any | undefined,
|
||||
actorUserId: string,
|
||||
streamConfig?: {
|
||||
enabled: boolean
|
||||
selectedOutputs?: string[]
|
||||
isSecureMode?: boolean
|
||||
workflowTriggerType?: 'api' | 'chat'
|
||||
onStream?: (streamingExec: any) => Promise<void>
|
||||
onBlockComplete?: (blockId: string, output: any) => Promise<void>
|
||||
skipLoggingComplete?: boolean
|
||||
},
|
||||
providedExecutionId?: string
|
||||
): Promise<any> {
|
||||
const workflowId = workflow.id
|
||||
const executionId = providedExecutionId || uuidv4()
|
||||
const triggerType = streamConfig?.workflowTriggerType || 'api'
|
||||
const loggingSession = new LoggingSession(workflowId, executionId, triggerType, requestId)
|
||||
|
||||
try {
|
||||
const metadata: ExecutionMetadata = {
|
||||
requestId,
|
||||
executionId,
|
||||
workflowId,
|
||||
workspaceId: workflow.workspaceId,
|
||||
userId: actorUserId,
|
||||
workflowUserId: workflow.userId,
|
||||
triggerType,
|
||||
useDraftState: false,
|
||||
startTime: new Date().toISOString(),
|
||||
isClientSession: false,
|
||||
}
|
||||
|
||||
const snapshot = new ExecutionSnapshot(
|
||||
metadata,
|
||||
workflow,
|
||||
input,
|
||||
workflow.variables || {},
|
||||
streamConfig?.selectedOutputs || []
|
||||
)
|
||||
|
||||
const result = await executeWorkflowCore({
|
||||
snapshot,
|
||||
callbacks: {
|
||||
onStream: streamConfig?.onStream,
|
||||
onBlockComplete: streamConfig?.onBlockComplete
|
||||
? async (blockId: string, _blockName: string, _blockType: string, output: any) => {
|
||||
await streamConfig.onBlockComplete!(blockId, output)
|
||||
}
|
||||
: undefined,
|
||||
},
|
||||
loggingSession,
|
||||
})
|
||||
|
||||
if (result.status === 'paused') {
|
||||
if (!result.snapshotSeed) {
|
||||
logger.error(`[${requestId}] Missing snapshot seed for paused execution`, {
|
||||
executionId,
|
||||
})
|
||||
} else {
|
||||
await PauseResumeManager.persistPauseResult({
|
||||
workflowId,
|
||||
executionId,
|
||||
pausePoints: result.pausePoints || [],
|
||||
snapshotSeed: result.snapshotSeed,
|
||||
executorUserId: result.metadata?.userId,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
await PauseResumeManager.processQueuedResumes(executionId)
|
||||
}
|
||||
|
||||
if (streamConfig?.skipLoggingComplete) {
|
||||
return {
|
||||
...result,
|
||||
_streamingMetadata: {
|
||||
loggingSession,
|
||||
processedInput: input,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Workflow execution failed:`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
export function createFilteredResult(result: any) {
|
||||
return {
|
||||
...result,
|
||||
logs: undefined,
|
||||
metadata: result.metadata
|
||||
? {
|
||||
...result.metadata,
|
||||
workflowConnections: undefined,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
function resolveOutputIds(
|
||||
selectedOutputs: string[] | undefined,
|
||||
blocks: Record<string, any>
|
||||
@@ -606,7 +486,6 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
isSecureMode: false,
|
||||
workflowTriggerType: triggerType === 'chat' ? 'chat' : 'api',
|
||||
},
|
||||
createFilteredResult,
|
||||
executionId,
|
||||
})
|
||||
|
||||
@@ -743,14 +622,17 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
|
||||
const onStream = async (streamingExec: StreamingExecution) => {
|
||||
const blockId = (streamingExec.execution as any).blockId
|
||||
|
||||
const reader = streamingExec.stream.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let chunkCount = 0
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
chunkCount++
|
||||
const chunk = decoder.decode(value, { stream: true })
|
||||
sendEvent({
|
||||
type: 'stream:chunk',
|
||||
|
||||
@@ -314,22 +314,9 @@ export function useChatStreaming() {
|
||||
let finalContent = accumulatedText
|
||||
|
||||
if (formattedOutputs.length > 0) {
|
||||
const trimmedStreamingContent = accumulatedText.trim()
|
||||
|
||||
const uniqueOutputs = formattedOutputs.filter((output) => {
|
||||
const trimmedOutput = output.trim()
|
||||
if (!trimmedOutput) return false
|
||||
|
||||
// Skip outputs that exactly match the streamed content to avoid duplication
|
||||
if (trimmedStreamingContent && trimmedOutput === trimmedStreamingContent) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if (uniqueOutputs.length > 0) {
|
||||
const combinedOutputs = uniqueOutputs.join('\n\n')
|
||||
const nonEmptyOutputs = formattedOutputs.filter((output) => output.trim())
|
||||
if (nonEmptyOutputs.length > 0) {
|
||||
const combinedOutputs = nonEmptyOutputs.join('\n\n')
|
||||
finalContent = finalContent
|
||||
? `${finalContent.trim()}\n\n${combinedOutputs}`
|
||||
: combinedOutputs
|
||||
|
||||
@@ -16,6 +16,7 @@ const logger = createLogger('ProviderModelsLoader')
|
||||
function useSyncProvider(provider: ProviderName) {
|
||||
const setProviderModels = useProvidersStore((state) => state.setProviderModels)
|
||||
const setProviderLoading = useProvidersStore((state) => state.setProviderLoading)
|
||||
const setOpenRouterModelInfo = useProvidersStore((state) => state.setOpenRouterModelInfo)
|
||||
const { data, isLoading, isFetching, error } = useProviderModels(provider)
|
||||
|
||||
useEffect(() => {
|
||||
@@ -27,18 +28,21 @@ function useSyncProvider(provider: ProviderName) {
|
||||
|
||||
try {
|
||||
if (provider === 'ollama') {
|
||||
updateOllamaProviderModels(data)
|
||||
updateOllamaProviderModels(data.models)
|
||||
} else if (provider === 'vllm') {
|
||||
updateVLLMProviderModels(data)
|
||||
updateVLLMProviderModels(data.models)
|
||||
} else if (provider === 'openrouter') {
|
||||
void updateOpenRouterProviderModels(data)
|
||||
void updateOpenRouterProviderModels(data.models)
|
||||
if (data.modelInfo) {
|
||||
setOpenRouterModelInfo(data.modelInfo)
|
||||
}
|
||||
}
|
||||
} catch (syncError) {
|
||||
logger.warn(`Failed to sync provider definitions for ${provider}`, syncError as Error)
|
||||
}
|
||||
|
||||
setProviderModels(provider, data)
|
||||
}, [provider, data, setProviderModels])
|
||||
setProviderModels(provider, data.models)
|
||||
}, [provider, data, setProviderModels, setOpenRouterModelInfo])
|
||||
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
|
||||
@@ -398,6 +398,7 @@ export function useWorkflowExecution() {
|
||||
}
|
||||
|
||||
const streamCompletionTimes = new Map<string, number>()
|
||||
const processedFirstChunk = new Set<string>()
|
||||
|
||||
const onStream = async (streamingExecution: StreamingExecution) => {
|
||||
const promise = (async () => {
|
||||
@@ -405,16 +406,14 @@ export function useWorkflowExecution() {
|
||||
const reader = streamingExecution.stream.getReader()
|
||||
const blockId = (streamingExecution.execution as any)?.blockId
|
||||
|
||||
let isFirstChunk = true
|
||||
|
||||
if (blockId) {
|
||||
if (blockId && !streamedContent.has(blockId)) {
|
||||
streamedContent.set(blockId, '')
|
||||
}
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
// Record when this stream completed
|
||||
if (blockId) {
|
||||
streamCompletionTimes.set(blockId, Date.now())
|
||||
}
|
||||
@@ -425,13 +424,12 @@ export function useWorkflowExecution() {
|
||||
streamedContent.set(blockId, (streamedContent.get(blockId) || '') + chunk)
|
||||
}
|
||||
|
||||
// Add separator before first chunk if this isn't the first block
|
||||
let chunkToSend = chunk
|
||||
if (isFirstChunk && streamedContent.size > 1) {
|
||||
chunkToSend = `\n\n${chunk}`
|
||||
isFirstChunk = false
|
||||
} else if (isFirstChunk) {
|
||||
isFirstChunk = false
|
||||
if (blockId && !processedFirstChunk.has(blockId)) {
|
||||
processedFirstChunk.add(blockId)
|
||||
if (streamedContent.size > 1) {
|
||||
chunkToSend = `\n\n${chunk}`
|
||||
}
|
||||
}
|
||||
|
||||
controller.enqueue(encodeSSE({ blockId, chunk: chunkToSend }))
|
||||
|
||||
@@ -41,17 +41,6 @@ export const MemoryBlock: BlockConfig = {
|
||||
},
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
id: 'blockId',
|
||||
title: 'Block ID',
|
||||
type: 'short-input',
|
||||
placeholder: 'Enter block ID (optional, defaults to current block)',
|
||||
condition: {
|
||||
field: 'operation',
|
||||
value: 'add',
|
||||
},
|
||||
required: false,
|
||||
},
|
||||
{
|
||||
id: 'id',
|
||||
title: 'Conversation ID',
|
||||
@@ -61,29 +50,7 @@ export const MemoryBlock: BlockConfig = {
|
||||
field: 'operation',
|
||||
value: 'get',
|
||||
},
|
||||
required: false,
|
||||
},
|
||||
{
|
||||
id: 'blockId',
|
||||
title: 'Block ID',
|
||||
type: 'short-input',
|
||||
placeholder: 'Enter block ID (optional)',
|
||||
condition: {
|
||||
field: 'operation',
|
||||
value: 'get',
|
||||
},
|
||||
required: false,
|
||||
},
|
||||
{
|
||||
id: 'blockName',
|
||||
title: 'Block Name',
|
||||
type: 'short-input',
|
||||
placeholder: 'Enter block name (optional)',
|
||||
condition: {
|
||||
field: 'operation',
|
||||
value: 'get',
|
||||
},
|
||||
required: false,
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
id: 'id',
|
||||
@@ -94,29 +61,7 @@ export const MemoryBlock: BlockConfig = {
|
||||
field: 'operation',
|
||||
value: 'delete',
|
||||
},
|
||||
required: false,
|
||||
},
|
||||
{
|
||||
id: 'blockId',
|
||||
title: 'Block ID',
|
||||
type: 'short-input',
|
||||
placeholder: 'Enter block ID (optional)',
|
||||
condition: {
|
||||
field: 'operation',
|
||||
value: 'delete',
|
||||
},
|
||||
required: false,
|
||||
},
|
||||
{
|
||||
id: 'blockName',
|
||||
title: 'Block Name',
|
||||
type: 'short-input',
|
||||
placeholder: 'Enter block name (optional)',
|
||||
condition: {
|
||||
field: 'operation',
|
||||
value: 'delete',
|
||||
},
|
||||
required: false,
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
id: 'role',
|
||||
@@ -186,10 +131,8 @@ export const MemoryBlock: BlockConfig = {
|
||||
}
|
||||
|
||||
if (params.operation === 'get' || params.operation === 'delete') {
|
||||
if (!conversationId && !params.blockId && !params.blockName) {
|
||||
errors.push(
|
||||
`At least one of ID, blockId, or blockName is required for ${params.operation} operation`
|
||||
)
|
||||
if (!conversationId) {
|
||||
errors.push(`Conversation ID is required for ${params.operation} operation`)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,33 +143,26 @@ export const MemoryBlock: BlockConfig = {
|
||||
const baseResult: Record<string, any> = {}
|
||||
|
||||
if (params.operation === 'add') {
|
||||
const result: Record<string, any> = {
|
||||
return {
|
||||
...baseResult,
|
||||
conversationId: conversationId,
|
||||
role: params.role,
|
||||
content: params.content,
|
||||
}
|
||||
if (params.blockId) {
|
||||
result.blockId = params.blockId
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
if (params.operation === 'get') {
|
||||
const result: Record<string, any> = { ...baseResult }
|
||||
if (conversationId) result.conversationId = conversationId
|
||||
if (params.blockId) result.blockId = params.blockId
|
||||
if (params.blockName) result.blockName = params.blockName
|
||||
return result
|
||||
return {
|
||||
...baseResult,
|
||||
conversationId: conversationId,
|
||||
}
|
||||
}
|
||||
|
||||
if (params.operation === 'delete') {
|
||||
const result: Record<string, any> = { ...baseResult }
|
||||
if (conversationId) result.conversationId = conversationId
|
||||
if (params.blockId) result.blockId = params.blockId
|
||||
if (params.blockName) result.blockName = params.blockName
|
||||
return result
|
||||
return {
|
||||
...baseResult,
|
||||
conversationId: conversationId,
|
||||
}
|
||||
}
|
||||
|
||||
return baseResult
|
||||
@@ -235,10 +171,8 @@ export const MemoryBlock: BlockConfig = {
|
||||
},
|
||||
inputs: {
|
||||
operation: { type: 'string', description: 'Operation to perform' },
|
||||
id: { type: 'string', description: 'Memory identifier (for add operation)' },
|
||||
id: { type: 'string', description: 'Memory identifier (conversation ID)' },
|
||||
conversationId: { type: 'string', description: 'Conversation identifier' },
|
||||
blockId: { type: 'string', description: 'Block identifier' },
|
||||
blockName: { type: 'string', description: 'Block name' },
|
||||
role: { type: 'string', description: 'Agent role' },
|
||||
content: { type: 'string', description: 'Memory content' },
|
||||
},
|
||||
|
||||
@@ -158,11 +158,19 @@ export const HTTP = {
|
||||
|
||||
export const AGENT = {
|
||||
DEFAULT_MODEL: 'claude-sonnet-4-5',
|
||||
DEFAULT_FUNCTION_TIMEOUT: 600000, // 10 minutes for custom tool code execution
|
||||
REQUEST_TIMEOUT: 600000, // 10 minutes for LLM API requests
|
||||
DEFAULT_FUNCTION_TIMEOUT: 600000,
|
||||
REQUEST_TIMEOUT: 600000,
|
||||
CUSTOM_TOOL_PREFIX: 'custom_',
|
||||
} as const
|
||||
|
||||
export const MEMORY = {
|
||||
DEFAULT_SLIDING_WINDOW_SIZE: 10,
|
||||
DEFAULT_SLIDING_WINDOW_TOKENS: 4000,
|
||||
CONTEXT_WINDOW_UTILIZATION: 0.9,
|
||||
MAX_CONVERSATION_ID_LENGTH: 255,
|
||||
MAX_MESSAGE_CONTENT_BYTES: 100 * 1024,
|
||||
} as const
|
||||
|
||||
export const ROUTER = {
|
||||
DEFAULT_MODEL: 'gpt-4o',
|
||||
DEFAULT_TEMPERATURE: 0,
|
||||
|
||||
@@ -1140,7 +1140,7 @@ describe('AgentBlockHandler', () => {
|
||||
expect(systemMessages[0].content).toBe('You are a helpful assistant.')
|
||||
})
|
||||
|
||||
it('should prioritize messages array system message over system messages in memories', async () => {
|
||||
it('should prefix agent system message before legacy memories', async () => {
|
||||
const inputs = {
|
||||
model: 'gpt-4o',
|
||||
messages: [
|
||||
@@ -1163,25 +1163,26 @@ describe('AgentBlockHandler', () => {
|
||||
const requestBody = JSON.parse(fetchCall[1].body)
|
||||
|
||||
// Verify messages were built correctly
|
||||
// Agent system (1) + legacy memories (3) + user from messages (1) = 5
|
||||
expect(requestBody.messages).toBeDefined()
|
||||
expect(requestBody.messages.length).toBe(5) // memory system + 2 non-system memories + 2 from messages array
|
||||
expect(requestBody.messages.length).toBe(5)
|
||||
|
||||
// All messages should be present (memories first, then messages array)
|
||||
// Memory messages come first
|
||||
// Agent's system message is prefixed first
|
||||
expect(requestBody.messages[0].role).toBe('system')
|
||||
expect(requestBody.messages[0].content).toBe('Old system message from memories.')
|
||||
expect(requestBody.messages[1].role).toBe('user')
|
||||
expect(requestBody.messages[1].content).toBe('Hello!')
|
||||
expect(requestBody.messages[2].role).toBe('assistant')
|
||||
expect(requestBody.messages[2].content).toBe('Hi there!')
|
||||
// Then messages array
|
||||
expect(requestBody.messages[3].role).toBe('system')
|
||||
expect(requestBody.messages[3].content).toBe('You are a helpful assistant.')
|
||||
expect(requestBody.messages[0].content).toBe('You are a helpful assistant.')
|
||||
// Then legacy memories (with their system message preserved)
|
||||
expect(requestBody.messages[1].role).toBe('system')
|
||||
expect(requestBody.messages[1].content).toBe('Old system message from memories.')
|
||||
expect(requestBody.messages[2].role).toBe('user')
|
||||
expect(requestBody.messages[2].content).toBe('Hello!')
|
||||
expect(requestBody.messages[3].role).toBe('assistant')
|
||||
expect(requestBody.messages[3].content).toBe('Hi there!')
|
||||
// Then user message from messages array
|
||||
expect(requestBody.messages[4].role).toBe('user')
|
||||
expect(requestBody.messages[4].content).toBe('What should I do?')
|
||||
})
|
||||
|
||||
it('should handle multiple system messages in memories with messages array', async () => {
|
||||
it('should prefix agent system message and preserve legacy memory system messages', async () => {
|
||||
const inputs = {
|
||||
model: 'gpt-4o',
|
||||
messages: [
|
||||
@@ -1207,21 +1208,23 @@ describe('AgentBlockHandler', () => {
|
||||
|
||||
// Verify messages were built correctly
|
||||
expect(requestBody.messages).toBeDefined()
|
||||
expect(requestBody.messages.length).toBe(7) // 5 memory messages (3 system + 2 conversation) + 2 from messages array
|
||||
expect(requestBody.messages.length).toBe(7)
|
||||
|
||||
// All messages should be present in order
|
||||
// Agent's system message prefixed first
|
||||
expect(requestBody.messages[0].role).toBe('system')
|
||||
expect(requestBody.messages[0].content).toBe('First system message.')
|
||||
expect(requestBody.messages[1].role).toBe('user')
|
||||
expect(requestBody.messages[1].content).toBe('Hello!')
|
||||
expect(requestBody.messages[2].role).toBe('system')
|
||||
expect(requestBody.messages[2].content).toBe('Second system message.')
|
||||
expect(requestBody.messages[3].role).toBe('assistant')
|
||||
expect(requestBody.messages[3].content).toBe('Hi there!')
|
||||
expect(requestBody.messages[4].role).toBe('system')
|
||||
expect(requestBody.messages[4].content).toBe('Third system message.')
|
||||
expect(requestBody.messages[0].content).toBe('You are a helpful assistant.')
|
||||
// Then legacy memories with their system messages preserved in order
|
||||
expect(requestBody.messages[1].role).toBe('system')
|
||||
expect(requestBody.messages[1].content).toBe('First system message.')
|
||||
expect(requestBody.messages[2].role).toBe('user')
|
||||
expect(requestBody.messages[2].content).toBe('Hello!')
|
||||
expect(requestBody.messages[3].role).toBe('system')
|
||||
expect(requestBody.messages[3].content).toBe('Second system message.')
|
||||
expect(requestBody.messages[4].role).toBe('assistant')
|
||||
expect(requestBody.messages[4].content).toBe('Hi there!')
|
||||
expect(requestBody.messages[5].role).toBe('system')
|
||||
expect(requestBody.messages[5].content).toBe('You are a helpful assistant.')
|
||||
expect(requestBody.messages[5].content).toBe('Third system message.')
|
||||
// Then user message from messages array
|
||||
expect(requestBody.messages[6].role).toBe('user')
|
||||
expect(requestBody.messages[6].content).toBe('Continue our conversation.')
|
||||
})
|
||||
|
||||
@@ -47,7 +47,7 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
const providerId = getProviderFromModel(model)
|
||||
const formattedTools = await this.formatTools(ctx, filteredInputs.tools || [])
|
||||
const streamingConfig = this.getStreamingConfig(ctx, block)
|
||||
const messages = await this.buildMessages(ctx, filteredInputs, block.id)
|
||||
const messages = await this.buildMessages(ctx, filteredInputs)
|
||||
|
||||
const providerRequest = this.buildProviderRequest({
|
||||
ctx,
|
||||
@@ -68,7 +68,20 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
filteredInputs
|
||||
)
|
||||
|
||||
await this.persistResponseToMemory(ctx, filteredInputs, result, block.id)
|
||||
if (this.isStreamingExecution(result)) {
|
||||
if (filteredInputs.memoryType && filteredInputs.memoryType !== 'none') {
|
||||
return this.wrapStreamForMemoryPersistence(
|
||||
ctx,
|
||||
filteredInputs,
|
||||
result as StreamingExecution
|
||||
)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
if (filteredInputs.memoryType && filteredInputs.memoryType !== 'none') {
|
||||
await this.persistResponseToMemory(ctx, filteredInputs, result as BlockOutput)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -686,81 +699,102 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
|
||||
private async buildMessages(
|
||||
ctx: ExecutionContext,
|
||||
inputs: AgentInputs,
|
||||
blockId: string
|
||||
inputs: AgentInputs
|
||||
): Promise<Message[] | undefined> {
|
||||
const messages: Message[] = []
|
||||
const memoryEnabled = inputs.memoryType && inputs.memoryType !== 'none'
|
||||
|
||||
// 1. Fetch memory history if configured (industry standard: chronological order)
|
||||
if (inputs.memoryType && inputs.memoryType !== 'none') {
|
||||
const memoryMessages = await memoryService.fetchMemoryMessages(ctx, inputs, blockId)
|
||||
messages.push(...memoryMessages)
|
||||
// 1. Extract and validate messages from messages-input subblock
|
||||
const inputMessages = this.extractValidMessages(inputs.messages)
|
||||
const systemMessages = inputMessages.filter((m) => m.role === 'system')
|
||||
const conversationMessages = inputMessages.filter((m) => m.role !== 'system')
|
||||
|
||||
// 2. Handle native memory: seed on first run, then fetch and append new user input
|
||||
if (memoryEnabled && ctx.workspaceId) {
|
||||
const memoryMessages = await memoryService.fetchMemoryMessages(ctx, inputs)
|
||||
const hasExisting = memoryMessages.length > 0
|
||||
|
||||
if (!hasExisting && conversationMessages.length > 0) {
|
||||
const taggedMessages = conversationMessages.map((m) =>
|
||||
m.role === 'user' ? { ...m, executionId: ctx.executionId } : m
|
||||
)
|
||||
await memoryService.seedMemory(ctx, inputs, taggedMessages)
|
||||
messages.push(...taggedMessages)
|
||||
} else {
|
||||
messages.push(...memoryMessages)
|
||||
|
||||
if (hasExisting && conversationMessages.length > 0) {
|
||||
const latestUserFromInput = conversationMessages.filter((m) => m.role === 'user').pop()
|
||||
if (latestUserFromInput) {
|
||||
const userMessageInThisRun = memoryMessages.some(
|
||||
(m) => m.role === 'user' && m.executionId === ctx.executionId
|
||||
)
|
||||
if (!userMessageInThisRun) {
|
||||
const taggedMessage = { ...latestUserFromInput, executionId: ctx.executionId }
|
||||
messages.push(taggedMessage)
|
||||
await memoryService.appendToMemory(ctx, inputs, taggedMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Process legacy memories (backward compatibility - from Memory block)
|
||||
// 3. Process legacy memories (backward compatibility - from Memory block)
|
||||
// These may include system messages which are preserved in their position
|
||||
if (inputs.memories) {
|
||||
messages.push(...this.processMemories(inputs.memories))
|
||||
}
|
||||
|
||||
// 3. Add messages array (new approach - from messages-input subblock)
|
||||
if (inputs.messages && Array.isArray(inputs.messages)) {
|
||||
const validMessages = inputs.messages.filter(
|
||||
(msg) =>
|
||||
msg &&
|
||||
typeof msg === 'object' &&
|
||||
'role' in msg &&
|
||||
'content' in msg &&
|
||||
['system', 'user', 'assistant'].includes(msg.role)
|
||||
)
|
||||
messages.push(...validMessages)
|
||||
// 4. Add conversation messages from inputs.messages (if not using native memory)
|
||||
// When memory is enabled, these are already seeded/fetched above
|
||||
if (!memoryEnabled && conversationMessages.length > 0) {
|
||||
messages.push(...conversationMessages)
|
||||
}
|
||||
|
||||
// Warn if using both new and legacy input formats
|
||||
if (
|
||||
inputs.messages &&
|
||||
inputs.messages.length > 0 &&
|
||||
(inputs.systemPrompt || inputs.userPrompt)
|
||||
) {
|
||||
logger.warn('Agent block using both messages array and legacy prompts', {
|
||||
hasMessages: true,
|
||||
hasSystemPrompt: !!inputs.systemPrompt,
|
||||
hasUserPrompt: !!inputs.userPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
// 4. Handle legacy systemPrompt (backward compatibility)
|
||||
// Only add if no system message exists yet
|
||||
if (inputs.systemPrompt && !messages.some((m) => m.role === 'system')) {
|
||||
this.addSystemPrompt(messages, inputs.systemPrompt)
|
||||
}
|
||||
|
||||
// 5. Handle legacy userPrompt (backward compatibility)
|
||||
if (inputs.userPrompt) {
|
||||
this.addUserPrompt(messages, inputs.userPrompt)
|
||||
}
|
||||
|
||||
// 6. Persist user message(s) to memory if configured
|
||||
// This ensures conversation history is complete before agent execution
|
||||
if (inputs.memoryType && inputs.memoryType !== 'none' && messages.length > 0) {
|
||||
// Find new user messages that need to be persisted
|
||||
// (messages added via messages array or userPrompt)
|
||||
const userMessages = messages.filter((m) => m.role === 'user')
|
||||
const lastUserMessage = userMessages[userMessages.length - 1]
|
||||
|
||||
// Only persist if there's a user message AND it's from userPrompt or messages input
|
||||
// (not from memory history which was already persisted)
|
||||
if (
|
||||
lastUserMessage &&
|
||||
(inputs.userPrompt || (inputs.messages && inputs.messages.length > 0))
|
||||
) {
|
||||
await memoryService.persistUserMessage(ctx, inputs, lastUserMessage, blockId)
|
||||
// 5. Handle legacy systemPrompt (backward compatibility)
|
||||
// Only add if no system message exists from any source
|
||||
if (inputs.systemPrompt) {
|
||||
const hasSystem = systemMessages.length > 0 || messages.some((m) => m.role === 'system')
|
||||
if (!hasSystem) {
|
||||
this.addSystemPrompt(messages, inputs.systemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// Return messages or undefined if empty (maintains API compatibility)
|
||||
// 6. Handle legacy userPrompt - this is NEW input each run
|
||||
if (inputs.userPrompt) {
|
||||
this.addUserPrompt(messages, inputs.userPrompt)
|
||||
|
||||
if (memoryEnabled) {
|
||||
const userMessages = messages.filter((m) => m.role === 'user')
|
||||
const lastUserMessage = userMessages[userMessages.length - 1]
|
||||
if (lastUserMessage) {
|
||||
await memoryService.appendToMemory(ctx, inputs, lastUserMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Prefix system messages from inputs.messages at the start (runtime only)
|
||||
// These are the agent's configured system prompts
|
||||
if (systemMessages.length > 0) {
|
||||
messages.unshift(...systemMessages)
|
||||
}
|
||||
|
||||
return messages.length > 0 ? messages : undefined
|
||||
}
|
||||
|
||||
private extractValidMessages(messages?: Message[]): Message[] {
|
||||
if (!messages || !Array.isArray(messages)) return []
|
||||
|
||||
return messages.filter(
|
||||
(msg): msg is Message =>
|
||||
msg &&
|
||||
typeof msg === 'object' &&
|
||||
'role' in msg &&
|
||||
'content' in msg &&
|
||||
['system', 'user', 'assistant'].includes(msg.role)
|
||||
)
|
||||
}
|
||||
|
||||
private processMemories(memories: any): Message[] {
|
||||
if (!memories) return []
|
||||
|
||||
@@ -1036,29 +1070,14 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
private async handleStreamingResponse(
|
||||
response: Response,
|
||||
block: SerializedBlock,
|
||||
ctx?: ExecutionContext,
|
||||
inputs?: AgentInputs
|
||||
_ctx?: ExecutionContext,
|
||||
_inputs?: AgentInputs
|
||||
): Promise<StreamingExecution> {
|
||||
const executionDataHeader = response.headers.get('X-Execution-Data')
|
||||
|
||||
if (executionDataHeader) {
|
||||
try {
|
||||
const executionData = JSON.parse(executionDataHeader)
|
||||
|
||||
// If execution data contains full content, persist to memory
|
||||
if (ctx && inputs && executionData.output?.content) {
|
||||
const assistantMessage: Message = {
|
||||
role: 'assistant',
|
||||
content: executionData.output.content,
|
||||
}
|
||||
// Fire and forget - don't await
|
||||
memoryService
|
||||
.persistMemoryMessage(ctx, inputs, assistantMessage, block.id)
|
||||
.catch((error) =>
|
||||
logger.error('Failed to persist streaming response to memory:', error)
|
||||
)
|
||||
}
|
||||
|
||||
return {
|
||||
stream: response.body!,
|
||||
execution: {
|
||||
@@ -1158,46 +1177,35 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
}
|
||||
}
|
||||
|
||||
private wrapStreamForMemoryPersistence(
|
||||
ctx: ExecutionContext,
|
||||
inputs: AgentInputs,
|
||||
streamingExec: StreamingExecution
|
||||
): StreamingExecution {
|
||||
return {
|
||||
stream: memoryService.wrapStreamForPersistence(streamingExec.stream, ctx, inputs),
|
||||
execution: streamingExec.execution,
|
||||
}
|
||||
}
|
||||
|
||||
private async persistResponseToMemory(
|
||||
ctx: ExecutionContext,
|
||||
inputs: AgentInputs,
|
||||
result: BlockOutput | StreamingExecution,
|
||||
blockId: string
|
||||
result: BlockOutput
|
||||
): Promise<void> {
|
||||
// Only persist if memoryType is configured
|
||||
if (!inputs.memoryType || inputs.memoryType === 'none') {
|
||||
const content = (result as any)?.content
|
||||
if (!content || typeof content !== 'string') {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
// Don't persist streaming responses here - they're handled separately
|
||||
if (this.isStreamingExecution(result)) {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract content from regular response
|
||||
const blockOutput = result as any
|
||||
const content = blockOutput?.content
|
||||
|
||||
if (!content || typeof content !== 'string') {
|
||||
return
|
||||
}
|
||||
|
||||
const assistantMessage: Message = {
|
||||
role: 'assistant',
|
||||
content,
|
||||
}
|
||||
|
||||
await memoryService.persistMemoryMessage(ctx, inputs, assistantMessage, blockId)
|
||||
|
||||
await memoryService.appendToMemory(ctx, inputs, { role: 'assistant', content })
|
||||
logger.debug('Persisted assistant response to memory', {
|
||||
workflowId: ctx.workflowId,
|
||||
memoryType: inputs.memoryType,
|
||||
conversationId: inputs.conversationId,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to persist response to memory:', error)
|
||||
// Don't throw - memory persistence failure shouldn't break workflow execution
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { MEMORY } from '@/executor/constants'
|
||||
import { Memory } from '@/executor/handlers/agent/memory'
|
||||
import type { AgentInputs, Message } from '@/executor/handlers/agent/types'
|
||||
import type { ExecutionContext } from '@/executor/types'
|
||||
import type { Message } from '@/executor/handlers/agent/types'
|
||||
|
||||
vi.mock('@/lib/logs/console/logger', () => ({
|
||||
createLogger: () => ({
|
||||
@@ -20,21 +20,14 @@ vi.mock('@/lib/tokenization/estimators', () => ({
|
||||
|
||||
describe('Memory', () => {
|
||||
let memoryService: Memory
|
||||
let mockContext: ExecutionContext
|
||||
|
||||
beforeEach(() => {
|
||||
memoryService = new Memory()
|
||||
mockContext = {
|
||||
workflowId: 'test-workflow-id',
|
||||
executionId: 'test-execution-id',
|
||||
workspaceId: 'test-workspace-id',
|
||||
} as ExecutionContext
|
||||
})
|
||||
|
||||
describe('applySlidingWindow (message-based)', () => {
|
||||
it('should keep last N conversation messages', () => {
|
||||
describe('applyWindow (message-based)', () => {
|
||||
it('should keep last N messages', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'system', content: 'System prompt' },
|
||||
{ role: 'user', content: 'Message 1' },
|
||||
{ role: 'assistant', content: 'Response 1' },
|
||||
{ role: 'user', content: 'Message 2' },
|
||||
@@ -43,55 +36,51 @@ describe('Memory', () => {
|
||||
{ role: 'assistant', content: 'Response 3' },
|
||||
]
|
||||
|
||||
const result = (memoryService as any).applySlidingWindow(messages, '4')
|
||||
const result = (memoryService as any).applyWindow(messages, 4)
|
||||
|
||||
expect(result.length).toBe(5)
|
||||
expect(result[0].role).toBe('system')
|
||||
expect(result[0].content).toBe('System prompt')
|
||||
expect(result[1].content).toBe('Message 2')
|
||||
expect(result[4].content).toBe('Response 3')
|
||||
expect(result.length).toBe(4)
|
||||
expect(result[0].content).toBe('Message 2')
|
||||
expect(result[3].content).toBe('Response 3')
|
||||
})
|
||||
|
||||
it('should preserve only first system message', () => {
|
||||
it('should return all messages if limit exceeds array length', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'system', content: 'First system' },
|
||||
{ role: 'user', content: 'User message' },
|
||||
{ role: 'system', content: 'Second system' },
|
||||
{ role: 'assistant', content: 'Assistant message' },
|
||||
{ role: 'user', content: 'Test' },
|
||||
{ role: 'assistant', content: 'Response' },
|
||||
]
|
||||
|
||||
const result = (memoryService as any).applySlidingWindow(messages, '10')
|
||||
|
||||
const systemMessages = result.filter((m: Message) => m.role === 'system')
|
||||
expect(systemMessages.length).toBe(1)
|
||||
expect(systemMessages[0].content).toBe('First system')
|
||||
const result = (memoryService as any).applyWindow(messages, 10)
|
||||
expect(result.length).toBe(2)
|
||||
})
|
||||
|
||||
it('should handle invalid window size', () => {
|
||||
const messages: Message[] = [{ role: 'user', content: 'Test' }]
|
||||
|
||||
const result = (memoryService as any).applySlidingWindow(messages, 'invalid')
|
||||
const result = (memoryService as any).applyWindow(messages, Number.NaN)
|
||||
expect(result).toEqual(messages)
|
||||
})
|
||||
|
||||
it('should handle zero limit', () => {
|
||||
const messages: Message[] = [{ role: 'user', content: 'Test' }]
|
||||
|
||||
const result = (memoryService as any).applyWindow(messages, 0)
|
||||
expect(result).toEqual(messages)
|
||||
})
|
||||
})
|
||||
|
||||
describe('applySlidingWindowByTokens (token-based)', () => {
|
||||
describe('applyTokenWindow (token-based)', () => {
|
||||
it('should keep messages within token limit', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'system', content: 'This is a system message' }, // ~6 tokens
|
||||
{ role: 'user', content: 'Short' }, // ~2 tokens
|
||||
{ role: 'assistant', content: 'This is a longer response message' }, // ~8 tokens
|
||||
{ role: 'user', content: 'Another user message here' }, // ~6 tokens
|
||||
{ role: 'assistant', content: 'Final response' }, // ~3 tokens
|
||||
{ role: 'user', content: 'Short' },
|
||||
{ role: 'assistant', content: 'This is a longer response message' },
|
||||
{ role: 'user', content: 'Another user message here' },
|
||||
{ role: 'assistant', content: 'Final response' },
|
||||
]
|
||||
|
||||
// Set limit to ~15 tokens - should include last 2-3 messages
|
||||
const result = (memoryService as any).applySlidingWindowByTokens(messages, '15', 'gpt-4o')
|
||||
const result = (memoryService as any).applyTokenWindow(messages, 15, 'gpt-4o')
|
||||
|
||||
expect(result.length).toBeGreaterThan(0)
|
||||
expect(result.length).toBeLessThan(messages.length)
|
||||
|
||||
// Should include newest messages
|
||||
expect(result[result.length - 1].content).toBe('Final response')
|
||||
})
|
||||
|
||||
@@ -104,30 +93,12 @@ describe('Memory', () => {
|
||||
},
|
||||
]
|
||||
|
||||
const result = (memoryService as any).applySlidingWindowByTokens(messages, '5', 'gpt-4o')
|
||||
const result = (memoryService as any).applyTokenWindow(messages, 5, 'gpt-4o')
|
||||
|
||||
expect(result.length).toBe(1)
|
||||
expect(result[0].content).toBe(messages[0].content)
|
||||
})
|
||||
|
||||
it('should preserve first system message and exclude it from token count', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'system', content: 'A' }, // System message - always preserved
|
||||
{ role: 'user', content: 'B' }, // ~1 token
|
||||
{ role: 'assistant', content: 'C' }, // ~1 token
|
||||
{ role: 'user', content: 'D' }, // ~1 token
|
||||
]
|
||||
|
||||
// Limit to 2 tokens - should fit system message + last 2 conversation messages (D, C)
|
||||
const result = (memoryService as any).applySlidingWindowByTokens(messages, '2', 'gpt-4o')
|
||||
|
||||
// Should have: system message + 2 conversation messages = 3 total
|
||||
expect(result.length).toBe(3)
|
||||
expect(result[0].role).toBe('system') // First system message preserved
|
||||
expect(result[1].content).toBe('C') // Second most recent conversation message
|
||||
expect(result[2].content).toBe('D') // Most recent conversation message
|
||||
})
|
||||
|
||||
it('should process messages from newest to oldest', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'user', content: 'Old message' },
|
||||
@@ -136,141 +107,101 @@ describe('Memory', () => {
|
||||
{ role: 'assistant', content: 'New response' },
|
||||
]
|
||||
|
||||
const result = (memoryService as any).applySlidingWindowByTokens(messages, '10', 'gpt-4o')
|
||||
const result = (memoryService as any).applyTokenWindow(messages, 10, 'gpt-4o')
|
||||
|
||||
// Should prioritize newer messages
|
||||
expect(result[result.length - 1].content).toBe('New response')
|
||||
})
|
||||
|
||||
it('should handle invalid token limit', () => {
|
||||
const messages: Message[] = [{ role: 'user', content: 'Test' }]
|
||||
|
||||
const result = (memoryService as any).applySlidingWindowByTokens(
|
||||
messages,
|
||||
'invalid',
|
||||
'gpt-4o'
|
||||
)
|
||||
expect(result).toEqual(messages) // Should return all messages
|
||||
const result = (memoryService as any).applyTokenWindow(messages, Number.NaN, 'gpt-4o')
|
||||
expect(result).toEqual(messages)
|
||||
})
|
||||
|
||||
it('should handle zero or negative token limit', () => {
|
||||
const messages: Message[] = [{ role: 'user', content: 'Test' }]
|
||||
|
||||
const result1 = (memoryService as any).applySlidingWindowByTokens(messages, '0', 'gpt-4o')
|
||||
const result1 = (memoryService as any).applyTokenWindow(messages, 0, 'gpt-4o')
|
||||
expect(result1).toEqual(messages)
|
||||
|
||||
const result2 = (memoryService as any).applySlidingWindowByTokens(messages, '-5', 'gpt-4o')
|
||||
const result2 = (memoryService as any).applyTokenWindow(messages, -5, 'gpt-4o')
|
||||
expect(result2).toEqual(messages)
|
||||
})
|
||||
|
||||
it('should work with different model names', () => {
|
||||
it('should work without model specified', () => {
|
||||
const messages: Message[] = [{ role: 'user', content: 'Test message' }]
|
||||
|
||||
const result1 = (memoryService as any).applySlidingWindowByTokens(messages, '100', 'gpt-4o')
|
||||
expect(result1.length).toBe(1)
|
||||
|
||||
const result2 = (memoryService as any).applySlidingWindowByTokens(
|
||||
messages,
|
||||
'100',
|
||||
'claude-3-5-sonnet-20241022'
|
||||
)
|
||||
expect(result2.length).toBe(1)
|
||||
|
||||
const result3 = (memoryService as any).applySlidingWindowByTokens(messages, '100', undefined)
|
||||
expect(result3.length).toBe(1)
|
||||
const result = (memoryService as any).applyTokenWindow(messages, 100, undefined)
|
||||
expect(result.length).toBe(1)
|
||||
})
|
||||
|
||||
it('should handle empty messages array', () => {
|
||||
const messages: Message[] = []
|
||||
|
||||
const result = (memoryService as any).applySlidingWindowByTokens(messages, '100', 'gpt-4o')
|
||||
const result = (memoryService as any).applyTokenWindow(messages, 100, 'gpt-4o')
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildMemoryKey', () => {
|
||||
it('should build correct key with conversationId:blockId format', () => {
|
||||
const inputs: AgentInputs = {
|
||||
memoryType: 'conversation',
|
||||
conversationId: 'emir',
|
||||
}
|
||||
|
||||
const key = (memoryService as any).buildMemoryKey(mockContext, inputs, 'test-block-id')
|
||||
expect(key).toBe('emir:test-block-id')
|
||||
})
|
||||
|
||||
it('should use same key format regardless of memory type', () => {
|
||||
const conversationId = 'user-123'
|
||||
const blockId = 'block-abc'
|
||||
|
||||
const conversationKey = (memoryService as any).buildMemoryKey(
|
||||
mockContext,
|
||||
{ memoryType: 'conversation', conversationId },
|
||||
blockId
|
||||
)
|
||||
const slidingWindowKey = (memoryService as any).buildMemoryKey(
|
||||
mockContext,
|
||||
{ memoryType: 'sliding_window', conversationId },
|
||||
blockId
|
||||
)
|
||||
const slidingTokensKey = (memoryService as any).buildMemoryKey(
|
||||
mockContext,
|
||||
{ memoryType: 'sliding_window_tokens', conversationId },
|
||||
blockId
|
||||
)
|
||||
|
||||
// All should produce the same key - memory type only affects processing
|
||||
expect(conversationKey).toBe('user-123:block-abc')
|
||||
expect(slidingWindowKey).toBe('user-123:block-abc')
|
||||
expect(slidingTokensKey).toBe('user-123:block-abc')
|
||||
})
|
||||
|
||||
describe('validateConversationId', () => {
|
||||
it('should throw error for missing conversationId', () => {
|
||||
const inputs: AgentInputs = {
|
||||
memoryType: 'conversation',
|
||||
// conversationId missing
|
||||
}
|
||||
|
||||
expect(() => {
|
||||
;(memoryService as any).buildMemoryKey(mockContext, inputs, 'test-block-id')
|
||||
}).toThrow('Conversation ID is required for all memory types')
|
||||
;(memoryService as any).validateConversationId(undefined)
|
||||
}).toThrow('Conversation ID is required')
|
||||
})
|
||||
|
||||
it('should throw error for empty conversationId', () => {
|
||||
const inputs: AgentInputs = {
|
||||
memoryType: 'conversation',
|
||||
conversationId: ' ', // Only whitespace
|
||||
}
|
||||
|
||||
expect(() => {
|
||||
;(memoryService as any).buildMemoryKey(mockContext, inputs, 'test-block-id')
|
||||
}).toThrow('Conversation ID is required for all memory types')
|
||||
;(memoryService as any).validateConversationId(' ')
|
||||
}).toThrow('Conversation ID is required')
|
||||
})
|
||||
|
||||
it('should throw error for too long conversationId', () => {
|
||||
const longId = 'a'.repeat(MEMORY.MAX_CONVERSATION_ID_LENGTH + 1)
|
||||
expect(() => {
|
||||
;(memoryService as any).validateConversationId(longId)
|
||||
}).toThrow('Conversation ID too long')
|
||||
})
|
||||
|
||||
it('should accept valid conversationId', () => {
|
||||
expect(() => {
|
||||
;(memoryService as any).validateConversationId('user-123')
|
||||
}).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('validateContent', () => {
|
||||
it('should throw error for content exceeding max size', () => {
|
||||
const largeContent = 'x'.repeat(MEMORY.MAX_MESSAGE_CONTENT_BYTES + 1)
|
||||
expect(() => {
|
||||
;(memoryService as any).validateContent(largeContent)
|
||||
}).toThrow('Message content too large')
|
||||
})
|
||||
|
||||
it('should accept content within limit', () => {
|
||||
const content = 'Normal sized content'
|
||||
expect(() => {
|
||||
;(memoryService as any).validateContent(content)
|
||||
}).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Token-based vs Message-based comparison', () => {
|
||||
it('should produce different results for same message count limit', () => {
|
||||
it('should produce different results for same limit concept', () => {
|
||||
const messages: Message[] = [
|
||||
{ role: 'user', content: 'A' }, // Short message (~1 token)
|
||||
{ role: 'user', content: 'A' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'This is a much longer response that takes many more tokens',
|
||||
}, // Long message (~15 tokens)
|
||||
{ role: 'user', content: 'B' }, // Short message (~1 token)
|
||||
},
|
||||
{ role: 'user', content: 'B' },
|
||||
]
|
||||
|
||||
// Message-based: last 2 messages
|
||||
const messageResult = (memoryService as any).applySlidingWindow(messages, '2')
|
||||
const messageResult = (memoryService as any).applyWindow(messages, 2)
|
||||
expect(messageResult.length).toBe(2)
|
||||
|
||||
// Token-based: with limit of 10 tokens, might fit all 3 messages or just last 2
|
||||
const tokenResult = (memoryService as any).applySlidingWindowByTokens(
|
||||
messages,
|
||||
'10',
|
||||
'gpt-4o'
|
||||
)
|
||||
|
||||
// The long message should affect what fits
|
||||
const tokenResult = (memoryService as any).applyTokenWindow(messages, 10, 'gpt-4o')
|
||||
expect(tokenResult.length).toBeGreaterThanOrEqual(1)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,661 +1,281 @@
|
||||
import { randomUUID } from 'node:crypto'
|
||||
import { db } from '@sim/db'
|
||||
import { memory } from '@sim/db/schema'
|
||||
import { and, eq, sql } from 'drizzle-orm'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getAccurateTokenCount } from '@/lib/tokenization/estimators'
|
||||
import { MEMORY } from '@/executor/constants'
|
||||
import type { AgentInputs, Message } from '@/executor/handlers/agent/types'
|
||||
import type { ExecutionContext } from '@/executor/types'
|
||||
import { buildAPIUrl, buildAuthHeaders } from '@/executor/utils/http'
|
||||
import { stringifyJSON } from '@/executor/utils/json'
|
||||
import { PROVIDER_DEFINITIONS } from '@/providers/models'
|
||||
|
||||
const logger = createLogger('Memory')
|
||||
|
||||
/**
|
||||
* Class for managing agent conversation memory
|
||||
* Handles fetching and persisting messages to the memory table
|
||||
*/
|
||||
export class Memory {
|
||||
/**
|
||||
* Fetch messages from memory based on memoryType configuration
|
||||
*/
|
||||
async fetchMemoryMessages(
|
||||
ctx: ExecutionContext,
|
||||
inputs: AgentInputs,
|
||||
blockId: string
|
||||
): Promise<Message[]> {
|
||||
async fetchMemoryMessages(ctx: ExecutionContext, inputs: AgentInputs): Promise<Message[]> {
|
||||
if (!inputs.memoryType || inputs.memoryType === 'none') {
|
||||
return []
|
||||
}
|
||||
|
||||
if (!ctx.workflowId) {
|
||||
logger.warn('Cannot fetch memory without workflowId')
|
||||
return []
|
||||
}
|
||||
const workspaceId = this.requireWorkspaceId(ctx)
|
||||
this.validateConversationId(inputs.conversationId)
|
||||
|
||||
try {
|
||||
this.validateInputs(inputs.conversationId)
|
||||
const messages = await this.fetchMemory(workspaceId, inputs.conversationId!)
|
||||
|
||||
const memoryKey = this.buildMemoryKey(ctx, inputs, blockId)
|
||||
let messages = await this.fetchFromMemoryAPI(ctx.workflowId, memoryKey)
|
||||
switch (inputs.memoryType) {
|
||||
case 'conversation':
|
||||
return this.applyContextWindowLimit(messages, inputs.model)
|
||||
|
||||
switch (inputs.memoryType) {
|
||||
case 'conversation':
|
||||
messages = this.applyContextWindowLimit(messages, inputs.model)
|
||||
break
|
||||
|
||||
case 'sliding_window': {
|
||||
// Default to 10 messages if not specified (matches agent block default)
|
||||
const windowSize = inputs.slidingWindowSize || '10'
|
||||
messages = this.applySlidingWindow(messages, windowSize)
|
||||
break
|
||||
}
|
||||
|
||||
case 'sliding_window_tokens': {
|
||||
// Default to 4000 tokens if not specified (matches agent block default)
|
||||
const maxTokens = inputs.slidingWindowTokens || '4000'
|
||||
messages = this.applySlidingWindowByTokens(messages, maxTokens, inputs.model)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
} catch (error) {
|
||||
logger.error('Failed to fetch memory messages:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Persist assistant response to memory
|
||||
* Uses atomic append operations to prevent race conditions
|
||||
*/
|
||||
async persistMemoryMessage(
|
||||
ctx: ExecutionContext,
|
||||
inputs: AgentInputs,
|
||||
assistantMessage: Message,
|
||||
blockId: string
|
||||
): Promise<void> {
|
||||
if (!inputs.memoryType || inputs.memoryType === 'none') {
|
||||
return
|
||||
}
|
||||
|
||||
if (!ctx.workflowId) {
|
||||
logger.warn('Cannot persist memory without workflowId')
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
this.validateInputs(inputs.conversationId, assistantMessage.content)
|
||||
|
||||
const memoryKey = this.buildMemoryKey(ctx, inputs, blockId)
|
||||
|
||||
if (inputs.memoryType === 'sliding_window') {
|
||||
// Default to 10 messages if not specified (matches agent block default)
|
||||
const windowSize = inputs.slidingWindowSize || '10'
|
||||
|
||||
const existingMessages = await this.fetchFromMemoryAPI(ctx.workflowId, memoryKey)
|
||||
const updatedMessages = [...existingMessages, assistantMessage]
|
||||
const messagesToPersist = this.applySlidingWindow(updatedMessages, windowSize)
|
||||
|
||||
await this.persistToMemoryAPI(ctx.workflowId, memoryKey, messagesToPersist)
|
||||
} else if (inputs.memoryType === 'sliding_window_tokens') {
|
||||
// Default to 4000 tokens if not specified (matches agent block default)
|
||||
const maxTokens = inputs.slidingWindowTokens || '4000'
|
||||
|
||||
const existingMessages = await this.fetchFromMemoryAPI(ctx.workflowId, memoryKey)
|
||||
const updatedMessages = [...existingMessages, assistantMessage]
|
||||
const messagesToPersist = this.applySlidingWindowByTokens(
|
||||
updatedMessages,
|
||||
maxTokens,
|
||||
inputs.model
|
||||
case 'sliding_window': {
|
||||
const limit = this.parsePositiveInt(
|
||||
inputs.slidingWindowSize,
|
||||
MEMORY.DEFAULT_SLIDING_WINDOW_SIZE
|
||||
)
|
||||
|
||||
await this.persistToMemoryAPI(ctx.workflowId, memoryKey, messagesToPersist)
|
||||
} else {
|
||||
// Conversation mode: use atomic append for better concurrency
|
||||
await this.atomicAppendToMemory(ctx.workflowId, memoryKey, assistantMessage)
|
||||
return this.applyWindow(messages, limit)
|
||||
}
|
||||
|
||||
logger.debug('Successfully persisted memory message', {
|
||||
workflowId: ctx.workflowId,
|
||||
key: memoryKey,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to persist memory message:', error)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Persist user message to memory before agent execution
|
||||
*/
|
||||
async persistUserMessage(
|
||||
ctx: ExecutionContext,
|
||||
inputs: AgentInputs,
|
||||
userMessage: Message,
|
||||
blockId: string
|
||||
): Promise<void> {
|
||||
if (!inputs.memoryType || inputs.memoryType === 'none') {
|
||||
return
|
||||
}
|
||||
|
||||
if (!ctx.workflowId) {
|
||||
logger.warn('Cannot persist user message without workflowId')
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const memoryKey = this.buildMemoryKey(ctx, inputs, blockId)
|
||||
|
||||
if (inputs.slidingWindowSize && inputs.memoryType === 'sliding_window') {
|
||||
const existingMessages = await this.fetchFromMemoryAPI(ctx.workflowId, memoryKey)
|
||||
const updatedMessages = [...existingMessages, userMessage]
|
||||
const messagesToPersist = this.applySlidingWindow(updatedMessages, inputs.slidingWindowSize)
|
||||
await this.persistToMemoryAPI(ctx.workflowId, memoryKey, messagesToPersist)
|
||||
} else if (inputs.slidingWindowTokens && inputs.memoryType === 'sliding_window_tokens') {
|
||||
const existingMessages = await this.fetchFromMemoryAPI(ctx.workflowId, memoryKey)
|
||||
const updatedMessages = [...existingMessages, userMessage]
|
||||
const messagesToPersist = this.applySlidingWindowByTokens(
|
||||
updatedMessages,
|
||||
case 'sliding_window_tokens': {
|
||||
const maxTokens = this.parsePositiveInt(
|
||||
inputs.slidingWindowTokens,
|
||||
inputs.model
|
||||
MEMORY.DEFAULT_SLIDING_WINDOW_TOKENS
|
||||
)
|
||||
await this.persistToMemoryAPI(ctx.workflowId, memoryKey, messagesToPersist)
|
||||
} else {
|
||||
await this.atomicAppendToMemory(ctx.workflowId, memoryKey, userMessage)
|
||||
return this.applyTokenWindow(messages, maxTokens, inputs.model)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to persist user message:', error)
|
||||
|
||||
default:
|
||||
return messages
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build memory key based on conversationId and blockId
|
||||
* BlockId provides block-level memory isolation
|
||||
*/
|
||||
private buildMemoryKey(_ctx: ExecutionContext, inputs: AgentInputs, blockId: string): string {
|
||||
const { conversationId } = inputs
|
||||
async appendToMemory(
|
||||
ctx: ExecutionContext,
|
||||
inputs: AgentInputs,
|
||||
message: Message
|
||||
): Promise<void> {
|
||||
if (!inputs.memoryType || inputs.memoryType === 'none') {
|
||||
return
|
||||
}
|
||||
|
||||
if (!conversationId || conversationId.trim() === '') {
|
||||
throw new Error(
|
||||
'Conversation ID is required for all memory types. ' +
|
||||
'Please provide a unique identifier (e.g., user-123, session-abc, customer-456).'
|
||||
const workspaceId = this.requireWorkspaceId(ctx)
|
||||
this.validateConversationId(inputs.conversationId)
|
||||
this.validateContent(message.content)
|
||||
|
||||
const key = inputs.conversationId!
|
||||
|
||||
await this.appendMessage(workspaceId, key, message)
|
||||
|
||||
logger.debug('Appended message to memory', {
|
||||
workspaceId,
|
||||
key,
|
||||
role: message.role,
|
||||
})
|
||||
}
|
||||
|
||||
async seedMemory(ctx: ExecutionContext, inputs: AgentInputs, messages: Message[]): Promise<void> {
|
||||
if (!inputs.memoryType || inputs.memoryType === 'none') {
|
||||
return
|
||||
}
|
||||
|
||||
const workspaceId = this.requireWorkspaceId(ctx)
|
||||
|
||||
const conversationMessages = messages.filter((m) => m.role !== 'system')
|
||||
if (conversationMessages.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
this.validateConversationId(inputs.conversationId)
|
||||
|
||||
const key = inputs.conversationId!
|
||||
|
||||
let messagesToStore = conversationMessages
|
||||
if (inputs.memoryType === 'sliding_window') {
|
||||
const limit = this.parsePositiveInt(
|
||||
inputs.slidingWindowSize,
|
||||
MEMORY.DEFAULT_SLIDING_WINDOW_SIZE
|
||||
)
|
||||
messagesToStore = this.applyWindow(conversationMessages, limit)
|
||||
} else if (inputs.memoryType === 'sliding_window_tokens') {
|
||||
const maxTokens = this.parsePositiveInt(
|
||||
inputs.slidingWindowTokens,
|
||||
MEMORY.DEFAULT_SLIDING_WINDOW_TOKENS
|
||||
)
|
||||
messagesToStore = this.applyTokenWindow(conversationMessages, maxTokens, inputs.model)
|
||||
}
|
||||
|
||||
return `${conversationId}:${blockId}`
|
||||
await this.seedMemoryRecord(workspaceId, key, messagesToStore)
|
||||
|
||||
logger.debug('Seeded memory', {
|
||||
workspaceId,
|
||||
key,
|
||||
count: messagesToStore.length,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply sliding window to limit number of conversation messages
|
||||
*
|
||||
* System message handling:
|
||||
* - System messages are excluded from the sliding window count
|
||||
* - Only the first system message is preserved and placed at the start
|
||||
* - This ensures system prompts remain available while limiting conversation history
|
||||
*/
|
||||
private applySlidingWindow(messages: Message[], windowSize: string): Message[] {
|
||||
const limit = Number.parseInt(windowSize, 10)
|
||||
wrapStreamForPersistence(
|
||||
stream: ReadableStream<Uint8Array>,
|
||||
ctx: ExecutionContext,
|
||||
inputs: AgentInputs
|
||||
): ReadableStream<Uint8Array> {
|
||||
let accumulatedContent = ''
|
||||
const decoder = new TextDecoder()
|
||||
|
||||
if (Number.isNaN(limit) || limit <= 0) {
|
||||
logger.warn('Invalid sliding window size, returning all messages', { windowSize })
|
||||
return messages
|
||||
}
|
||||
const transformStream = new TransformStream<Uint8Array, Uint8Array>({
|
||||
transform: (chunk, controller) => {
|
||||
controller.enqueue(chunk)
|
||||
const decoded = decoder.decode(chunk, { stream: true })
|
||||
accumulatedContent += decoded
|
||||
},
|
||||
|
||||
const systemMessages = messages.filter((msg) => msg.role === 'system')
|
||||
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
|
||||
|
||||
const recentMessages = conversationMessages.slice(-limit)
|
||||
|
||||
const firstSystemMessage = systemMessages.length > 0 ? [systemMessages[0]] : []
|
||||
|
||||
return [...firstSystemMessage, ...recentMessages]
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply token-based sliding window to limit conversation by token count
|
||||
*
|
||||
* System message handling:
|
||||
* - For consistency with message-based sliding window, the first system message is preserved
|
||||
* - System messages are excluded from the token count
|
||||
* - This ensures system prompts are always available while limiting conversation history
|
||||
*/
|
||||
private applySlidingWindowByTokens(
|
||||
messages: Message[],
|
||||
maxTokens: string,
|
||||
model?: string
|
||||
): Message[] {
|
||||
const tokenLimit = Number.parseInt(maxTokens, 10)
|
||||
|
||||
if (Number.isNaN(tokenLimit) || tokenLimit <= 0) {
|
||||
logger.warn('Invalid token limit, returning all messages', { maxTokens })
|
||||
return messages
|
||||
}
|
||||
|
||||
// Separate system messages from conversation messages for consistent handling
|
||||
const systemMessages = messages.filter((msg) => msg.role === 'system')
|
||||
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
|
||||
|
||||
const result: Message[] = []
|
||||
let currentTokenCount = 0
|
||||
|
||||
// Add conversation messages from most recent backwards
|
||||
for (let i = conversationMessages.length - 1; i >= 0; i--) {
|
||||
const message = conversationMessages[i]
|
||||
const messageTokens = getAccurateTokenCount(message.content, model)
|
||||
|
||||
if (currentTokenCount + messageTokens <= tokenLimit) {
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
} else if (result.length === 0) {
|
||||
logger.warn('Single message exceeds token limit, including anyway', {
|
||||
messageTokens,
|
||||
tokenLimit,
|
||||
messageRole: message.role,
|
||||
})
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
break
|
||||
} else {
|
||||
// Token limit reached, stop processing
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('Applied token-based sliding window', {
|
||||
totalMessages: messages.length,
|
||||
conversationMessages: conversationMessages.length,
|
||||
includedMessages: result.length,
|
||||
totalTokens: currentTokenCount,
|
||||
tokenLimit,
|
||||
flush: () => {
|
||||
if (accumulatedContent.trim()) {
|
||||
this.appendToMemory(ctx, inputs, {
|
||||
role: 'assistant',
|
||||
content: accumulatedContent,
|
||||
}).catch((error) => logger.error('Failed to persist streaming response:', error))
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
// Preserve first system message and prepend to results (consistent with message-based window)
|
||||
const firstSystemMessage = systemMessages.length > 0 ? [systemMessages[0]] : []
|
||||
return [...firstSystemMessage, ...result]
|
||||
return stream.pipeThrough(transformStream)
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply context window limit based on model's maximum context window
|
||||
* Auto-trims oldest conversation messages when approaching the model's context limit
|
||||
* Uses 90% of context window (10% buffer for response)
|
||||
* Only applies if model has contextWindow defined and contextInformationAvailable !== false
|
||||
*/
|
||||
private applyContextWindowLimit(messages: Message[], model?: string): Message[] {
|
||||
if (!model) {
|
||||
return messages
|
||||
private requireWorkspaceId(ctx: ExecutionContext): string {
|
||||
if (!ctx.workspaceId) {
|
||||
throw new Error('workspaceId is required for memory operations')
|
||||
}
|
||||
return ctx.workspaceId
|
||||
}
|
||||
|
||||
private applyWindow(messages: Message[], limit: number): Message[] {
|
||||
return messages.slice(-limit)
|
||||
}
|
||||
|
||||
private applyTokenWindow(messages: Message[], maxTokens: number, model?: string): Message[] {
|
||||
const result: Message[] = []
|
||||
let tokenCount = 0
|
||||
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const msg = messages[i]
|
||||
const msgTokens = getAccurateTokenCount(msg.content, model)
|
||||
|
||||
if (tokenCount + msgTokens <= maxTokens) {
|
||||
result.unshift(msg)
|
||||
tokenCount += msgTokens
|
||||
} else if (result.length === 0) {
|
||||
result.unshift(msg)
|
||||
break
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
let contextWindow: number | undefined
|
||||
return result
|
||||
}
|
||||
|
||||
private applyContextWindowLimit(messages: Message[], model?: string): Message[] {
|
||||
if (!model) return messages
|
||||
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
if (provider.contextInformationAvailable === false) {
|
||||
continue
|
||||
}
|
||||
if (provider.contextInformationAvailable === false) continue
|
||||
|
||||
const matchesPattern = provider.modelPatterns?.some((pattern) => pattern.test(model))
|
||||
const matchesPattern = provider.modelPatterns?.some((p) => p.test(model))
|
||||
const matchesModel = provider.models.some((m) => m.id === model)
|
||||
|
||||
if (matchesPattern || matchesModel) {
|
||||
const modelDef = provider.models.find((m) => m.id === model)
|
||||
if (modelDef?.contextWindow) {
|
||||
contextWindow = modelDef.contextWindow
|
||||
break
|
||||
const maxTokens = Math.floor(modelDef.contextWindow * MEMORY.CONTEXT_WINDOW_UTILIZATION)
|
||||
return this.applyTokenWindow(messages, maxTokens, model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!contextWindow) {
|
||||
logger.debug('No context window information available for model, skipping auto-trim', {
|
||||
model,
|
||||
})
|
||||
return messages
|
||||
}
|
||||
|
||||
const maxTokens = Math.floor(contextWindow * 0.9)
|
||||
|
||||
logger.debug('Applying context window limit', {
|
||||
model,
|
||||
contextWindow,
|
||||
maxTokens,
|
||||
totalMessages: messages.length,
|
||||
})
|
||||
|
||||
const systemMessages = messages.filter((msg) => msg.role === 'system')
|
||||
const conversationMessages = messages.filter((msg) => msg.role !== 'system')
|
||||
|
||||
// Count tokens used by system messages first
|
||||
let systemTokenCount = 0
|
||||
for (const msg of systemMessages) {
|
||||
systemTokenCount += getAccurateTokenCount(msg.content, model)
|
||||
}
|
||||
|
||||
// Calculate remaining tokens available for conversation messages
|
||||
const remainingTokens = Math.max(0, maxTokens - systemTokenCount)
|
||||
|
||||
if (systemTokenCount >= maxTokens) {
|
||||
logger.warn('System messages exceed context window limit, including anyway', {
|
||||
systemTokenCount,
|
||||
maxTokens,
|
||||
systemMessageCount: systemMessages.length,
|
||||
})
|
||||
return systemMessages
|
||||
}
|
||||
|
||||
const result: Message[] = []
|
||||
let currentTokenCount = 0
|
||||
|
||||
for (let i = conversationMessages.length - 1; i >= 0; i--) {
|
||||
const message = conversationMessages[i]
|
||||
const messageTokens = getAccurateTokenCount(message.content, model)
|
||||
|
||||
if (currentTokenCount + messageTokens <= remainingTokens) {
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
} else if (result.length === 0) {
|
||||
logger.warn('Single message exceeds remaining context window, including anyway', {
|
||||
messageTokens,
|
||||
remainingTokens,
|
||||
systemTokenCount,
|
||||
messageRole: message.role,
|
||||
})
|
||||
result.unshift(message)
|
||||
currentTokenCount += messageTokens
|
||||
break
|
||||
} else {
|
||||
logger.info('Auto-trimmed conversation history to fit context window', {
|
||||
originalMessages: conversationMessages.length,
|
||||
trimmedMessages: result.length,
|
||||
conversationTokens: currentTokenCount,
|
||||
systemTokens: systemTokenCount,
|
||||
totalTokens: currentTokenCount + systemTokenCount,
|
||||
maxTokens,
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return [...systemMessages, ...result]
|
||||
return messages
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch messages from memory API
|
||||
*/
|
||||
private async fetchFromMemoryAPI(workflowId: string, key: string): Promise<Message[]> {
|
||||
try {
|
||||
const isBrowser = typeof window !== 'undefined'
|
||||
private async fetchMemory(workspaceId: string, key: string): Promise<Message[]> {
|
||||
const result = await db
|
||||
.select({ data: memory.data })
|
||||
.from(memory)
|
||||
.where(and(eq(memory.workspaceId, workspaceId), eq(memory.key, key)))
|
||||
.limit(1)
|
||||
|
||||
if (!isBrowser) {
|
||||
return await this.fetchFromMemoryDirect(workflowId, key)
|
||||
}
|
||||
if (result.length === 0) return []
|
||||
|
||||
const headers = await buildAuthHeaders()
|
||||
const url = buildAPIUrl(`/api/memory/${encodeURIComponent(key)}`, { workflowId })
|
||||
const data = result[0].data
|
||||
if (!Array.isArray(data)) return []
|
||||
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'GET',
|
||||
headers,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
if (response.status === 404) {
|
||||
return []
|
||||
}
|
||||
throw new Error(`Failed to fetch memory: ${response.status} ${response.statusText}`)
|
||||
}
|
||||
|
||||
const result = await response.json()
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(result.error || 'Failed to fetch memory')
|
||||
}
|
||||
|
||||
const memoryData = result.data?.data || result.data
|
||||
if (Array.isArray(memoryData)) {
|
||||
return memoryData.filter(
|
||||
(msg) => msg && typeof msg === 'object' && 'role' in msg && 'content' in msg
|
||||
)
|
||||
}
|
||||
|
||||
return []
|
||||
} catch (error) {
|
||||
logger.error('Error fetching from memory API:', error)
|
||||
return []
|
||||
}
|
||||
return data.filter(
|
||||
(msg): msg is Message => msg && typeof msg === 'object' && 'role' in msg && 'content' in msg
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Direct database access
|
||||
*/
|
||||
private async fetchFromMemoryDirect(workflowId: string, key: string): Promise<Message[]> {
|
||||
try {
|
||||
const { db } = await import('@sim/db')
|
||||
const { memory } = await import('@sim/db/schema')
|
||||
const { and, eq } = await import('drizzle-orm')
|
||||
|
||||
const result = await db
|
||||
.select({
|
||||
data: memory.data,
|
||||
})
|
||||
.from(memory)
|
||||
.where(and(eq(memory.workflowId, workflowId), eq(memory.key, key)))
|
||||
.limit(1)
|
||||
|
||||
if (result.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
const memoryData = result[0].data as any
|
||||
if (Array.isArray(memoryData)) {
|
||||
return memoryData.filter(
|
||||
(msg) => msg && typeof msg === 'object' && 'role' in msg && 'content' in msg
|
||||
)
|
||||
}
|
||||
|
||||
return []
|
||||
} catch (error) {
|
||||
logger.error('Error fetching from memory database:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Persist messages to memory API
|
||||
*/
|
||||
private async persistToMemoryAPI(
|
||||
workflowId: string,
|
||||
private async seedMemoryRecord(
|
||||
workspaceId: string,
|
||||
key: string,
|
||||
messages: Message[]
|
||||
): Promise<void> {
|
||||
try {
|
||||
const isBrowser = typeof window !== 'undefined'
|
||||
const now = new Date()
|
||||
|
||||
if (!isBrowser) {
|
||||
await this.persistToMemoryDirect(workflowId, key, messages)
|
||||
return
|
||||
}
|
||||
|
||||
const headers = await buildAuthHeaders()
|
||||
const url = buildAPIUrl('/api/memory')
|
||||
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
...headers,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: stringifyJSON({
|
||||
workflowId,
|
||||
key,
|
||||
data: messages,
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to persist memory: ${response.status} ${response.statusText}`)
|
||||
}
|
||||
|
||||
const result = await response.json()
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(result.error || 'Failed to persist memory')
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error persisting to memory API:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Atomically append a message to memory
|
||||
*/
|
||||
private async atomicAppendToMemory(
|
||||
workflowId: string,
|
||||
key: string,
|
||||
message: Message
|
||||
): Promise<void> {
|
||||
try {
|
||||
const isBrowser = typeof window !== 'undefined'
|
||||
|
||||
if (!isBrowser) {
|
||||
await this.atomicAppendToMemoryDirect(workflowId, key, message)
|
||||
} else {
|
||||
const headers = await buildAuthHeaders()
|
||||
const url = buildAPIUrl('/api/memory')
|
||||
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
...headers,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: stringifyJSON({
|
||||
workflowId,
|
||||
key,
|
||||
data: message,
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to append memory: ${response.status} ${response.statusText}`)
|
||||
}
|
||||
|
||||
const result = await response.json()
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(result.error || 'Failed to append memory')
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error appending to memory:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Direct database atomic append for server-side
|
||||
* Uses PostgreSQL JSONB concatenation operator for atomic operations
|
||||
*/
|
||||
private async atomicAppendToMemoryDirect(
|
||||
workflowId: string,
|
||||
key: string,
|
||||
message: Message
|
||||
): Promise<void> {
|
||||
try {
|
||||
const { db } = await import('@sim/db')
|
||||
const { memory } = await import('@sim/db/schema')
|
||||
const { sql } = await import('drizzle-orm')
|
||||
const { randomUUID } = await import('node:crypto')
|
||||
|
||||
const now = new Date()
|
||||
const id = randomUUID()
|
||||
|
||||
await db
|
||||
.insert(memory)
|
||||
.values({
|
||||
id,
|
||||
workflowId,
|
||||
key,
|
||||
data: [message],
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
})
|
||||
.onConflictDoUpdate({
|
||||
target: [memory.workflowId, memory.key],
|
||||
set: {
|
||||
data: sql`${memory.data} || ${JSON.stringify([message])}::jsonb`,
|
||||
updatedAt: now,
|
||||
},
|
||||
})
|
||||
|
||||
logger.debug('Atomically appended message to memory', {
|
||||
workflowId,
|
||||
await db
|
||||
.insert(memory)
|
||||
.values({
|
||||
id: randomUUID(),
|
||||
workspaceId,
|
||||
key,
|
||||
data: messages,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error in atomic append to memory database:', error)
|
||||
throw error
|
||||
}
|
||||
.onConflictDoNothing()
|
||||
}
|
||||
|
||||
/**
|
||||
* Direct database access for server-side persistence
|
||||
* Uses UPSERT to handle race conditions atomically
|
||||
*/
|
||||
private async persistToMemoryDirect(
|
||||
workflowId: string,
|
||||
key: string,
|
||||
messages: Message[]
|
||||
): Promise<void> {
|
||||
try {
|
||||
const { db } = await import('@sim/db')
|
||||
const { memory } = await import('@sim/db/schema')
|
||||
const { randomUUID } = await import('node:crypto')
|
||||
private async appendMessage(workspaceId: string, key: string, message: Message): Promise<void> {
|
||||
const now = new Date()
|
||||
|
||||
const now = new Date()
|
||||
const id = randomUUID()
|
||||
|
||||
await db
|
||||
.insert(memory)
|
||||
.values({
|
||||
id,
|
||||
workflowId,
|
||||
key,
|
||||
data: messages,
|
||||
createdAt: now,
|
||||
await db
|
||||
.insert(memory)
|
||||
.values({
|
||||
id: randomUUID(),
|
||||
workspaceId,
|
||||
key,
|
||||
data: [message],
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
})
|
||||
.onConflictDoUpdate({
|
||||
target: [memory.workspaceId, memory.key],
|
||||
set: {
|
||||
data: sql`${memory.data} || ${JSON.stringify([message])}::jsonb`,
|
||||
updatedAt: now,
|
||||
})
|
||||
.onConflictDoUpdate({
|
||||
target: [memory.workflowId, memory.key],
|
||||
set: {
|
||||
data: messages,
|
||||
updatedAt: now,
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error persisting to memory database:', error)
|
||||
throw error
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
private parsePositiveInt(value: string | undefined, defaultValue: number): number {
|
||||
if (!value) return defaultValue
|
||||
const parsed = Number.parseInt(value, 10)
|
||||
if (Number.isNaN(parsed) || parsed <= 0) return defaultValue
|
||||
return parsed
|
||||
}
|
||||
|
||||
private validateConversationId(conversationId?: string): void {
|
||||
if (!conversationId || conversationId.trim() === '') {
|
||||
throw new Error('Conversation ID is required')
|
||||
}
|
||||
if (conversationId.length > MEMORY.MAX_CONVERSATION_ID_LENGTH) {
|
||||
throw new Error(
|
||||
`Conversation ID too long (max ${MEMORY.MAX_CONVERSATION_ID_LENGTH} characters)`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate inputs to prevent malicious data or performance issues
|
||||
*/
|
||||
private validateInputs(conversationId?: string, content?: string): void {
|
||||
if (conversationId) {
|
||||
if (conversationId.length > 255) {
|
||||
throw new Error('Conversation ID too long (max 255 characters)')
|
||||
}
|
||||
|
||||
if (!/^[a-zA-Z0-9_\-:.@]+$/.test(conversationId)) {
|
||||
logger.warn('Conversation ID contains special characters', { conversationId })
|
||||
}
|
||||
}
|
||||
|
||||
if (content) {
|
||||
const contentSize = Buffer.byteLength(content, 'utf8')
|
||||
const MAX_CONTENT_SIZE = 100 * 1024 // 100KB
|
||||
|
||||
if (contentSize > MAX_CONTENT_SIZE) {
|
||||
throw new Error(`Message content too large (${contentSize} bytes, max ${MAX_CONTENT_SIZE})`)
|
||||
}
|
||||
private validateContent(content: string): void {
|
||||
const size = Buffer.byteLength(content, 'utf8')
|
||||
if (size > MEMORY.MAX_MESSAGE_CONTENT_BYTES) {
|
||||
throw new Error(
|
||||
`Message content too large (${size} bytes, max ${MEMORY.MAX_MESSAGE_CONTENT_BYTES})`
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ export interface ToolInput {
|
||||
export interface Message {
|
||||
role: 'system' | 'user' | 'assistant'
|
||||
content: string
|
||||
executionId?: string
|
||||
function_call?: any
|
||||
tool_calls?: any[]
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { ProviderName } from '@/stores/providers/types'
|
||||
import type { OpenRouterModelInfo, ProviderName } from '@/stores/providers/types'
|
||||
|
||||
const logger = createLogger('ProviderModelsQuery')
|
||||
|
||||
@@ -11,7 +11,12 @@ const providerEndpoints: Record<ProviderName, string> = {
|
||||
openrouter: '/api/providers/openrouter/models',
|
||||
}
|
||||
|
||||
async function fetchProviderModels(provider: ProviderName): Promise<string[]> {
|
||||
interface ProviderModelsResponse {
|
||||
models: string[]
|
||||
modelInfo?: Record<string, OpenRouterModelInfo>
|
||||
}
|
||||
|
||||
async function fetchProviderModels(provider: ProviderName): Promise<ProviderModelsResponse> {
|
||||
const response = await fetch(providerEndpoints[provider])
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -24,8 +29,12 @@ async function fetchProviderModels(provider: ProviderName): Promise<string[]> {
|
||||
|
||||
const data = await response.json()
|
||||
const models: string[] = Array.isArray(data.models) ? data.models : []
|
||||
const uniqueModels = provider === 'openrouter' ? Array.from(new Set(models)) : models
|
||||
|
||||
return provider === 'openrouter' ? Array.from(new Set(models)) : models
|
||||
return {
|
||||
models: uniqueModels,
|
||||
modelInfo: data.modelInfo,
|
||||
}
|
||||
}
|
||||
|
||||
export function useProviderModels(provider: ProviderName) {
|
||||
|
||||
@@ -24,20 +24,6 @@ describe('Email Validation', () => {
|
||||
expect(result.checks.disposable).toBe(false)
|
||||
})
|
||||
|
||||
it.concurrent('should accept legitimate business emails', async () => {
|
||||
const legitimateEmails = [
|
||||
'test@gmail.com',
|
||||
'no-reply@yahoo.com',
|
||||
'user12345@outlook.com',
|
||||
'longusernamehere@gmail.com',
|
||||
]
|
||||
|
||||
for (const email of legitimateEmails) {
|
||||
const result = await validateEmail(email)
|
||||
expect(result.isValid).toBe(true)
|
||||
}
|
||||
})
|
||||
|
||||
it.concurrent('should reject consecutive dots (RFC violation)', async () => {
|
||||
const result = await validateEmail('user..name@example.com')
|
||||
expect(result.isValid).toBe(false)
|
||||
|
||||
114
apps/sim/lib/workflows/executor/execute-workflow.ts
Normal file
114
apps/sim/lib/workflows/executor/execute-workflow.ts
Normal file
@@ -0,0 +1,114 @@
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { LoggingSession } from '@/lib/logs/execution/logging-session'
|
||||
import { executeWorkflowCore } from '@/lib/workflows/executor/execution-core'
|
||||
import { PauseResumeManager } from '@/lib/workflows/executor/human-in-the-loop-manager'
|
||||
import { type ExecutionMetadata, ExecutionSnapshot } from '@/executor/execution/snapshot'
|
||||
|
||||
const logger = createLogger('WorkflowExecution')
|
||||
|
||||
export interface ExecuteWorkflowOptions {
|
||||
enabled: boolean
|
||||
selectedOutputs?: string[]
|
||||
isSecureMode?: boolean
|
||||
workflowTriggerType?: 'api' | 'chat'
|
||||
onStream?: (streamingExec: any) => Promise<void>
|
||||
onBlockComplete?: (blockId: string, output: any) => Promise<void>
|
||||
skipLoggingComplete?: boolean
|
||||
}
|
||||
|
||||
export interface WorkflowInfo {
|
||||
id: string
|
||||
userId: string
|
||||
workspaceId?: string | null
|
||||
isDeployed?: boolean
|
||||
variables?: Record<string, any>
|
||||
}
|
||||
|
||||
export async function executeWorkflow(
|
||||
workflow: WorkflowInfo,
|
||||
requestId: string,
|
||||
input: any | undefined,
|
||||
actorUserId: string,
|
||||
streamConfig?: ExecuteWorkflowOptions,
|
||||
providedExecutionId?: string
|
||||
): Promise<any> {
|
||||
if (!workflow.workspaceId) {
|
||||
throw new Error(`Workflow ${workflow.id} has no workspaceId`)
|
||||
}
|
||||
|
||||
const workflowId = workflow.id
|
||||
const workspaceId = workflow.workspaceId
|
||||
const executionId = providedExecutionId || uuidv4()
|
||||
const triggerType = streamConfig?.workflowTriggerType || 'api'
|
||||
const loggingSession = new LoggingSession(workflowId, executionId, triggerType, requestId)
|
||||
|
||||
try {
|
||||
const metadata: ExecutionMetadata = {
|
||||
requestId,
|
||||
executionId,
|
||||
workflowId,
|
||||
workspaceId,
|
||||
userId: actorUserId,
|
||||
workflowUserId: workflow.userId,
|
||||
triggerType,
|
||||
useDraftState: false,
|
||||
startTime: new Date().toISOString(),
|
||||
isClientSession: false,
|
||||
}
|
||||
|
||||
const snapshot = new ExecutionSnapshot(
|
||||
metadata,
|
||||
workflow,
|
||||
input,
|
||||
workflow.variables || {},
|
||||
streamConfig?.selectedOutputs || []
|
||||
)
|
||||
|
||||
const result = await executeWorkflowCore({
|
||||
snapshot,
|
||||
callbacks: {
|
||||
onStream: streamConfig?.onStream,
|
||||
onBlockComplete: streamConfig?.onBlockComplete
|
||||
? async (blockId: string, _blockName: string, _blockType: string, output: any) => {
|
||||
await streamConfig.onBlockComplete!(blockId, output)
|
||||
}
|
||||
: undefined,
|
||||
},
|
||||
loggingSession,
|
||||
})
|
||||
|
||||
if (result.status === 'paused') {
|
||||
if (!result.snapshotSeed) {
|
||||
logger.error(`[${requestId}] Missing snapshot seed for paused execution`, {
|
||||
executionId,
|
||||
})
|
||||
} else {
|
||||
await PauseResumeManager.persistPauseResult({
|
||||
workflowId,
|
||||
executionId,
|
||||
pausePoints: result.pausePoints || [],
|
||||
snapshotSeed: result.snapshotSeed,
|
||||
executorUserId: result.metadata?.userId,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
await PauseResumeManager.processQueuedResumes(executionId)
|
||||
}
|
||||
|
||||
if (streamConfig?.skipLoggingComplete) {
|
||||
return {
|
||||
...result,
|
||||
_streamingMetadata: {
|
||||
loggingSession,
|
||||
processedInput: input,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Workflow execution failed:`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,23 @@
|
||||
import {
|
||||
extractBlockIdFromOutputId,
|
||||
extractPathFromOutputId,
|
||||
traverseObjectPath,
|
||||
} from '@/lib/core/utils/response-format'
|
||||
import { encodeSSE } from '@/lib/core/utils/sse'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { buildTraceSpans } from '@/lib/logs/execution/trace-spans/trace-spans'
|
||||
import { processStreamingBlockLogs } from '@/lib/tokenization'
|
||||
import { executeWorkflow } from '@/lib/workflows/executor/execute-workflow'
|
||||
import type { ExecutionResult } from '@/executor/types'
|
||||
|
||||
const logger = createLogger('WorkflowStreaming')
|
||||
|
||||
const DANGEROUS_KEYS = ['__proto__', 'constructor', 'prototype']
|
||||
|
||||
export interface StreamingConfig {
|
||||
selectedOutputs?: string[]
|
||||
isSecureMode?: boolean
|
||||
workflowTriggerType?: 'api' | 'chat'
|
||||
onStream?: (streamingExec: {
|
||||
stream: ReadableStream
|
||||
execution?: { blockId?: string }
|
||||
}) => Promise<void>
|
||||
}
|
||||
|
||||
export interface StreamingResponseOptions {
|
||||
@@ -26,109 +32,219 @@ export interface StreamingResponseOptions {
|
||||
input: any
|
||||
executingUserId: string
|
||||
streamConfig: StreamingConfig
|
||||
createFilteredResult: (result: ExecutionResult) => any
|
||||
executionId?: string
|
||||
}
|
||||
|
||||
interface StreamingState {
|
||||
streamedContent: Map<string, string>
|
||||
processedOutputs: Set<string>
|
||||
streamCompletionTimes: Map<string, number>
|
||||
}
|
||||
|
||||
function extractOutputValue(output: any, path: string): any {
|
||||
let value = traverseObjectPath(output, path)
|
||||
if (value === undefined && output?.response) {
|
||||
value = traverseObjectPath(output.response, path)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
function isDangerousKey(key: string): boolean {
|
||||
return DANGEROUS_KEYS.includes(key)
|
||||
}
|
||||
|
||||
function buildMinimalResult(
|
||||
result: ExecutionResult,
|
||||
selectedOutputs: string[] | undefined,
|
||||
streamedContent: Map<string, string>,
|
||||
requestId: string
|
||||
): { success: boolean; error?: string; output: Record<string, any> } {
|
||||
const minimalResult = {
|
||||
success: result.success,
|
||||
error: result.error,
|
||||
output: {} as Record<string, any>,
|
||||
}
|
||||
|
||||
if (!selectedOutputs?.length) {
|
||||
minimalResult.output = result.output || {}
|
||||
return minimalResult
|
||||
}
|
||||
|
||||
if (!result.output || !result.logs) {
|
||||
return minimalResult
|
||||
}
|
||||
|
||||
for (const outputId of selectedOutputs) {
|
||||
const blockId = extractBlockIdFromOutputId(outputId)
|
||||
|
||||
if (streamedContent.has(blockId)) {
|
||||
continue
|
||||
}
|
||||
|
||||
if (isDangerousKey(blockId)) {
|
||||
logger.warn(`[${requestId}] Blocked dangerous blockId: ${blockId}`)
|
||||
continue
|
||||
}
|
||||
|
||||
const path = extractPathFromOutputId(outputId, blockId)
|
||||
if (isDangerousKey(path)) {
|
||||
logger.warn(`[${requestId}] Blocked dangerous path: ${path}`)
|
||||
continue
|
||||
}
|
||||
|
||||
const blockLog = result.logs.find((log: any) => log.blockId === blockId)
|
||||
if (!blockLog?.output) {
|
||||
continue
|
||||
}
|
||||
|
||||
const value = extractOutputValue(blockLog.output, path)
|
||||
if (value === undefined) {
|
||||
continue
|
||||
}
|
||||
|
||||
if (!minimalResult.output[blockId]) {
|
||||
minimalResult.output[blockId] = Object.create(null)
|
||||
}
|
||||
minimalResult.output[blockId][path] = value
|
||||
}
|
||||
|
||||
return minimalResult
|
||||
}
|
||||
|
||||
function updateLogsWithStreamedContent(logs: any[], state: StreamingState): any[] {
|
||||
return logs.map((log: any) => {
|
||||
if (!state.streamedContent.has(log.blockId)) {
|
||||
return log
|
||||
}
|
||||
|
||||
const content = state.streamedContent.get(log.blockId)
|
||||
const updatedLog = { ...log }
|
||||
|
||||
if (state.streamCompletionTimes.has(log.blockId)) {
|
||||
const completionTime = state.streamCompletionTimes.get(log.blockId)!
|
||||
const startTime = new Date(log.startedAt).getTime()
|
||||
updatedLog.endedAt = new Date(completionTime).toISOString()
|
||||
updatedLog.durationMs = completionTime - startTime
|
||||
}
|
||||
|
||||
if (log.output && content) {
|
||||
updatedLog.output = { ...log.output, content }
|
||||
}
|
||||
|
||||
return updatedLog
|
||||
})
|
||||
}
|
||||
|
||||
async function completeLoggingSession(result: ExecutionResult): Promise<void> {
|
||||
if (!result._streamingMetadata?.loggingSession) {
|
||||
return
|
||||
}
|
||||
|
||||
const { traceSpans, totalDuration } = buildTraceSpans(result)
|
||||
|
||||
await result._streamingMetadata.loggingSession.safeComplete({
|
||||
endedAt: new Date().toISOString(),
|
||||
totalDurationMs: totalDuration || 0,
|
||||
finalOutput: result.output || {},
|
||||
traceSpans: (traceSpans || []) as any,
|
||||
workflowInput: result._streamingMetadata.processedInput,
|
||||
})
|
||||
|
||||
result._streamingMetadata = undefined
|
||||
}
|
||||
|
||||
export async function createStreamingResponse(
|
||||
options: StreamingResponseOptions
|
||||
): Promise<ReadableStream> {
|
||||
const {
|
||||
requestId,
|
||||
workflow,
|
||||
input,
|
||||
executingUserId,
|
||||
streamConfig,
|
||||
createFilteredResult,
|
||||
executionId,
|
||||
} = options
|
||||
|
||||
const { executeWorkflow, createFilteredResult: defaultFilteredResult } = await import(
|
||||
'@/app/api/workflows/[id]/execute/route'
|
||||
)
|
||||
const filterResultFn = createFilteredResult || defaultFilteredResult
|
||||
const { requestId, workflow, input, executingUserId, streamConfig, executionId } = options
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
const streamedContent = new Map<string, string>()
|
||||
const processedOutputs = new Set<string>()
|
||||
const streamCompletionTimes = new Map<string, number>()
|
||||
const state: StreamingState = {
|
||||
streamedContent: new Map(),
|
||||
processedOutputs: new Set(),
|
||||
streamCompletionTimes: new Map(),
|
||||
}
|
||||
|
||||
const sendChunk = (blockId: string, content: string) => {
|
||||
const separator = processedOutputs.size > 0 ? '\n\n' : ''
|
||||
controller.enqueue(encodeSSE({ blockId, chunk: separator + content }))
|
||||
processedOutputs.add(blockId)
|
||||
const sendChunk = (blockId: string, content: string) => {
|
||||
const separator = state.processedOutputs.size > 0 ? '\n\n' : ''
|
||||
controller.enqueue(encodeSSE({ blockId, chunk: separator + content }))
|
||||
state.processedOutputs.add(blockId)
|
||||
}
|
||||
|
||||
const onStreamCallback = async (streamingExec: {
|
||||
stream: ReadableStream
|
||||
execution?: { blockId?: string }
|
||||
}) => {
|
||||
const blockId = streamingExec.execution?.blockId
|
||||
if (!blockId) {
|
||||
logger.warn(`[${requestId}] Streaming execution missing blockId`)
|
||||
return
|
||||
}
|
||||
|
||||
const onStreamCallback = async (streamingExec: {
|
||||
stream: ReadableStream
|
||||
execution?: { blockId?: string }
|
||||
}) => {
|
||||
const blockId = streamingExec.execution?.blockId || 'unknown'
|
||||
const reader = streamingExec.stream.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let isFirstChunk = true
|
||||
const reader = streamingExec.stream.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let isFirstChunk = true
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
// Record when this stream completed
|
||||
streamCompletionTimes.set(blockId, Date.now())
|
||||
break
|
||||
}
|
||||
|
||||
const textChunk = decoder.decode(value, { stream: true })
|
||||
streamedContent.set(blockId, (streamedContent.get(blockId) || '') + textChunk)
|
||||
|
||||
if (isFirstChunk) {
|
||||
sendChunk(blockId, textChunk)
|
||||
isFirstChunk = false
|
||||
} else {
|
||||
controller.enqueue(encodeSSE({ blockId, chunk: textChunk }))
|
||||
}
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
state.streamCompletionTimes.set(blockId, Date.now())
|
||||
break
|
||||
}
|
||||
} catch (streamError) {
|
||||
logger.error(`[${requestId}] Error reading agent stream:`, streamError)
|
||||
controller.enqueue(
|
||||
encodeSSE({
|
||||
event: 'stream_error',
|
||||
blockId,
|
||||
error: streamError instanceof Error ? streamError.message : 'Stream reading error',
|
||||
})
|
||||
|
||||
const textChunk = decoder.decode(value, { stream: true })
|
||||
state.streamedContent.set(
|
||||
blockId,
|
||||
(state.streamedContent.get(blockId) || '') + textChunk
|
||||
)
|
||||
|
||||
if (isFirstChunk) {
|
||||
sendChunk(blockId, textChunk)
|
||||
isFirstChunk = false
|
||||
} else {
|
||||
controller.enqueue(encodeSSE({ blockId, chunk: textChunk }))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const onBlockCompleteCallback = async (blockId: string, output: any) => {
|
||||
if (!streamConfig.selectedOutputs?.length) return
|
||||
|
||||
const { extractBlockIdFromOutputId, extractPathFromOutputId, traverseObjectPath } =
|
||||
await import('@/lib/core/utils/response-format')
|
||||
|
||||
const matchingOutputs = streamConfig.selectedOutputs.filter(
|
||||
(outputId) => extractBlockIdFromOutputId(outputId) === blockId
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error reading stream for block ${blockId}:`, error)
|
||||
controller.enqueue(
|
||||
encodeSSE({
|
||||
event: 'stream_error',
|
||||
blockId,
|
||||
error: error instanceof Error ? error.message : 'Stream reading error',
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (!matchingOutputs.length) return
|
||||
|
||||
for (const outputId of matchingOutputs) {
|
||||
const path = extractPathFromOutputId(outputId, blockId)
|
||||
|
||||
let outputValue = traverseObjectPath(output, path)
|
||||
if (outputValue === undefined && output.response) {
|
||||
outputValue = traverseObjectPath(output.response, path)
|
||||
}
|
||||
|
||||
if (outputValue !== undefined) {
|
||||
const formattedOutput =
|
||||
typeof outputValue === 'string' ? outputValue : JSON.stringify(outputValue, null, 2)
|
||||
sendChunk(blockId, formattedOutput)
|
||||
}
|
||||
}
|
||||
const onBlockCompleteCallback = async (blockId: string, output: any) => {
|
||||
if (!streamConfig.selectedOutputs?.length) {
|
||||
return
|
||||
}
|
||||
|
||||
if (state.streamedContent.has(blockId)) {
|
||||
return
|
||||
}
|
||||
|
||||
const matchingOutputs = streamConfig.selectedOutputs.filter(
|
||||
(outputId) => extractBlockIdFromOutputId(outputId) === blockId
|
||||
)
|
||||
|
||||
for (const outputId of matchingOutputs) {
|
||||
const path = extractPathFromOutputId(outputId, blockId)
|
||||
const outputValue = extractOutputValue(output, path)
|
||||
|
||||
if (outputValue !== undefined) {
|
||||
const formattedOutput =
|
||||
typeof outputValue === 'string' ? outputValue : JSON.stringify(outputValue, null, 2)
|
||||
sendChunk(blockId, formattedOutput)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await executeWorkflow(
|
||||
workflow,
|
||||
requestId,
|
||||
@@ -141,97 +257,24 @@ export async function createStreamingResponse(
|
||||
workflowTriggerType: streamConfig.workflowTriggerType,
|
||||
onStream: onStreamCallback,
|
||||
onBlockComplete: onBlockCompleteCallback,
|
||||
skipLoggingComplete: true, // We'll complete logging after tokenization
|
||||
skipLoggingComplete: true,
|
||||
},
|
||||
executionId
|
||||
)
|
||||
|
||||
if (result.logs && streamedContent.size > 0) {
|
||||
result.logs = result.logs.map((log: any) => {
|
||||
if (streamedContent.has(log.blockId)) {
|
||||
const content = streamedContent.get(log.blockId)
|
||||
|
||||
// Update timing to reflect actual stream completion
|
||||
if (streamCompletionTimes.has(log.blockId)) {
|
||||
const completionTime = streamCompletionTimes.get(log.blockId)!
|
||||
const startTime = new Date(log.startedAt).getTime()
|
||||
log.endedAt = new Date(completionTime).toISOString()
|
||||
log.durationMs = completionTime - startTime
|
||||
}
|
||||
|
||||
if (log.output && content) {
|
||||
return { ...log, output: { ...log.output, content } }
|
||||
}
|
||||
}
|
||||
return log
|
||||
})
|
||||
|
||||
const { processStreamingBlockLogs } = await import('@/lib/tokenization')
|
||||
processStreamingBlockLogs(result.logs, streamedContent)
|
||||
if (result.logs && state.streamedContent.size > 0) {
|
||||
result.logs = updateLogsWithStreamedContent(result.logs, state)
|
||||
processStreamingBlockLogs(result.logs, state.streamedContent)
|
||||
}
|
||||
|
||||
// Complete the logging session with updated trace spans that include cost data
|
||||
if (result._streamingMetadata?.loggingSession) {
|
||||
const { buildTraceSpans } = await import('@/lib/logs/execution/trace-spans/trace-spans')
|
||||
const { traceSpans, totalDuration } = buildTraceSpans(result)
|
||||
await completeLoggingSession(result)
|
||||
|
||||
await result._streamingMetadata.loggingSession.safeComplete({
|
||||
endedAt: new Date().toISOString(),
|
||||
totalDurationMs: totalDuration || 0,
|
||||
finalOutput: result.output || {},
|
||||
traceSpans: (traceSpans || []) as any,
|
||||
workflowInput: result._streamingMetadata.processedInput,
|
||||
})
|
||||
|
||||
result._streamingMetadata = undefined
|
||||
}
|
||||
|
||||
// Create a minimal result with only selected outputs
|
||||
const minimalResult = {
|
||||
success: result.success,
|
||||
error: result.error,
|
||||
output: {} as any,
|
||||
}
|
||||
|
||||
if (streamConfig.selectedOutputs?.length && result.output) {
|
||||
const { extractBlockIdFromOutputId, extractPathFromOutputId, traverseObjectPath } =
|
||||
await import('@/lib/core/utils/response-format')
|
||||
|
||||
for (const outputId of streamConfig.selectedOutputs) {
|
||||
const blockId = extractBlockIdFromOutputId(outputId)
|
||||
const path = extractPathFromOutputId(outputId, blockId)
|
||||
|
||||
if (result.logs) {
|
||||
const blockLog = result.logs.find((log: any) => log.blockId === blockId)
|
||||
if (blockLog?.output) {
|
||||
let value = traverseObjectPath(blockLog.output, path)
|
||||
if (value === undefined && blockLog.output.response) {
|
||||
value = traverseObjectPath(blockLog.output.response, path)
|
||||
}
|
||||
if (value !== undefined) {
|
||||
const dangerousKeys = ['__proto__', 'constructor', 'prototype']
|
||||
if (dangerousKeys.includes(blockId) || dangerousKeys.includes(path)) {
|
||||
logger.warn(
|
||||
`[${requestId}] Blocked potentially dangerous property assignment`,
|
||||
{
|
||||
blockId,
|
||||
path,
|
||||
}
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if (!minimalResult.output[blockId]) {
|
||||
minimalResult.output[blockId] = Object.create(null)
|
||||
}
|
||||
minimalResult.output[blockId][path] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (!streamConfig.selectedOutputs?.length) {
|
||||
minimalResult.output = result.output
|
||||
}
|
||||
const minimalResult = buildMinimalResult(
|
||||
result,
|
||||
streamConfig.selectedOutputs,
|
||||
state.streamedContent,
|
||||
requestId
|
||||
)
|
||||
|
||||
controller.enqueue(encodeSSE({ event: 'final', data: minimalResult }))
|
||||
controller.enqueue(encodeSSE('[DONE]'))
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
"@browserbasehq/stagehand": "^3.0.5",
|
||||
"@cerebras/cerebras_cloud_sdk": "^1.23.0",
|
||||
"@e2b/code-interpreter": "^2.0.0",
|
||||
"@google/genai": "1.34.0",
|
||||
"@hookform/resolvers": "^4.1.3",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/exporter-jaeger": "2.1.0",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,22 +1,49 @@
|
||||
import type {
|
||||
RawMessageDeltaEvent,
|
||||
RawMessageStartEvent,
|
||||
RawMessageStreamEvent,
|
||||
Usage,
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { trackForcedToolUsage } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('AnthropicUtils')
|
||||
|
||||
/**
|
||||
* Helper to wrap Anthropic streaming into a browser-friendly ReadableStream
|
||||
*/
|
||||
export interface AnthropicStreamUsage {
|
||||
input_tokens: number
|
||||
output_tokens: number
|
||||
}
|
||||
|
||||
export function createReadableStreamFromAnthropicStream(
|
||||
anthropicStream: AsyncIterable<any>
|
||||
): ReadableStream {
|
||||
anthropicStream: AsyncIterable<RawMessageStreamEvent>,
|
||||
onComplete?: (content: string, usage: AnthropicStreamUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
let fullContent = ''
|
||||
let inputTokens = 0
|
||||
let outputTokens = 0
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === 'content_block_delta' && event.delta?.text) {
|
||||
controller.enqueue(new TextEncoder().encode(event.delta.text))
|
||||
if (event.type === 'message_start') {
|
||||
const startEvent = event as RawMessageStartEvent
|
||||
const usage: Usage = startEvent.message.usage
|
||||
inputTokens = usage.input_tokens
|
||||
} else if (event.type === 'message_delta') {
|
||||
const deltaEvent = event as RawMessageDeltaEvent
|
||||
outputTokens = deltaEvent.usage.output_tokens
|
||||
} else if (event.type === 'content_block_delta' && event.delta.type === 'text_delta') {
|
||||
const text = event.delta.text
|
||||
fullContent += text
|
||||
controller.enqueue(new TextEncoder().encode(text))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, { input_tokens: inputTokens, output_tokens: outputTokens })
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (err) {
|
||||
controller.error(err)
|
||||
@@ -25,16 +52,10 @@ export function createReadableStreamFromAnthropicStream(
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to generate a simple unique ID for tool uses
|
||||
*/
|
||||
export function generateToolUseId(toolName: string): string {
|
||||
return `${toolName}-${Date.now()}-${Math.random().toString(36).substring(2, 7)}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to check for forced tool usage in Anthropic responses
|
||||
*/
|
||||
export function checkForForcedToolUsage(
|
||||
response: any,
|
||||
toolChoice: any,
|
||||
@@ -45,16 +66,11 @@ export function checkForForcedToolUsage(
|
||||
const toolUses = response.content.filter((item: any) => item.type === 'tool_use')
|
||||
|
||||
if (toolUses.length > 0) {
|
||||
// Convert Anthropic tool_use format to a format trackForcedToolUsage can understand
|
||||
const adaptedToolCalls = toolUses.map((tool: any) => ({
|
||||
name: tool.name,
|
||||
}))
|
||||
|
||||
// Convert Anthropic tool_choice format to match OpenAI format for tracking
|
||||
const adaptedToolCalls = toolUses.map((tool: any) => ({ name: tool.name }))
|
||||
const adaptedToolChoice =
|
||||
toolChoice.type === 'tool' ? { function: { name: toolChoice.name } } : toolChoice
|
||||
|
||||
const result = trackForcedToolUsage(
|
||||
return trackForcedToolUsage(
|
||||
adaptedToolCalls,
|
||||
adaptedToolChoice,
|
||||
logger,
|
||||
@@ -62,8 +78,6 @@ export function checkForForcedToolUsage(
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
return null
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { AzureOpenAI } from 'openai'
|
||||
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
@@ -14,7 +15,11 @@ import type {
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('AzureOpenAIProvider')
|
||||
@@ -43,8 +48,6 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
stream: !!request.stream,
|
||||
})
|
||||
|
||||
// Extract Azure-specific configuration from request or environment
|
||||
// Priority: request parameters > environment variables
|
||||
const azureEndpoint = request.azureEndpoint || env.AZURE_OPENAI_ENDPOINT
|
||||
const azureApiVersion =
|
||||
request.azureApiVersion || env.AZURE_OPENAI_API_VERSION || '2024-07-01-preview'
|
||||
@@ -55,17 +58,14 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
)
|
||||
}
|
||||
|
||||
// API key is now handled server-side before this function is called
|
||||
const azureOpenAI = new AzureOpenAI({
|
||||
apiKey: request.apiKey,
|
||||
apiVersion: azureApiVersion,
|
||||
endpoint: azureEndpoint,
|
||||
})
|
||||
|
||||
// Start with an empty array for all messages
|
||||
const allMessages = []
|
||||
|
||||
// Add system prompt if present
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
@@ -73,7 +73,6 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add context if present
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
@@ -81,12 +80,10 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to Azure OpenAI format if provided
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
@@ -98,24 +95,19 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
}))
|
||||
: undefined
|
||||
|
||||
// Build the request payload - use deployment name instead of model name
|
||||
const deploymentName = (request.model || 'azure/gpt-4o').replace('azure/', '')
|
||||
const payload: any = {
|
||||
model: deploymentName, // Azure OpenAI uses deployment name
|
||||
model: deploymentName,
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
// Add optional parameters
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
// Add GPT-5 specific parameters
|
||||
if (request.reasoningEffort !== undefined) payload.reasoning_effort = request.reasoningEffort
|
||||
if (request.verbosity !== undefined) payload.verbosity = request.verbosity
|
||||
|
||||
// Add response format for structured output if specified
|
||||
if (request.responseFormat) {
|
||||
// Use Azure OpenAI's JSON schema format
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
@@ -128,7 +120,6 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
logger.info('Added JSON schema response format to Azure OpenAI request')
|
||||
}
|
||||
|
||||
// Handle tools and tool usage control
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
|
||||
if (tools?.length) {
|
||||
@@ -156,39 +147,40 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// Start execution timer for the entire provider execution
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
// Check if we can stream directly (no tools required)
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info('Using streaming response for Azure OpenAI request')
|
||||
|
||||
// Create a streaming request with token usage tracking
|
||||
const streamResponse = await azureOpenAI.chat.completions.create({
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
})
|
||||
|
||||
// Start collecting token usage from the stream
|
||||
const tokenUsage = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
const streamResponse = await azureOpenAI.chat.completions.create(streamingParams)
|
||||
|
||||
let _streamContent = ''
|
||||
|
||||
// Create a StreamingExecution response with a callback to update content and tokens
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromAzureOpenAIStream(streamResponse, (content, usage) => {
|
||||
// Update the execution data with the final content and token usage
|
||||
_streamContent = content
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.prompt_tokens,
|
||||
completion: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
|
||||
// Update the timing information with the actual completion time
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
@@ -197,7 +189,6 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
|
||||
// Update the time segment as well
|
||||
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
|
||||
streamEndTime
|
||||
@@ -205,25 +196,13 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
|
||||
// Update token usage if available from the stream
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokenUsage.prompt,
|
||||
completion: usage.completion_tokens || tokenUsage.completion,
|
||||
total: usage.total_tokens || tokenUsage.total,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
}
|
||||
// We don't need to estimate tokens here as logger.ts will handle that
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by the stream completion callback
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: tokenUsage,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
@@ -239,9 +218,9 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
// Cost will be calculated in logger
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [], // No block logs for direct streaming
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -250,17 +229,11 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
// Return the streaming execution object with explicit casting
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Make the initial API request
|
||||
const initialCallTime = Date.now()
|
||||
|
||||
// Track the original tool_choice for forced tool tracking
|
||||
const originalToolChoice = payload.tool_choice
|
||||
|
||||
// Track forced tools and their usage
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
|
||||
@@ -268,7 +241,6 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
// Collect token information but don't calculate costs - that will be done in logger.ts
|
||||
const tokens = {
|
||||
prompt: currentResponse.usage?.prompt_tokens || 0,
|
||||
completion: currentResponse.usage?.completion_tokens || 0,
|
||||
@@ -278,15 +250,10 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
// Track if a forced tool has been used
|
||||
let hasUsedForcedTool = false
|
||||
|
||||
// Track each model and tool call segment with timestamps
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
@@ -297,7 +264,6 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
},
|
||||
]
|
||||
|
||||
// Check if a forced tool was used in the first response
|
||||
const firstCheckResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
originalToolChoice,
|
||||
@@ -309,7 +275,10 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
usedForcedTools = firstCheckResult.usedForcedTools
|
||||
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
@@ -319,126 +288,135 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
// Track time for tool calls in this batch
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
// Process each tool call
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
// Get the tool from the tools registry
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
if (!tool) return null
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
// Add to time segments for both success and failure
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
// Prepare result content for the LLM
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
// Include error information so LLM can respond appropriately
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages (both success and failure)
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error processing tool call:', {
|
||||
error,
|
||||
toolName: toolCall?.function?.name,
|
||||
})
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call:', { error, toolName })
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: toolCallsInResponse.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate tool call time for this iteration
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
// Make the next request with updated messages
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
// Update tool_choice based on which forced tools have been used
|
||||
if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) {
|
||||
// If we have remaining forced tools, get the next one to force
|
||||
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
|
||||
|
||||
if (remainingTools.length > 0) {
|
||||
// Force the next tool
|
||||
nextPayload.tool_choice = {
|
||||
type: 'function',
|
||||
function: { name: remainingTools[0] },
|
||||
}
|
||||
logger.info(`Forcing next tool: ${remainingTools[0]}`)
|
||||
} else {
|
||||
// All forced tools have been used, switch to auto
|
||||
nextPayload.tool_choice = 'auto'
|
||||
logger.info('All forced tools have been used, switching to auto tool_choice')
|
||||
}
|
||||
}
|
||||
|
||||
// Time the next model call
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
// Make the next request
|
||||
currentResponse = await azureOpenAI.chat.completions.create(nextPayload)
|
||||
|
||||
// Check if any forced tools were used in this response
|
||||
const nextCheckResult = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
nextPayload.tool_choice,
|
||||
@@ -452,7 +430,6 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
@@ -461,15 +438,12 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
// Add to model time
|
||||
modelTime += thisModelTime
|
||||
|
||||
// Update content if we have a text response
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
// Update token counts
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += currentResponse.usage.completion_tokens || 0
|
||||
@@ -479,46 +453,44 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
iterationCount++
|
||||
}
|
||||
|
||||
// After all tool processing complete, if streaming was requested, use streaming for the final response
|
||||
if (request.stream) {
|
||||
logger.info('Using streaming for final response after tool processing')
|
||||
|
||||
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
|
||||
// This prevents Azure OpenAI API from trying to force tool usage again in the final streaming response
|
||||
const streamingPayload = {
|
||||
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
|
||||
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
|
||||
tool_choice: 'auto',
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
|
||||
const streamResponse = await azureOpenAI.chat.completions.create(streamingPayload)
|
||||
|
||||
// Create the StreamingExecution object with all collected data
|
||||
let _streamContent = ''
|
||||
const streamResponse = await azureOpenAI.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromAzureOpenAIStream(streamResponse, (content, usage) => {
|
||||
// Update the execution data with the final content and token usage
|
||||
_streamContent = content
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: tokens.prompt + usage.prompt_tokens,
|
||||
completion: tokens.completion + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
// Update token usage if available from the stream
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokens.prompt,
|
||||
completion: usage.completion_tokens || tokens.completion,
|
||||
total: usage.total_tokens || tokens.total,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by the callback
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
@@ -542,9 +514,13 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
// Cost will be calculated in logger
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [], // No block logs at provider level
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -553,11 +529,9 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
// Return the streaming execution object with explicit casting
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Calculate overall timing
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -578,10 +552,8 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
// We're not calculating cost here as it will be handled in logger.ts
|
||||
}
|
||||
} catch (error) {
|
||||
// Include timing information even for errors
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -591,9 +563,8 @@ export const azureOpenAIProvider: ProviderConfig = {
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
// Create a new error with timing information
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
// @ts-ignore
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
|
||||
@@ -1,70 +1,37 @@
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import type { Stream } from 'openai/streaming'
|
||||
import type { Logger } from '@/lib/logs/console/logger'
|
||||
import { trackForcedToolUsage } from '@/providers/utils'
|
||||
import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils'
|
||||
|
||||
/**
|
||||
* Helper function to convert an Azure OpenAI stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
* Creates a ReadableStream from an Azure OpenAI streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
*/
|
||||
export function createReadableStreamFromAzureOpenAIStream(
|
||||
azureOpenAIStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
azureOpenAIStream: Stream<ChatCompletionChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of azureOpenAIStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
return createOpenAICompatibleStream(azureOpenAIStream, 'Azure OpenAI', onComplete)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to check for forced tool usage in responses
|
||||
* Checks if a forced tool was used in an Azure OpenAI response.
|
||||
* Uses the shared OpenAI-compatible forced tool usage helper.
|
||||
*/
|
||||
export function checkForForcedToolUsage(
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
|
||||
logger: Logger,
|
||||
_logger: Logger,
|
||||
forcedTools: string[],
|
||||
usedForcedTools: string[]
|
||||
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
|
||||
let hasUsedForcedTool = false
|
||||
let updatedUsedForcedTools = [...usedForcedTools]
|
||||
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
'azure-openai',
|
||||
forcedTools,
|
||||
updatedUsedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
updatedUsedForcedTools = result.usedForcedTools
|
||||
}
|
||||
|
||||
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
|
||||
return checkForForcedToolUsageOpenAI(
|
||||
response,
|
||||
toolChoice,
|
||||
'Azure OpenAI',
|
||||
forcedTools,
|
||||
usedForcedTools,
|
||||
_logger
|
||||
)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import type {
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
@@ -35,7 +36,6 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
throw new Error('API key is required for Cerebras')
|
||||
}
|
||||
|
||||
// Start execution timer for the entire provider execution
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
@@ -44,31 +44,23 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
apiKey: request.apiKey,
|
||||
})
|
||||
|
||||
// Start with an empty array for all messages
|
||||
const allMessages = []
|
||||
|
||||
// Add system prompt if present
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
content: request.systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
// Add context if present
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
content: request.context,
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to Cerebras format if provided
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
@@ -80,26 +72,23 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
}))
|
||||
: undefined
|
||||
|
||||
// Build the request payload
|
||||
const payload: any = {
|
||||
model: (request.model || 'cerebras/llama-3.3-70b').replace('cerebras/', ''),
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
// Add optional parameters
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
// Add response format for structured output if specified
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
schema: request.responseFormat.schema || request.responseFormat,
|
||||
json_schema: {
|
||||
name: request.responseFormat.name || 'response_schema',
|
||||
schema: request.responseFormat.schema || request.responseFormat,
|
||||
strict: request.responseFormat.strict !== false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tools and tool usage control
|
||||
// Cerebras supports full OpenAI-compatible tool_choice including forcing specific tools
|
||||
let originalToolChoice: any
|
||||
let forcedTools: string[] = []
|
||||
let hasFilteredTools = false
|
||||
@@ -124,30 +113,40 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// EARLY STREAMING: if streaming requested and no tools to execute, stream directly
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info('Using streaming response for Cerebras request (no tools)')
|
||||
|
||||
const streamResponse: any = await client.chat.completions.create({
|
||||
...payload,
|
||||
stream: true,
|
||||
})
|
||||
|
||||
// Start collecting token usage
|
||||
const tokenUsage = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
|
||||
// Create a StreamingExecution response with a readable stream
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromCerebrasStream(streamResponse),
|
||||
stream: createReadableStreamFromCerebrasStream(streamResponse, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.prompt_tokens,
|
||||
completion: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by streaming content in chat component
|
||||
content: '',
|
||||
model: request.model || 'cerebras/llama-3.3-70b',
|
||||
tokens: tokenUsage,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
@@ -163,14 +162,9 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
// Estimate token cost
|
||||
cost: {
|
||||
total: 0.0,
|
||||
input: 0.0,
|
||||
output: 0.0,
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [], // No block logs for direct streaming
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -180,11 +174,8 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
|
||||
// Return the streaming execution object
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Make the initial API request
|
||||
const initialCallTime = Date.now()
|
||||
|
||||
let currentResponse = (await client.chat.completions.create(payload)) as CerebrasResponse
|
||||
@@ -201,11 +192,8 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
// Track each model and tool call segment with timestamps
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
@@ -216,17 +204,12 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
},
|
||||
]
|
||||
|
||||
// Keep track of processed tool calls to avoid duplicates
|
||||
const processedToolCallIds = new Set()
|
||||
// Keep track of tool call signatures to detect repeats
|
||||
const toolCallSignatures = new Set()
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
|
||||
// Break if no tool calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
@@ -234,111 +217,124 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
break
|
||||
}
|
||||
|
||||
// Track time for tool calls in this batch
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
// Process each tool call
|
||||
let processedAnyToolCall = false
|
||||
let hasRepeatedToolCalls = false
|
||||
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
// Skip if we've already processed this tool call
|
||||
const filteredToolCalls = toolCallsInResponse.filter((toolCall) => {
|
||||
if (processedToolCallIds.has(toolCall.id)) {
|
||||
continue
|
||||
return false
|
||||
}
|
||||
|
||||
// Create a signature for this tool call to detect repeats
|
||||
const toolCallSignature = `${toolCall.function.name}-${toolCall.function.arguments}`
|
||||
if (toolCallSignatures.has(toolCallSignature)) {
|
||||
hasRepeatedToolCalls = true
|
||||
continue
|
||||
return false
|
||||
}
|
||||
processedToolCallIds.add(toolCall.id)
|
||||
toolCallSignatures.add(toolCallSignature)
|
||||
return true
|
||||
})
|
||||
|
||||
const processedAnyToolCall = filteredToolCalls.length > 0
|
||||
const toolExecutionPromises = filteredToolCalls.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
try {
|
||||
processedToolCallIds.add(toolCall.id)
|
||||
toolCallSignatures.add(toolCallSignature)
|
||||
processedAnyToolCall = true
|
||||
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
// Get the tool from the tools registry
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
if (!tool) return null
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
// Add to time segments for both success and failure
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
// Prepare result content for the LLM
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
// Include error information so LLM can respond appropriately
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages (both success and failure)
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error processing tool call:', { error })
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call (Cerebras):', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
toolName,
|
||||
})
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: filteredToolCalls.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate tool call time for this iteration
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
// Check if we used any forced tools and update tool_choice for the next iteration
|
||||
let usedForcedTools: string[] = []
|
||||
if (typeof originalToolChoice === 'object' && forcedTools.length > 0) {
|
||||
const toolTracking = trackForcedToolUsage(
|
||||
@@ -351,28 +347,20 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
)
|
||||
usedForcedTools = toolTracking.usedForcedTools
|
||||
const nextToolChoice = toolTracking.nextToolChoice
|
||||
|
||||
// Update tool_choice for next iteration if we're still forcing tools
|
||||
if (nextToolChoice && typeof nextToolChoice === 'object') {
|
||||
payload.tool_choice = nextToolChoice
|
||||
} else if (nextToolChoice === 'auto' || !nextToolChoice) {
|
||||
// All forced tools have been used, switch to auto
|
||||
payload.tool_choice = 'auto'
|
||||
}
|
||||
}
|
||||
|
||||
// After processing tool calls, get a final response
|
||||
if (processedAnyToolCall || hasRepeatedToolCalls) {
|
||||
// Time the next model call
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
// Make the final request
|
||||
const finalPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
// Use tool_choice: 'none' for the final response to avoid an infinite loop
|
||||
finalPayload.tool_choice = 'none'
|
||||
|
||||
const finalResponse = (await client.chat.completions.create(
|
||||
@@ -382,7 +370,6 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: 'Final response',
|
||||
@@ -391,14 +378,11 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
// Add to model time
|
||||
modelTime += thisModelTime
|
||||
|
||||
if (finalResponse.choices[0]?.message?.content) {
|
||||
content = finalResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
// Update final token counts
|
||||
if (finalResponse.usage) {
|
||||
tokens.prompt += finalResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += finalResponse.usage.completion_tokens || 0
|
||||
@@ -408,18 +392,13 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
break
|
||||
}
|
||||
|
||||
// Only continue if we haven't processed any tool calls and haven't seen repeats
|
||||
if (!processedAnyToolCall && !hasRepeatedToolCalls) {
|
||||
// Make the next request with updated messages
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
// Time the next model call
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
// Make the next request
|
||||
currentResponse = (await client.chat.completions.create(
|
||||
nextPayload
|
||||
)) as CerebrasResponse
|
||||
@@ -427,7 +406,6 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
@@ -436,10 +414,7 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
// Add to model time
|
||||
modelTime += thisModelTime
|
||||
|
||||
// Update token counts
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += currentResponse.usage.completion_tokens || 0
|
||||
@@ -453,33 +428,48 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
logger.error('Error in Cerebras tool processing:', { error })
|
||||
}
|
||||
|
||||
// Calculate overall timing
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
// POST-TOOL-STREAMING: stream after tool calls if requested
|
||||
if (request.stream) {
|
||||
logger.info('Using streaming for final Cerebras response after tool processing')
|
||||
|
||||
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
|
||||
// This prevents the API from trying to force tool usage again in the final streaming response
|
||||
const streamingPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
|
||||
tool_choice: 'auto',
|
||||
stream: true,
|
||||
}
|
||||
|
||||
const streamResponse: any = await client.chat.completions.create(streamingPayload)
|
||||
|
||||
// Create a StreamingExecution response with all collected data
|
||||
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromCerebrasStream(streamResponse),
|
||||
stream: createReadableStreamFromCerebrasStream(streamResponse, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: tokens.prompt + usage.prompt_tokens,
|
||||
completion: tokens.completion + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by the callback
|
||||
content: '',
|
||||
model: request.model || 'cerebras/llama-3.3-70b',
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
@@ -504,12 +494,12 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
total: (tokens.total || 0) * 0.0001,
|
||||
input: (tokens.prompt || 0) * 0.0001,
|
||||
output: (tokens.completion || 0) * 0.0001,
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [], // No block logs at provider level
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -519,7 +509,6 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
|
||||
// Return the streaming execution object
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
@@ -541,7 +530,6 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
// Include timing information even for errors
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -551,9 +539,8 @@ export const cerebrasProvider: ProviderConfig = {
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
// Create a new error with timing information
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
// @ts-ignore - Adding timing property to error for debugging
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
|
||||
@@ -1,23 +1,26 @@
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import { createOpenAICompatibleStream } from '@/providers/utils'
|
||||
|
||||
interface CerebrasChunk {
|
||||
choices?: Array<{
|
||||
delta?: {
|
||||
content?: string
|
||||
}
|
||||
}>
|
||||
usage?: {
|
||||
prompt_tokens?: number
|
||||
completion_tokens?: number
|
||||
total_tokens?: number
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to convert a Cerebras streaming response (async iterable) into a ReadableStream.
|
||||
* Enqueues only the model's text delta chunks as UTF-8 encoded bytes.
|
||||
* Creates a ReadableStream from a Cerebras streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
*/
|
||||
export function createReadableStreamFromCerebrasStream(
|
||||
cerebrasStream: AsyncIterable<any>
|
||||
): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of cerebrasStream) {
|
||||
const content = chunk.choices?.[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
cerebrasStream: AsyncIterable<CerebrasChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createOpenAICompatibleStream(cerebrasStream as any, 'Cerebras', onComplete)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
@@ -34,21 +35,17 @@ export const deepseekProvider: ProviderConfig = {
|
||||
throw new Error('API key is required for Deepseek')
|
||||
}
|
||||
|
||||
// Start execution timer for the entire provider execution
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
// Deepseek uses the OpenAI SDK with a custom baseURL
|
||||
const deepseek = new OpenAI({
|
||||
apiKey: request.apiKey,
|
||||
baseURL: 'https://api.deepseek.com/v1',
|
||||
})
|
||||
|
||||
// Start with an empty array for all messages
|
||||
const allMessages = []
|
||||
|
||||
// Add system prompt if present
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
@@ -56,7 +53,6 @@ export const deepseekProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add context if present
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
@@ -64,12 +60,10 @@ export const deepseekProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to OpenAI format if provided
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
@@ -82,15 +76,13 @@ export const deepseekProvider: ProviderConfig = {
|
||||
: undefined
|
||||
|
||||
const payload: any = {
|
||||
model: 'deepseek-chat', // Hardcode to deepseek-chat regardless of what's selected in the UI
|
||||
model: 'deepseek-chat',
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
// Add optional parameters
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
// Handle tools and tool usage control
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
|
||||
if (tools?.length) {
|
||||
@@ -118,7 +110,6 @@ export const deepseekProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// EARLY STREAMING: if streaming requested and no tools to execute, stream directly
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info('Using streaming response for DeepSeek request (no tools)')
|
||||
|
||||
@@ -127,22 +118,35 @@ export const deepseekProvider: ProviderConfig = {
|
||||
stream: true,
|
||||
})
|
||||
|
||||
// Start collecting token usage
|
||||
const tokenUsage = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
|
||||
// Create a StreamingExecution response with a readable stream
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromDeepseekStream(streamResponse),
|
||||
stream: createReadableStreamFromDeepseekStream(
|
||||
streamResponse as any,
|
||||
(content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.prompt_tokens,
|
||||
completion: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
}
|
||||
),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by streaming content in chat component
|
||||
content: '',
|
||||
model: request.model || 'deepseek-chat',
|
||||
tokens: tokenUsage,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
@@ -158,14 +162,9 @@ export const deepseekProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
// Estimate token cost
|
||||
cost: {
|
||||
total: 0.0,
|
||||
input: 0.0,
|
||||
output: 0.0,
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [], // No block logs for direct streaming
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -175,17 +174,11 @@ export const deepseekProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
|
||||
// Return the streaming execution object
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Make the initial API request
|
||||
const initialCallTime = Date.now()
|
||||
|
||||
// Track the original tool_choice for forced tool tracking
|
||||
const originalToolChoice = payload.tool_choice
|
||||
|
||||
// Track forced tools and their usage
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
|
||||
@@ -194,11 +187,8 @@ export const deepseekProvider: ProviderConfig = {
|
||||
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
|
||||
// Clean up the response content if it exists
|
||||
if (content) {
|
||||
// Remove any markdown code block markers
|
||||
content = content.replace(/```json\n?|\n?```/g, '')
|
||||
// Trim any whitespace
|
||||
content = content.trim()
|
||||
}
|
||||
|
||||
@@ -211,15 +201,10 @@ export const deepseekProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
// Track if a forced tool has been used
|
||||
let hasUsedForcedTool = false
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
// Track each model and tool call segment with timestamps
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
@@ -230,7 +215,6 @@ export const deepseekProvider: ProviderConfig = {
|
||||
},
|
||||
]
|
||||
|
||||
// Check if a forced tool was used in the first response
|
||||
if (
|
||||
typeof originalToolChoice === 'object' &&
|
||||
currentResponse.choices[0]?.message?.tool_calls
|
||||
@@ -250,133 +234,148 @@ export const deepseekProvider: ProviderConfig = {
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
// Track time for tool calls in this batch
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
// Process each tool call
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
// Get the tool from the tools registry
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
if (!tool) return null
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
// Add to time segments for both success and failure
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
// Prepare result content for the LLM
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
// Include error information so LLM can respond appropriately
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages (both success and failure)
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error processing tool call:', { error })
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call:', { error, toolName })
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: toolCallsInResponse.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate tool call time for this iteration
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
// Make the next request with updated messages
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
// Update tool_choice based on which forced tools have been used
|
||||
if (
|
||||
typeof originalToolChoice === 'object' &&
|
||||
hasUsedForcedTool &&
|
||||
forcedTools.length > 0
|
||||
) {
|
||||
// If we have remaining forced tools, get the next one to force
|
||||
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
|
||||
|
||||
if (remainingTools.length > 0) {
|
||||
// Force the next tool
|
||||
nextPayload.tool_choice = {
|
||||
type: 'function',
|
||||
function: { name: remainingTools[0] },
|
||||
}
|
||||
logger.info(`Forcing next tool: ${remainingTools[0]}`)
|
||||
} else {
|
||||
// All forced tools have been used, switch to auto
|
||||
nextPayload.tool_choice = 'auto'
|
||||
logger.info('All forced tools have been used, switching to auto tool_choice')
|
||||
}
|
||||
}
|
||||
|
||||
// Time the next model call
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
// Make the next request
|
||||
currentResponse = await deepseek.chat.completions.create(nextPayload)
|
||||
|
||||
// Check if any forced tools were used in this response
|
||||
if (
|
||||
typeof nextPayload.tool_choice === 'object' &&
|
||||
currentResponse.choices[0]?.message?.tool_calls
|
||||
@@ -397,7 +396,6 @@ export const deepseekProvider: ProviderConfig = {
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
@@ -406,18 +404,14 @@ export const deepseekProvider: ProviderConfig = {
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
// Add to model time
|
||||
modelTime += thisModelTime
|
||||
|
||||
// Update content if we have a text response
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
// Clean up the response content
|
||||
content = content.replace(/```json\n?|\n?```/g, '')
|
||||
content = content.trim()
|
||||
}
|
||||
|
||||
// Update token counts
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += currentResponse.usage.completion_tokens || 0
|
||||
@@ -430,33 +424,51 @@ export const deepseekProvider: ProviderConfig = {
|
||||
logger.error('Error in Deepseek request:', { error })
|
||||
}
|
||||
|
||||
// Calculate overall timing
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
|
||||
// POST-TOOL STREAMING: stream final response after tool calls if requested
|
||||
if (request.stream) {
|
||||
logger.info('Using streaming for final DeepSeek response after tool processing')
|
||||
|
||||
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
|
||||
// This prevents the API from trying to force tool usage again in the final streaming response
|
||||
const streamingPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
|
||||
tool_choice: 'auto',
|
||||
stream: true,
|
||||
}
|
||||
|
||||
const streamResponse = await deepseek.chat.completions.create(streamingPayload)
|
||||
|
||||
// Create a StreamingExecution response with all collected data
|
||||
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromDeepseekStream(streamResponse),
|
||||
stream: createReadableStreamFromDeepseekStream(
|
||||
streamResponse as any,
|
||||
(content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: tokens.prompt + usage.prompt_tokens,
|
||||
completion: tokens.completion + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}
|
||||
),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by the callback
|
||||
content: '',
|
||||
model: request.model || 'deepseek-chat',
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
@@ -481,12 +493,12 @@ export const deepseekProvider: ProviderConfig = {
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
total: (tokens.total || 0) * 0.0001,
|
||||
input: (tokens.prompt || 0) * 0.0001,
|
||||
output: (tokens.completion || 0) * 0.0001,
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [], // No block logs at provider level
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -496,7 +508,6 @@ export const deepseekProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
|
||||
// Return the streaming execution object
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
@@ -518,7 +529,6 @@ export const deepseekProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
// Include timing information even for errors
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -528,9 +538,8 @@ export const deepseekProvider: ProviderConfig = {
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
// Create a new error with timing information
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
// @ts-ignore
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
|
||||
@@ -1,21 +1,14 @@
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import { createOpenAICompatibleStream } from '@/providers/utils'
|
||||
|
||||
/**
|
||||
* Helper function to convert a DeepSeek (OpenAI-compatible) stream to a ReadableStream
|
||||
* of text chunks that can be consumed by the browser.
|
||||
* Creates a ReadableStream from a DeepSeek streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
*/
|
||||
export function createReadableStreamFromDeepseekStream(deepseekStream: any): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of deepseekStream) {
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
export function createReadableStreamFromDeepseekStream(
|
||||
deepseekStream: AsyncIterable<ChatCompletionChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createOpenAICompatibleStream(deepseekStream, 'Deepseek', onComplete)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import type {
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
@@ -23,15 +24,15 @@ import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('GoogleProvider')
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from Google's Gemini stream response
|
||||
*/
|
||||
interface GeminiStreamUsage {
|
||||
promptTokenCount: number
|
||||
candidatesTokenCount: number
|
||||
totalTokenCount: number
|
||||
}
|
||||
|
||||
function createReadableStreamFromGeminiStream(
|
||||
response: Response,
|
||||
onComplete?: (
|
||||
content: string,
|
||||
usage?: { promptTokenCount?: number; candidatesTokenCount?: number; totalTokenCount?: number }
|
||||
) => void
|
||||
onComplete?: (content: string, usage: GeminiStreamUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
@@ -43,11 +44,32 @@ function createReadableStreamFromGeminiStream(
|
||||
try {
|
||||
let buffer = ''
|
||||
let fullContent = ''
|
||||
let usageData: {
|
||||
promptTokenCount?: number
|
||||
candidatesTokenCount?: number
|
||||
totalTokenCount?: number
|
||||
} | null = null
|
||||
let promptTokenCount = 0
|
||||
let candidatesTokenCount = 0
|
||||
let totalTokenCount = 0
|
||||
|
||||
const updateUsage = (metadata: any) => {
|
||||
if (metadata) {
|
||||
promptTokenCount = metadata.promptTokenCount ?? promptTokenCount
|
||||
candidatesTokenCount = metadata.candidatesTokenCount ?? candidatesTokenCount
|
||||
totalTokenCount = metadata.totalTokenCount ?? totalTokenCount
|
||||
}
|
||||
}
|
||||
|
||||
const buildUsage = (): GeminiStreamUsage => ({
|
||||
promptTokenCount,
|
||||
candidatesTokenCount,
|
||||
totalTokenCount,
|
||||
})
|
||||
|
||||
const complete = () => {
|
||||
if (onComplete) {
|
||||
if (promptTokenCount === 0 && candidatesTokenCount === 0) {
|
||||
logger.warn('Gemini stream completed without usage metadata')
|
||||
}
|
||||
onComplete(fullContent, buildUsage())
|
||||
}
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
@@ -55,20 +77,15 @@ function createReadableStreamFromGeminiStream(
|
||||
if (buffer.trim()) {
|
||||
try {
|
||||
const data = JSON.parse(buffer.trim())
|
||||
if (data.usageMetadata) {
|
||||
usageData = data.usageMetadata
|
||||
}
|
||||
updateUsage(data.usageMetadata)
|
||||
const candidate = data.candidates?.[0]
|
||||
if (candidate?.content?.parts) {
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
'Function call detected in final buffer, ending stream to execute tool',
|
||||
{
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
logger.debug('Function call detected in final buffer', {
|
||||
functionName: functionCall.name,
|
||||
})
|
||||
complete()
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
@@ -84,20 +101,15 @@ function createReadableStreamFromGeminiStream(
|
||||
const dataArray = JSON.parse(buffer.trim())
|
||||
if (Array.isArray(dataArray)) {
|
||||
for (const item of dataArray) {
|
||||
if (item.usageMetadata) {
|
||||
usageData = item.usageMetadata
|
||||
}
|
||||
updateUsage(item.usageMetadata)
|
||||
const candidate = item.candidates?.[0]
|
||||
if (candidate?.content?.parts) {
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
'Function call detected in array item, ending stream to execute tool',
|
||||
{
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
logger.debug('Function call detected in array item', {
|
||||
functionName: functionCall.name,
|
||||
})
|
||||
complete()
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
@@ -109,13 +121,11 @@ function createReadableStreamFromGeminiStream(
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (arrayError) {
|
||||
// Buffer is not valid JSON array
|
||||
}
|
||||
} catch (_) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
complete()
|
||||
controller.close()
|
||||
break
|
||||
}
|
||||
@@ -162,25 +172,17 @@ function createReadableStreamFromGeminiStream(
|
||||
|
||||
try {
|
||||
const data = JSON.parse(jsonStr)
|
||||
|
||||
if (data.usageMetadata) {
|
||||
usageData = data.usageMetadata
|
||||
}
|
||||
|
||||
updateUsage(data.usageMetadata)
|
||||
const candidate = data.candidates?.[0]
|
||||
|
||||
if (candidate?.finishReason === 'UNEXPECTED_TOOL_CALL') {
|
||||
logger.warn('Gemini returned UNEXPECTED_TOOL_CALL in streaming mode', {
|
||||
finishReason: candidate.finishReason,
|
||||
hasContent: !!candidate?.content,
|
||||
hasParts: !!candidate?.content?.parts,
|
||||
})
|
||||
logger.warn('Gemini returned UNEXPECTED_TOOL_CALL in streaming mode')
|
||||
const textContent = extractTextContent(candidate)
|
||||
if (textContent) {
|
||||
fullContent += textContent
|
||||
controller.enqueue(new TextEncoder().encode(textContent))
|
||||
}
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
complete()
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
@@ -188,13 +190,10 @@ function createReadableStreamFromGeminiStream(
|
||||
if (candidate?.content?.parts) {
|
||||
const functionCall = extractFunctionCall(candidate)
|
||||
if (functionCall) {
|
||||
logger.debug(
|
||||
'Function call detected in stream, ending stream to execute tool',
|
||||
{
|
||||
functionName: functionCall.name,
|
||||
}
|
||||
)
|
||||
if (onComplete) onComplete(fullContent, usageData || undefined)
|
||||
logger.debug('Function call detected in stream', {
|
||||
functionName: functionCall.name,
|
||||
})
|
||||
complete()
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
@@ -214,7 +213,6 @@ function createReadableStreamFromGeminiStream(
|
||||
buffer = buffer.substring(closeBrace + 1)
|
||||
searchIndex = 0
|
||||
} else {
|
||||
// No complete JSON object found, wait for more data
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -395,6 +393,7 @@ export const googleProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
@@ -410,6 +409,22 @@ export const googleProvider: ProviderConfig = {
|
||||
response,
|
||||
(content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.promptTokenCount,
|
||||
completion: usage.candidatesTokenCount,
|
||||
total: usage.totalTokenCount || usage.promptTokenCount + usage.candidatesTokenCount,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.promptTokenCount,
|
||||
usage.candidatesTokenCount
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
@@ -426,16 +441,6 @@ export const googleProvider: ProviderConfig = {
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.promptTokenCount || 0,
|
||||
completion: usage.candidatesTokenCount || 0,
|
||||
total:
|
||||
usage.totalTokenCount ||
|
||||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -463,6 +468,11 @@ export const googleProvider: ProviderConfig = {
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
const cost = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
total: 0,
|
||||
}
|
||||
const toolCalls = []
|
||||
const toolResults = []
|
||||
let iterationCount = 0
|
||||
@@ -705,6 +715,15 @@ export const googleProvider: ProviderConfig = {
|
||||
tokens.total +=
|
||||
(checkResult.usageMetadata.promptTokenCount || 0) +
|
||||
(checkResult.usageMetadata.candidatesTokenCount || 0)
|
||||
|
||||
const iterationCost = calculateCost(
|
||||
request.model,
|
||||
checkResult.usageMetadata.promptTokenCount || 0,
|
||||
checkResult.usageMetadata.candidatesTokenCount || 0
|
||||
)
|
||||
cost.input += iterationCost.input
|
||||
cost.output += iterationCost.output
|
||||
cost.total += iterationCost.total
|
||||
}
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
@@ -724,8 +743,6 @@ export const googleProvider: ProviderConfig = {
|
||||
}
|
||||
logger.info('No function call detected, proceeding with streaming response')
|
||||
|
||||
// Apply structured output for the final response if responseFormat is specified
|
||||
// This works regardless of whether tools were forced or auto
|
||||
if (request.responseFormat) {
|
||||
streamingPayload.tools = undefined
|
||||
streamingPayload.toolConfig = undefined
|
||||
@@ -806,6 +823,7 @@ export const googleProvider: ProviderConfig = {
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments,
|
||||
},
|
||||
cost,
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
@@ -822,6 +840,28 @@ export const googleProvider: ProviderConfig = {
|
||||
(content, usage) => {
|
||||
streamingExecution.execution.output.content = content
|
||||
|
||||
const existingTokens = streamingExecution.execution.output.tokens
|
||||
streamingExecution.execution.output.tokens = {
|
||||
prompt: (existingTokens?.prompt ?? 0) + usage.promptTokenCount,
|
||||
completion: (existingTokens?.completion ?? 0) + usage.candidatesTokenCount,
|
||||
total:
|
||||
(existingTokens?.total ?? 0) +
|
||||
(usage.totalTokenCount ||
|
||||
usage.promptTokenCount + usage.candidatesTokenCount),
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.promptTokenCount,
|
||||
usage.candidatesTokenCount
|
||||
)
|
||||
const existingCost = streamingExecution.execution.output.cost as any
|
||||
streamingExecution.execution.output.cost = {
|
||||
input: (existingCost?.input ?? 0) + streamCost.input,
|
||||
output: (existingCost?.output ?? 0) + streamCost.output,
|
||||
total: (existingCost?.total ?? 0) + streamCost.total,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
@@ -831,23 +871,6 @@ export const googleProvider: ProviderConfig = {
|
||||
streamingExecution.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
const existingTokens = streamingExecution.execution.output.tokens || {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
streamingExecution.execution.output.tokens = {
|
||||
prompt: (existingTokens.prompt || 0) + (usage.promptTokenCount || 0),
|
||||
completion:
|
||||
(existingTokens.completion || 0) + (usage.candidatesTokenCount || 0),
|
||||
total:
|
||||
(existingTokens.total || 0) +
|
||||
(usage.totalTokenCount ||
|
||||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -926,7 +949,6 @@ export const googleProvider: ProviderConfig = {
|
||||
const nextFunctionCall = extractFunctionCall(nextCandidate)
|
||||
|
||||
if (!nextFunctionCall) {
|
||||
// If responseFormat is specified, make one final request with structured output
|
||||
if (request.responseFormat) {
|
||||
const finalPayload = {
|
||||
...payload,
|
||||
@@ -969,6 +991,15 @@ export const googleProvider: ProviderConfig = {
|
||||
tokens.total +=
|
||||
(finalResult.usageMetadata.promptTokenCount || 0) +
|
||||
(finalResult.usageMetadata.candidatesTokenCount || 0)
|
||||
|
||||
const iterationCost = calculateCost(
|
||||
request.model,
|
||||
finalResult.usageMetadata.promptTokenCount || 0,
|
||||
finalResult.usageMetadata.candidatesTokenCount || 0
|
||||
)
|
||||
cost.input += iterationCost.input
|
||||
cost.output += iterationCost.output
|
||||
cost.total += iterationCost.total
|
||||
}
|
||||
} else {
|
||||
logger.warn(
|
||||
@@ -1054,7 +1085,7 @@ export const googleProvider: ProviderConfig = {
|
||||
})
|
||||
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
// @ts-ignore
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { Candidate } from '@google/genai'
|
||||
import type { ProviderRequest } from '@/providers/types'
|
||||
|
||||
/**
|
||||
@@ -23,7 +24,7 @@ export function cleanSchemaForGemini(schema: any): any {
|
||||
/**
|
||||
* Extracts text content from a Gemini response candidate, handling structured output
|
||||
*/
|
||||
export function extractTextContent(candidate: any): string {
|
||||
export function extractTextContent(candidate: Candidate | undefined): string {
|
||||
if (!candidate?.content?.parts) return ''
|
||||
|
||||
if (candidate.content.parts?.length === 1 && candidate.content.parts[0].text) {
|
||||
@@ -32,9 +33,7 @@ export function extractTextContent(candidate: any): string {
|
||||
try {
|
||||
JSON.parse(text)
|
||||
return text
|
||||
} catch (_e) {
|
||||
/* Not valid JSON, continue with normal extraction */
|
||||
}
|
||||
} catch (_e) {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,34 +46,20 @@ export function extractTextContent(candidate: any): string {
|
||||
/**
|
||||
* Extracts a function call from a Gemini response candidate
|
||||
*/
|
||||
export function extractFunctionCall(candidate: any): { name: string; args: any } | null {
|
||||
export function extractFunctionCall(
|
||||
candidate: Candidate | undefined
|
||||
): { name: string; args: any } | null {
|
||||
if (!candidate?.content?.parts) return null
|
||||
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.functionCall) {
|
||||
const args = part.functionCall.args || {}
|
||||
if (
|
||||
typeof part.functionCall.args === 'string' &&
|
||||
part.functionCall.args.trim().startsWith('{')
|
||||
) {
|
||||
try {
|
||||
return { name: part.functionCall.name, args: JSON.parse(part.functionCall.args) }
|
||||
} catch (_e) {
|
||||
return { name: part.functionCall.name, args: part.functionCall.args }
|
||||
}
|
||||
return {
|
||||
name: part.functionCall.name ?? '',
|
||||
args: part.functionCall.args ?? {},
|
||||
}
|
||||
return { name: part.functionCall.name, args }
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate.content.function_call) {
|
||||
const args =
|
||||
typeof candidate.content.function_call.arguments === 'string'
|
||||
? JSON.parse(candidate.content.function_call.arguments || '{}')
|
||||
: candidate.content.function_call.arguments || {}
|
||||
return { name: candidate.content.function_call.name, args }
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
@@ -34,13 +35,10 @@ export const groqProvider: ProviderConfig = {
|
||||
throw new Error('API key is required for Groq')
|
||||
}
|
||||
|
||||
// Create Groq client
|
||||
const groq = new Groq({ apiKey: request.apiKey })
|
||||
|
||||
// Start with an empty array for all messages
|
||||
const allMessages = []
|
||||
|
||||
// Add system prompt if present
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
@@ -48,7 +46,6 @@ export const groqProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add context if present
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
@@ -56,12 +53,10 @@ export const groqProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to function format if provided
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
@@ -73,7 +68,6 @@ export const groqProvider: ProviderConfig = {
|
||||
}))
|
||||
: undefined
|
||||
|
||||
// Build the request payload
|
||||
const payload: any = {
|
||||
model: (request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct').replace(
|
||||
'groq/',
|
||||
@@ -82,20 +76,20 @@ export const groqProvider: ProviderConfig = {
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
// Add optional parameters
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
// Add response format for structured output if specified
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
schema: request.responseFormat.schema || request.responseFormat,
|
||||
json_schema: {
|
||||
name: request.responseFormat.name || 'response_schema',
|
||||
schema: request.responseFormat.schema || request.responseFormat,
|
||||
strict: request.responseFormat.strict !== false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tools and tool usage control
|
||||
// Groq supports full OpenAI-compatible tool_choice including forcing specific tools
|
||||
let originalToolChoice: any
|
||||
let forcedTools: string[] = []
|
||||
let hasFilteredTools = false
|
||||
@@ -120,12 +114,9 @@ export const groqProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// EARLY STREAMING: if caller requested streaming and there are no tools to execute,
|
||||
// we can directly stream the completion.
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info('Using streaming response for Groq request (no tools)')
|
||||
|
||||
// Start execution timer for the entire provider execution
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
@@ -134,22 +125,32 @@ export const groqProvider: ProviderConfig = {
|
||||
stream: true,
|
||||
})
|
||||
|
||||
// Start collecting token usage
|
||||
const tokenUsage = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
|
||||
// Create a StreamingExecution response with a readable stream
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromGroqStream(streamResponse),
|
||||
stream: createReadableStreamFromGroqStream(streamResponse as any, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.prompt_tokens,
|
||||
completion: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by streaming content in chat component
|
||||
content: '',
|
||||
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
|
||||
tokens: tokenUsage,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
@@ -165,13 +166,9 @@ export const groqProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
cost: {
|
||||
total: 0.0,
|
||||
input: 0.0,
|
||||
output: 0.0,
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [], // No block logs for direct streaming
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -181,16 +178,13 @@ export const groqProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
|
||||
// Return the streaming execution object
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Start execution timer for the entire provider execution
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
// Make the initial API request
|
||||
const initialCallTime = Date.now()
|
||||
|
||||
let currentResponse = await groq.chat.completions.create(payload)
|
||||
@@ -206,12 +200,9 @@ export const groqProvider: ProviderConfig = {
|
||||
const toolResults = []
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
// Track each model and tool call segment with timestamps
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
@@ -224,98 +215,121 @@ export const groqProvider: ProviderConfig = {
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
// Track time for tool calls in this batch
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
// Process each tool call
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
// Get the tool from the tools registry
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
if (!tool) return null
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
// Add to time segments for both success and failure
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
// Prepare result content for the LLM
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
// Include error information so LLM can respond appropriately
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages (both success and failure)
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error processing tool call:', { error })
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call:', { error, toolName })
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: toolCallsInResponse.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate tool call time for this iteration
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
// Check if we used any forced tools and update tool_choice for the next iteration
|
||||
let usedForcedTools: string[] = []
|
||||
if (typeof originalToolChoice === 'object' && forcedTools.length > 0) {
|
||||
const toolTracking = trackForcedToolUsage(
|
||||
@@ -329,31 +343,24 @@ export const groqProvider: ProviderConfig = {
|
||||
usedForcedTools = toolTracking.usedForcedTools
|
||||
const nextToolChoice = toolTracking.nextToolChoice
|
||||
|
||||
// Update tool_choice for next iteration if we're still forcing tools
|
||||
if (nextToolChoice && typeof nextToolChoice === 'object') {
|
||||
payload.tool_choice = nextToolChoice
|
||||
} else if (nextToolChoice === 'auto' || !nextToolChoice) {
|
||||
// All forced tools have been used, switch to auto
|
||||
payload.tool_choice = 'auto'
|
||||
}
|
||||
}
|
||||
|
||||
// Make the next request with updated messages
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
// Time the next model call
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
// Make the next request
|
||||
currentResponse = await groq.chat.completions.create(nextPayload)
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
@@ -362,15 +369,12 @@ export const groqProvider: ProviderConfig = {
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
// Add to model time
|
||||
modelTime += thisModelTime
|
||||
|
||||
// Update content if we have a text response
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
// Update token counts
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += currentResponse.usage.completion_tokens || 0
|
||||
@@ -383,28 +387,44 @@ export const groqProvider: ProviderConfig = {
|
||||
logger.error('Error in Groq request:', { error })
|
||||
}
|
||||
|
||||
// After all tool processing complete, if streaming was requested and we have messages, use streaming for the final response
|
||||
if (request.stream) {
|
||||
logger.info('Using streaming for final Groq response after tool processing')
|
||||
|
||||
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
|
||||
// This prevents the API from trying to force tool usage again in the final streaming response
|
||||
const streamingPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
|
||||
tool_choice: 'auto',
|
||||
stream: true,
|
||||
}
|
||||
|
||||
const streamResponse = await groq.chat.completions.create(streamingPayload)
|
||||
|
||||
// Create a StreamingExecution response with all collected data
|
||||
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromGroqStream(streamResponse),
|
||||
stream: createReadableStreamFromGroqStream(streamResponse as any, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: tokens.prompt + usage.prompt_tokens,
|
||||
completion: tokens.completion + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by the callback
|
||||
content: '',
|
||||
model: request.model || 'groq/meta-llama/llama-4-scout-17b-16e-instruct',
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
@@ -429,12 +449,12 @@ export const groqProvider: ProviderConfig = {
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
total: (tokens.total || 0) * 0.0001,
|
||||
input: (tokens.prompt || 0) * 0.0001,
|
||||
output: (tokens.completion || 0) * 0.0001,
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [], // No block logs at provider level
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -444,11 +464,9 @@ export const groqProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
|
||||
// Return the streaming execution object
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Calculate overall timing
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -471,7 +489,6 @@ export const groqProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
// Include timing information even for errors
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -481,9 +498,8 @@ export const groqProvider: ProviderConfig = {
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
// Create a new error with timing information
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
// @ts-ignore
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
|
||||
@@ -1,23 +1,14 @@
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import { createOpenAICompatibleStream } from '@/providers/utils'
|
||||
|
||||
/**
|
||||
* Helper to wrap Groq streaming into a browser-friendly ReadableStream
|
||||
* of raw assistant text chunks.
|
||||
*
|
||||
* @param groqStream - The Groq streaming response
|
||||
* @returns A ReadableStream that emits text chunks
|
||||
* Creates a ReadableStream from a Groq streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
*/
|
||||
export function createReadableStreamFromGroqStream(groqStream: any): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of groqStream) {
|
||||
if (chunk.choices[0]?.delta?.content) {
|
||||
controller.enqueue(new TextEncoder().encode(chunk.choices[0].delta.content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (err) {
|
||||
controller.error(err)
|
||||
}
|
||||
},
|
||||
})
|
||||
export function createReadableStreamFromGroqStream(
|
||||
groqStream: AsyncIterable<ChatCompletionChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createOpenAICompatibleStream(groqStream, 'Groq', onComplete)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import OpenAI from 'openai'
|
||||
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
@@ -11,6 +12,7 @@ import type {
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
@@ -100,8 +102,6 @@ export const mistralProvider: ProviderConfig = {
|
||||
strict: request.responseFormat.strict !== false,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info('Added JSON schema response format to request')
|
||||
}
|
||||
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
@@ -138,24 +138,32 @@ export const mistralProvider: ProviderConfig = {
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info('Using streaming response for Mistral request')
|
||||
|
||||
const streamResponse = await mistral.chat.completions.create({
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
})
|
||||
|
||||
const tokenUsage = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
|
||||
let _streamContent = ''
|
||||
const streamResponse = await mistral.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromMistralStream(streamResponse, (content, usage) => {
|
||||
_streamContent = content
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.prompt_tokens,
|
||||
completion: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
@@ -172,23 +180,13 @@ export const mistralProvider: ProviderConfig = {
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokenUsage.prompt,
|
||||
completion: usage.completion_tokens || tokenUsage.completion,
|
||||
total: usage.total_tokens || tokenUsage.total,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: tokenUsage,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
@@ -204,6 +202,7 @@ export const mistralProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
@@ -275,6 +274,10 @@ export const mistralProvider: ProviderConfig = {
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
@@ -285,78 +288,103 @@ export const mistralProvider: ProviderConfig = {
|
||||
)
|
||||
|
||||
const toolsStartTime = Date.now()
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
const toolCallStartTime = Date.now()
|
||||
if (!tool) return null
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error processing tool call:', {
|
||||
error,
|
||||
toolName: toolCall?.function?.name,
|
||||
})
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call:', { error, toolName })
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: toolCallsInResponse.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
@@ -417,31 +445,35 @@ export const mistralProvider: ProviderConfig = {
|
||||
if (request.stream) {
|
||||
logger.info('Using streaming for final response after tool processing')
|
||||
|
||||
const streamingPayload = {
|
||||
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
|
||||
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto',
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
|
||||
const streamResponse = await mistral.chat.completions.create(streamingPayload)
|
||||
|
||||
let _streamContent = ''
|
||||
const streamResponse = await mistral.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromMistralStream(streamResponse, (content, usage) => {
|
||||
_streamContent = content
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: tokens.prompt + usage.prompt_tokens,
|
||||
completion: tokens.completion + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokens.prompt,
|
||||
completion: usage.completion_tokens || tokens.completion,
|
||||
total: usage.total_tokens || tokens.total,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
@@ -471,6 +503,11 @@ export const mistralProvider: ProviderConfig = {
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
@@ -516,7 +553,7 @@ export const mistralProvider: ProviderConfig = {
|
||||
})
|
||||
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
// @ts-ignore - Adding timing property to error for debugging
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
|
||||
@@ -1,39 +1,14 @@
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import { createOpenAICompatibleStream } from '@/providers/utils'
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from a Mistral AI streaming response
|
||||
* @param mistralStream - The Mistral AI stream object
|
||||
* @param onComplete - Optional callback when streaming completes
|
||||
* @returns A ReadableStream that yields text chunks
|
||||
* Creates a ReadableStream from a Mistral streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
*/
|
||||
export function createReadableStreamFromMistralStream(
|
||||
mistralStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of mistralStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
mistralStream: AsyncIterable<ChatCompletionChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createOpenAICompatibleStream(mistralStream, 'Mistral', onComplete)
|
||||
}
|
||||
|
||||
@@ -23,13 +23,7 @@ import {
|
||||
VllmIcon,
|
||||
xAIIcon,
|
||||
} from '@/components/icons'
|
||||
|
||||
export interface ModelPricing {
|
||||
input: number // Per 1M tokens
|
||||
cachedInput?: number // Per 1M tokens (if supported)
|
||||
output: number // Per 1M tokens
|
||||
updatedAt: string
|
||||
}
|
||||
import type { ModelPricing } from '@/providers/types'
|
||||
|
||||
export interface ModelCapabilities {
|
||||
temperature?: {
|
||||
@@ -38,6 +32,7 @@ export interface ModelCapabilities {
|
||||
}
|
||||
toolUsageControl?: boolean
|
||||
computerUse?: boolean
|
||||
nativeStructuredOutputs?: boolean
|
||||
reasoningEffort?: {
|
||||
values: string[]
|
||||
}
|
||||
@@ -50,7 +45,7 @@ export interface ModelDefinition {
|
||||
id: string
|
||||
pricing: ModelPricing
|
||||
capabilities: ModelCapabilities
|
||||
contextWindow?: number // Maximum context window in tokens (may be undefined for dynamic providers)
|
||||
contextWindow?: number
|
||||
}
|
||||
|
||||
export interface ProviderDefinition {
|
||||
@@ -62,13 +57,9 @@ export interface ProviderDefinition {
|
||||
modelPatterns?: RegExp[]
|
||||
icon?: React.ComponentType<{ className?: string }>
|
||||
capabilities?: ModelCapabilities
|
||||
// Indicates whether reliable context window information is available for this provider's models
|
||||
contextInformationAvailable?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Comprehensive provider definitions, single source of truth
|
||||
*/
|
||||
export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
openrouter: {
|
||||
id: 'openrouter',
|
||||
@@ -616,6 +607,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 1 },
|
||||
nativeStructuredOutputs: true,
|
||||
},
|
||||
contextWindow: 200000,
|
||||
},
|
||||
@@ -629,6 +621,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 1 },
|
||||
nativeStructuredOutputs: true,
|
||||
},
|
||||
contextWindow: 200000,
|
||||
},
|
||||
@@ -655,6 +648,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 1 },
|
||||
nativeStructuredOutputs: true,
|
||||
},
|
||||
contextWindow: 200000,
|
||||
},
|
||||
@@ -668,6 +662,7 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
capabilities: {
|
||||
temperature: { min: 0, max: 1 },
|
||||
nativeStructuredOutputs: true,
|
||||
},
|
||||
contextWindow: 200000,
|
||||
},
|
||||
@@ -1619,23 +1614,14 @@ export const PROVIDER_DEFINITIONS: Record<string, ProviderDefinition> = {
|
||||
},
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models for a specific provider
|
||||
*/
|
||||
export function getProviderModels(providerId: string): string[] {
|
||||
return PROVIDER_DEFINITIONS[providerId]?.models.map((m) => m.id) || []
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the default model for a specific provider
|
||||
*/
|
||||
export function getProviderDefaultModel(providerId: string): string {
|
||||
return PROVIDER_DEFINITIONS[providerId]?.defaultModel || ''
|
||||
}
|
||||
|
||||
/**
|
||||
* Get pricing information for a specific model
|
||||
*/
|
||||
export function getModelPricing(modelId: string): ModelPricing | null {
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase())
|
||||
@@ -1646,20 +1632,15 @@ export function getModelPricing(modelId: string): ModelPricing | null {
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get capabilities for a specific model
|
||||
*/
|
||||
export function getModelCapabilities(modelId: string): ModelCapabilities | null {
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
const model = provider.models.find((m) => m.id.toLowerCase() === modelId.toLowerCase())
|
||||
if (model) {
|
||||
// Merge provider capabilities with model capabilities, model takes precedence
|
||||
const capabilities: ModelCapabilities = { ...provider.capabilities, ...model.capabilities }
|
||||
return capabilities
|
||||
}
|
||||
}
|
||||
|
||||
// If no model found, check for provider-level capabilities for dynamically fetched models
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
if (provider.modelPatterns) {
|
||||
for (const pattern of provider.modelPatterns) {
|
||||
@@ -1673,9 +1654,6 @@ export function getModelCapabilities(modelId: string): ModelCapabilities | null
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models that support temperature
|
||||
*/
|
||||
export function getModelsWithTemperatureSupport(): string[] {
|
||||
const models: string[] = []
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
@@ -1688,9 +1666,6 @@ export function getModelsWithTemperatureSupport(): string[] {
|
||||
return models
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models with temperature range 0-1
|
||||
*/
|
||||
export function getModelsWithTempRange01(): string[] {
|
||||
const models: string[] = []
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
@@ -1703,9 +1678,6 @@ export function getModelsWithTempRange01(): string[] {
|
||||
return models
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models with temperature range 0-2
|
||||
*/
|
||||
export function getModelsWithTempRange02(): string[] {
|
||||
const models: string[] = []
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
@@ -1718,9 +1690,6 @@ export function getModelsWithTempRange02(): string[] {
|
||||
return models
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all providers that support tool usage control
|
||||
*/
|
||||
export function getProvidersWithToolUsageControl(): string[] {
|
||||
const providers: string[] = []
|
||||
for (const [providerId, provider] of Object.entries(PROVIDER_DEFINITIONS)) {
|
||||
@@ -1731,9 +1700,6 @@ export function getProvidersWithToolUsageControl(): string[] {
|
||||
return providers
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models that are hosted (don't require user API keys)
|
||||
*/
|
||||
export function getHostedModels(): string[] {
|
||||
return [
|
||||
...getProviderModels('openai'),
|
||||
@@ -1742,9 +1708,6 @@ export function getHostedModels(): string[] {
|
||||
]
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all computer use models
|
||||
*/
|
||||
export function getComputerUseModels(): string[] {
|
||||
const models: string[] = []
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
@@ -1757,32 +1720,20 @@ export function getComputerUseModels(): string[] {
|
||||
return models
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model supports temperature
|
||||
*/
|
||||
export function supportsTemperature(modelId: string): boolean {
|
||||
const capabilities = getModelCapabilities(modelId)
|
||||
return !!capabilities?.temperature
|
||||
}
|
||||
|
||||
/**
|
||||
* Get maximum temperature for a model
|
||||
*/
|
||||
export function getMaxTemperature(modelId: string): number | undefined {
|
||||
const capabilities = getModelCapabilities(modelId)
|
||||
return capabilities?.temperature?.max
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a provider supports tool usage control
|
||||
*/
|
||||
export function supportsToolUsageControl(providerId: string): boolean {
|
||||
return getProvidersWithToolUsageControl().includes(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Update Ollama models dynamically
|
||||
*/
|
||||
export function updateOllamaModels(models: string[]): void {
|
||||
PROVIDER_DEFINITIONS.ollama.models = models.map((modelId) => ({
|
||||
id: modelId,
|
||||
@@ -1795,9 +1746,6 @@ export function updateOllamaModels(models: string[]): void {
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Update vLLM models dynamically
|
||||
*/
|
||||
export function updateVLLMModels(models: string[]): void {
|
||||
PROVIDER_DEFINITIONS.vllm.models = models.map((modelId) => ({
|
||||
id: modelId,
|
||||
@@ -1810,9 +1758,6 @@ export function updateVLLMModels(models: string[]): void {
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Update OpenRouter models dynamically
|
||||
*/
|
||||
export function updateOpenRouterModels(models: string[]): void {
|
||||
PROVIDER_DEFINITIONS.openrouter.models = models.map((modelId) => ({
|
||||
id: modelId,
|
||||
@@ -1825,9 +1770,6 @@ export function updateOpenRouterModels(models: string[]): void {
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Embedding model pricing - separate from chat models
|
||||
*/
|
||||
export const EMBEDDING_MODEL_PRICING: Record<string, ModelPricing> = {
|
||||
'text-embedding-3-small': {
|
||||
input: 0.02, // $0.02 per 1M tokens
|
||||
@@ -1846,16 +1788,10 @@ export const EMBEDDING_MODEL_PRICING: Record<string, ModelPricing> = {
|
||||
},
|
||||
}
|
||||
|
||||
/**
|
||||
* Get pricing for embedding models specifically
|
||||
*/
|
||||
export function getEmbeddingModelPricing(modelId: string): ModelPricing | null {
|
||||
return EMBEDDING_MODEL_PRICING[modelId] || null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models that support reasoning effort
|
||||
*/
|
||||
export function getModelsWithReasoningEffort(): string[] {
|
||||
const models: string[] = []
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
@@ -1882,9 +1818,6 @@ export function getReasoningEffortValuesForModel(modelId: string): string[] | nu
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models that support verbosity
|
||||
*/
|
||||
export function getModelsWithVerbosity(): string[] {
|
||||
const models: string[] = []
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
@@ -1910,3 +1843,24 @@ export function getVerbosityValuesForModel(modelId: string): string[] | null {
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model supports native structured outputs.
|
||||
* Handles model IDs with date suffixes (e.g., claude-sonnet-4-5-20250514).
|
||||
*/
|
||||
export function supportsNativeStructuredOutputs(modelId: string): boolean {
|
||||
const normalizedModelId = modelId.toLowerCase()
|
||||
|
||||
for (const provider of Object.values(PROVIDER_DEFINITIONS)) {
|
||||
for (const model of provider.models) {
|
||||
if (model.capabilities.nativeStructuredOutputs) {
|
||||
const baseModelId = model.id.toLowerCase()
|
||||
// Check exact match or date-suffixed version (e.g., claude-sonnet-4-5-20250514)
|
||||
if (normalizedModelId === baseModelId || normalizedModelId.startsWith(`${baseModelId}-`)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import OpenAI from 'openai'
|
||||
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
@@ -11,7 +12,7 @@ import type {
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import { prepareToolExecution } from '@/providers/utils'
|
||||
import { calculateCost, prepareToolExecution } from '@/providers/utils'
|
||||
import { useProvidersStore } from '@/stores/providers/store'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
@@ -23,10 +24,9 @@ export const ollamaProvider: ProviderConfig = {
|
||||
name: 'Ollama',
|
||||
description: 'Local Ollama server for LLM inference',
|
||||
version: '1.0.0',
|
||||
models: [], // Will be populated dynamically
|
||||
models: [],
|
||||
defaultModel: '',
|
||||
|
||||
// Initialize the provider by fetching available models
|
||||
async initialize() {
|
||||
if (typeof window !== 'undefined') {
|
||||
logger.info('Skipping Ollama initialization on client side to avoid CORS issues')
|
||||
@@ -63,16 +63,13 @@ export const ollamaProvider: ProviderConfig = {
|
||||
stream: !!request.stream,
|
||||
})
|
||||
|
||||
// Create Ollama client using OpenAI-compatible API
|
||||
const ollama = new OpenAI({
|
||||
apiKey: 'empty',
|
||||
baseURL: `${OLLAMA_HOST}/v1`,
|
||||
})
|
||||
|
||||
// Start with an empty array for all messages
|
||||
const allMessages = []
|
||||
|
||||
// Add system prompt if present
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
@@ -80,7 +77,6 @@ export const ollamaProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add context if present
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
@@ -88,12 +84,10 @@ export const ollamaProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to OpenAI format if provided
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
@@ -105,19 +99,15 @@ export const ollamaProvider: ProviderConfig = {
|
||||
}))
|
||||
: undefined
|
||||
|
||||
// Build the request payload
|
||||
const payload: any = {
|
||||
model: request.model,
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
// Add optional parameters
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
// Add response format for structured output if specified
|
||||
if (request.responseFormat) {
|
||||
// Use OpenAI's JSON schema format (Ollama supports this)
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
@@ -130,21 +120,13 @@ export const ollamaProvider: ProviderConfig = {
|
||||
logger.info('Added JSON schema response format to Ollama request')
|
||||
}
|
||||
|
||||
// Handle tools and tool usage control
|
||||
// NOTE: Ollama does NOT support the tool_choice parameter beyond basic 'auto' behavior
|
||||
// According to official documentation, tool_choice is silently ignored
|
||||
// Ollama only supports basic function calling where the model autonomously decides
|
||||
if (tools?.length) {
|
||||
// Filter out tools with usageControl='none'
|
||||
// Treat 'force' as 'auto' since Ollama doesn't support forced tool selection
|
||||
const filteredTools = tools.filter((tool) => {
|
||||
const toolId = tool.function?.name
|
||||
const toolConfig = request.tools?.find((t) => t.id === toolId)
|
||||
// Only filter out 'none', treat 'force' as 'auto'
|
||||
return toolConfig?.usageControl !== 'none'
|
||||
})
|
||||
|
||||
// Check if any tools were forcibly marked
|
||||
const hasForcedTools = tools.some((tool) => {
|
||||
const toolId = tool.function?.name
|
||||
const toolConfig = request.tools?.find((t) => t.id === toolId)
|
||||
@@ -160,55 +142,58 @@ export const ollamaProvider: ProviderConfig = {
|
||||
|
||||
if (filteredTools?.length) {
|
||||
payload.tools = filteredTools
|
||||
// Ollama only supports 'auto' behavior - model decides whether to use tools
|
||||
payload.tool_choice = 'auto'
|
||||
|
||||
logger.info('Ollama request configuration:', {
|
||||
toolCount: filteredTools.length,
|
||||
toolChoice: 'auto', // Ollama always uses auto
|
||||
toolChoice: 'auto',
|
||||
forcedToolsIgnored: hasForcedTools,
|
||||
model: request.model,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Start execution timer for the entire provider execution
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
// Check if we can stream directly (no tools required)
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info('Using streaming response for Ollama request')
|
||||
|
||||
// Create a streaming request with token usage tracking
|
||||
const streamResponse = await ollama.chat.completions.create({
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
})
|
||||
|
||||
// Start collecting token usage from the stream
|
||||
const tokenUsage = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
const streamResponse = await ollama.chat.completions.create(streamingParams)
|
||||
|
||||
// Create a StreamingExecution response with a callback to update content and tokens
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromOllamaStream(streamResponse, (content, usage) => {
|
||||
// Update the execution data with the final content and token usage
|
||||
streamingResult.execution.output.content = content
|
||||
|
||||
// Clean up the response content
|
||||
if (content) {
|
||||
streamingResult.execution.output.content = content
|
||||
.replace(/```json\n?|\n?```/g, '')
|
||||
.trim()
|
||||
}
|
||||
|
||||
// Update the timing information with the actual completion time
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.prompt_tokens,
|
||||
completion: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
@@ -217,7 +202,6 @@ export const ollamaProvider: ProviderConfig = {
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
|
||||
// Update the time segment as well
|
||||
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
|
||||
streamEndTime
|
||||
@@ -225,24 +209,13 @@ export const ollamaProvider: ProviderConfig = {
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
|
||||
// Update token usage if available from the stream
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokenUsage.prompt,
|
||||
completion: usage.completion_tokens || tokenUsage.completion,
|
||||
total: usage.total_tokens || tokenUsage.total,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by the stream completion callback
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: tokenUsage,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
@@ -258,8 +231,9 @@ export const ollamaProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [], // No block logs for direct streaming
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -268,11 +242,9 @@ export const ollamaProvider: ProviderConfig = {
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
// Return the streaming execution object
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Make the initial API request
|
||||
const initialCallTime = Date.now()
|
||||
|
||||
let currentResponse = await ollama.chat.completions.create(payload)
|
||||
@@ -280,13 +252,11 @@ export const ollamaProvider: ProviderConfig = {
|
||||
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
|
||||
// Clean up the response content if it exists
|
||||
if (content) {
|
||||
content = content.replace(/```json\n?|\n?```/g, '')
|
||||
content = content.trim()
|
||||
}
|
||||
|
||||
// Collect token information
|
||||
const tokens = {
|
||||
prompt: currentResponse.usage?.prompt_tokens || 0,
|
||||
completion: currentResponse.usage?.completion_tokens || 0,
|
||||
@@ -297,11 +267,9 @@ export const ollamaProvider: ProviderConfig = {
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
// Track each model and tool call segment with timestamps
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
@@ -313,7 +281,10 @@ export const ollamaProvider: ProviderConfig = {
|
||||
]
|
||||
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
@@ -323,109 +294,124 @@ export const ollamaProvider: ProviderConfig = {
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
// Track time for tool calls in this batch
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
// Process each tool call
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
// Get the tool from the tools registry
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
if (!tool) return null
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
// Add to time segments for both success and failure
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
// Prepare result content for the LLM
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
// Include error information so LLM can respond appropriately
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages (both success and failure)
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error processing tool call:', {
|
||||
error,
|
||||
toolName: toolCall?.function?.name,
|
||||
})
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call:', { error, toolName })
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: toolCallsInResponse.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate tool call time for this iteration
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
// Make the next request with updated messages
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
// Time the next model call
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
// Make the next request
|
||||
currentResponse = await ollama.chat.completions.create(nextPayload)
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
@@ -434,18 +420,14 @@ export const ollamaProvider: ProviderConfig = {
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
// Add to model time
|
||||
modelTime += thisModelTime
|
||||
|
||||
// Update content if we have a text response
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
// Clean up the response content
|
||||
content = content.replace(/```json\n?|\n?```/g, '')
|
||||
content = content.trim()
|
||||
}
|
||||
|
||||
// Update token counts
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += currentResponse.usage.completion_tokens || 0
|
||||
@@ -455,48 +437,51 @@ export const ollamaProvider: ProviderConfig = {
|
||||
iterationCount++
|
||||
}
|
||||
|
||||
// After all tool processing complete, if streaming was requested and we have messages, use streaming for the final response
|
||||
if (request.stream) {
|
||||
logger.info('Using streaming for final response after tool processing')
|
||||
|
||||
const streamingPayload = {
|
||||
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
|
||||
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
|
||||
tool_choice: 'auto',
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
const streamResponse = await ollama.chat.completions.create(streamingParams)
|
||||
|
||||
const streamResponse = await ollama.chat.completions.create(streamingPayload)
|
||||
|
||||
// Create the StreamingExecution object with all collected data
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromOllamaStream(streamResponse, (content, usage) => {
|
||||
// Update the execution data with the final content and token usage
|
||||
streamingResult.execution.output.content = content
|
||||
|
||||
// Clean up the response content
|
||||
if (content) {
|
||||
streamingResult.execution.output.content = content
|
||||
.replace(/```json\n?|\n?```/g, '')
|
||||
.trim()
|
||||
}
|
||||
|
||||
// Update token usage if available from the stream
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokens.prompt,
|
||||
completion: usage.completion_tokens || tokens.completion,
|
||||
total: usage.total_tokens || tokens.total,
|
||||
}
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: tokens.prompt + usage.prompt_tokens,
|
||||
completion: tokens.completion + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by the callback
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
@@ -520,8 +505,13 @@ export const ollamaProvider: ProviderConfig = {
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [], // No block logs at provider level
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -530,11 +520,9 @@ export const ollamaProvider: ProviderConfig = {
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
// Return the streaming execution object
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Calculate overall timing
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -557,7 +545,6 @@ export const ollamaProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
// Include timing information even for errors
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -567,9 +554,8 @@ export const ollamaProvider: ProviderConfig = {
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
// Create a new error with timing information
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
// @ts-ignore
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
|
||||
@@ -1,37 +1,14 @@
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import { createOpenAICompatibleStream } from '@/providers/utils'
|
||||
|
||||
/**
|
||||
* Helper function to convert an Ollama stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
* Creates a ReadableStream from an Ollama streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
*/
|
||||
export function createReadableStreamFromOllamaStream(
|
||||
ollamaStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of ollamaStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
ollamaStream: AsyncIterable<ChatCompletionChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createOpenAICompatibleStream(ollamaStream, 'Ollama', onComplete)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import OpenAI from 'openai'
|
||||
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
@@ -11,6 +12,7 @@ import type {
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
@@ -19,9 +21,6 @@ import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('OpenAIProvider')
|
||||
|
||||
/**
|
||||
* OpenAI provider configuration
|
||||
*/
|
||||
export const openaiProvider: ProviderConfig = {
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
@@ -43,13 +42,10 @@ export const openaiProvider: ProviderConfig = {
|
||||
stream: !!request.stream,
|
||||
})
|
||||
|
||||
// API key is now handled server-side before this function is called
|
||||
const openai = new OpenAI({ apiKey: request.apiKey })
|
||||
|
||||
// Start with an empty array for all messages
|
||||
const allMessages = []
|
||||
|
||||
// Add system prompt if present
|
||||
if (request.systemPrompt) {
|
||||
allMessages.push({
|
||||
role: 'system',
|
||||
@@ -57,7 +53,6 @@ export const openaiProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add context if present
|
||||
if (request.context) {
|
||||
allMessages.push({
|
||||
role: 'user',
|
||||
@@ -65,12 +60,10 @@ export const openaiProvider: ProviderConfig = {
|
||||
})
|
||||
}
|
||||
|
||||
// Add remaining messages
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Transform tools to OpenAI format if provided
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
@@ -82,23 +75,18 @@ export const openaiProvider: ProviderConfig = {
|
||||
}))
|
||||
: undefined
|
||||
|
||||
// Build the request payload
|
||||
const payload: any = {
|
||||
model: request.model || 'gpt-4o',
|
||||
messages: allMessages,
|
||||
}
|
||||
|
||||
// Add optional parameters
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
// Add GPT-5 specific parameters
|
||||
if (request.reasoningEffort !== undefined) payload.reasoning_effort = request.reasoningEffort
|
||||
if (request.verbosity !== undefined) payload.verbosity = request.verbosity
|
||||
|
||||
// Add response format for structured output if specified
|
||||
if (request.responseFormat) {
|
||||
// Use OpenAI's JSON schema format
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
@@ -111,7 +99,6 @@ export const openaiProvider: ProviderConfig = {
|
||||
logger.info('Added JSON schema response format to request')
|
||||
}
|
||||
|
||||
// Handle tools and tool usage control
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
|
||||
if (tools?.length) {
|
||||
@@ -139,39 +126,40 @@ export const openaiProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// Start execution timer for the entire provider execution
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
// Check if we can stream directly (no tools required)
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info('Using streaming response for OpenAI request')
|
||||
|
||||
// Create a streaming request with token usage tracking
|
||||
const streamResponse = await openai.chat.completions.create({
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
})
|
||||
|
||||
// Start collecting token usage from the stream
|
||||
const tokenUsage = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
const streamResponse = await openai.chat.completions.create(streamingParams)
|
||||
|
||||
let _streamContent = ''
|
||||
|
||||
// Create a StreamingExecution response with a callback to update content and tokens
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
|
||||
// Update the execution data with the final content and token usage
|
||||
_streamContent = content
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.prompt_tokens,
|
||||
completion: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
|
||||
// Update the timing information with the actual completion time
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
|
||||
@@ -180,7 +168,6 @@ export const openaiProvider: ProviderConfig = {
|
||||
streamingResult.execution.output.providerTiming.duration =
|
||||
streamEndTime - providerStartTime
|
||||
|
||||
// Update the time segment as well
|
||||
if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) {
|
||||
streamingResult.execution.output.providerTiming.timeSegments[0].endTime =
|
||||
streamEndTime
|
||||
@@ -188,25 +175,13 @@ export const openaiProvider: ProviderConfig = {
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
|
||||
// Update token usage if available from the stream
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokenUsage.prompt,
|
||||
completion: usage.completion_tokens || tokenUsage.completion,
|
||||
total: usage.total_tokens || tokenUsage.total,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
}
|
||||
// We don't need to estimate tokens here as logger.ts will handle that
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by the stream completion callback
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: tokenUsage,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
@@ -222,9 +197,9 @@ export const openaiProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
// Cost will be calculated in logger
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [], // No block logs for direct streaming
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -233,21 +208,19 @@ export const openaiProvider: ProviderConfig = {
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
// Return the streaming execution object with explicit casting
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Make the initial API request
|
||||
const initialCallTime = Date.now()
|
||||
|
||||
// Track the original tool_choice for forced tool tracking
|
||||
const originalToolChoice = payload.tool_choice
|
||||
|
||||
// Track forced tools and their usage
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
|
||||
// Helper function to check for forced tool usage in responses
|
||||
/**
|
||||
* Helper function to check for forced tool usage in responses
|
||||
*/
|
||||
const checkForForcedToolUsage = (
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }
|
||||
@@ -271,7 +244,6 @@ export const openaiProvider: ProviderConfig = {
|
||||
const firstResponseTime = Date.now() - initialCallTime
|
||||
|
||||
let content = currentResponse.choices[0]?.message?.content || ''
|
||||
// Collect token information but don't calculate costs - that will be done in logger.ts
|
||||
const tokens = {
|
||||
prompt: currentResponse.usage?.prompt_tokens || 0,
|
||||
completion: currentResponse.usage?.completion_tokens || 0,
|
||||
@@ -282,14 +254,11 @@ export const openaiProvider: ProviderConfig = {
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
// Track if a forced tool has been used
|
||||
let hasUsedForcedTool = false
|
||||
|
||||
// Track each model and tool call segment with timestamps
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
@@ -300,145 +269,159 @@ export const openaiProvider: ProviderConfig = {
|
||||
},
|
||||
]
|
||||
|
||||
// Check if a forced tool was used in the first response
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
`Processing ${toolCallsInResponse.length} tool calls in parallel (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})`
|
||||
)
|
||||
|
||||
// Track time for tool calls in this batch
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
// Process each tool call
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
// Get the tool from the tools registry
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
// Execute the tool
|
||||
const toolCallStartTime = Date.now()
|
||||
if (!tool) {
|
||||
return null
|
||||
}
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
// Add to time segments for both success and failure
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
// Prepare result content for the LLM
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
// Include error information so LLM can respond appropriately
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages (both success and failure)
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error processing tool call:', {
|
||||
error,
|
||||
toolName: toolCall?.function?.name,
|
||||
})
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call:', { error, toolName })
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: toolCallsInResponse.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate tool call time for this iteration
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
// Make the next request with updated messages
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
|
||||
// Update tool_choice based on which forced tools have been used
|
||||
if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) {
|
||||
// If we have remaining forced tools, get the next one to force
|
||||
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
|
||||
|
||||
if (remainingTools.length > 0) {
|
||||
// Force the next tool
|
||||
nextPayload.tool_choice = {
|
||||
type: 'function',
|
||||
function: { name: remainingTools[0] },
|
||||
}
|
||||
logger.info(`Forcing next tool: ${remainingTools[0]}`)
|
||||
} else {
|
||||
// All forced tools have been used, switch to auto
|
||||
nextPayload.tool_choice = 'auto'
|
||||
logger.info('All forced tools have been used, switching to auto tool_choice')
|
||||
}
|
||||
}
|
||||
|
||||
// Time the next model call
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
// Make the next request
|
||||
currentResponse = await openai.chat.completions.create(nextPayload)
|
||||
|
||||
// Check if any forced tools were used in this response
|
||||
checkForForcedToolUsage(currentResponse, nextPayload.tool_choice)
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
@@ -447,15 +430,8 @@ export const openaiProvider: ProviderConfig = {
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
// Add to model time
|
||||
modelTime += thisModelTime
|
||||
|
||||
// Update content if we have a text response
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
// Update token counts
|
||||
if (currentResponse.usage) {
|
||||
tokens.prompt += currentResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += currentResponse.usage.completion_tokens || 0
|
||||
@@ -465,46 +441,44 @@ export const openaiProvider: ProviderConfig = {
|
||||
iterationCount++
|
||||
}
|
||||
|
||||
// After all tool processing complete, if streaming was requested, use streaming for the final response
|
||||
if (request.stream) {
|
||||
logger.info('Using streaming for final response after tool processing')
|
||||
|
||||
// When streaming after tool calls with forced tools, make sure tool_choice is set to 'auto'
|
||||
// This prevents OpenAI API from trying to force tool usage again in the final streaming response
|
||||
const streamingPayload = {
|
||||
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
|
||||
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto', // Always use 'auto' for the streaming response after tool calls
|
||||
tool_choice: 'auto',
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
|
||||
const streamResponse = await openai.chat.completions.create(streamingPayload)
|
||||
|
||||
// Create the StreamingExecution object with all collected data
|
||||
let _streamContent = ''
|
||||
const streamResponse = await openai.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
|
||||
// Update the execution data with the final content and token usage
|
||||
_streamContent = content
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: tokens.prompt + usage.prompt_tokens,
|
||||
completion: tokens.completion + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
// Update token usage if available from the stream
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokens.prompt,
|
||||
completion: usage.completion_tokens || tokens.completion,
|
||||
total: usage.total_tokens || tokens.total,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by the callback
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
@@ -528,9 +502,13 @@ export const openaiProvider: ProviderConfig = {
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
// Cost will be calculated in logger
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [], // No block logs at provider level
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -539,11 +517,9 @@ export const openaiProvider: ProviderConfig = {
|
||||
},
|
||||
} as StreamingExecution
|
||||
|
||||
// Return the streaming execution object with explicit casting
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Calculate overall timing
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -564,10 +540,8 @@ export const openaiProvider: ProviderConfig = {
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
// We're not calculating cost here as it will be handled in logger.ts
|
||||
}
|
||||
} catch (error) {
|
||||
// Include timing information even for errors
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -577,7 +551,6 @@ export const openaiProvider: ProviderConfig = {
|
||||
duration: totalDuration,
|
||||
})
|
||||
|
||||
// Create a new error with timing information
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
enhancedError.timing = {
|
||||
|
||||
@@ -1,37 +1,15 @@
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import type { Stream } from 'openai/streaming'
|
||||
import { createOpenAICompatibleStream } from '@/providers/utils'
|
||||
|
||||
/**
|
||||
* Helper function to convert an OpenAI stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
* Creates a ReadableStream from an OpenAI streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
*/
|
||||
export function createReadableStreamFromOpenAIStream(
|
||||
openaiStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of openaiStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
openaiStream: Stream<ChatCompletionChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createOpenAICompatibleStream(openaiStream, 'OpenAI', onComplete)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import OpenAI from 'openai'
|
||||
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
@@ -6,6 +7,7 @@ import { getProviderDefaultModel, getProviderModels } from '@/providers/models'
|
||||
import {
|
||||
checkForForcedToolUsage,
|
||||
createReadableStreamFromOpenAIStream,
|
||||
supportsNativeStructuredOutputs,
|
||||
} from '@/providers/openrouter/utils'
|
||||
import type {
|
||||
ProviderConfig,
|
||||
@@ -13,11 +15,49 @@ import type {
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
|
||||
import {
|
||||
calculateCost,
|
||||
generateSchemaInstructions,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
} from '@/providers/utils'
|
||||
import { executeTool } from '@/tools'
|
||||
|
||||
const logger = createLogger('OpenRouterProvider')
|
||||
|
||||
/**
|
||||
* Applies structured output configuration to a payload based on model capabilities.
|
||||
* Uses json_schema with require_parameters for supported models, falls back to json_object with prompt instructions.
|
||||
*/
|
||||
async function applyResponseFormat(
|
||||
targetPayload: any,
|
||||
messages: any[],
|
||||
responseFormat: any,
|
||||
model: string
|
||||
): Promise<any[]> {
|
||||
const useNative = await supportsNativeStructuredOutputs(model)
|
||||
|
||||
if (useNative) {
|
||||
logger.info('Using native structured outputs for OpenRouter model', { model })
|
||||
targetPayload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: responseFormat.name || 'response_schema',
|
||||
schema: responseFormat.schema || responseFormat,
|
||||
strict: responseFormat.strict !== false,
|
||||
},
|
||||
}
|
||||
targetPayload.provider = { ...targetPayload.provider, require_parameters: true }
|
||||
return messages
|
||||
}
|
||||
|
||||
logger.info('Using json_object mode with prompt instructions for OpenRouter model', { model })
|
||||
const schema = responseFormat.schema || responseFormat
|
||||
const schemaInstructions = generateSchemaInstructions(schema, responseFormat.name)
|
||||
targetPayload.response_format = { type: 'json_object' }
|
||||
return [...messages, { role: 'user', content: schemaInstructions }]
|
||||
}
|
||||
|
||||
export const openRouterProvider: ProviderConfig = {
|
||||
id: 'openrouter',
|
||||
name: 'OpenRouter',
|
||||
@@ -83,17 +123,6 @@ export const openRouterProvider: ProviderConfig = {
|
||||
if (request.temperature !== undefined) payload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) payload.max_tokens = request.maxTokens
|
||||
|
||||
if (request.responseFormat) {
|
||||
payload.response_format = {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: request.responseFormat.name || 'response_schema',
|
||||
schema: request.responseFormat.schema || request.responseFormat,
|
||||
strict: request.responseFormat.strict !== false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
let hasActiveTools = false
|
||||
if (tools?.length) {
|
||||
@@ -110,26 +139,43 @@ export const openRouterProvider: ProviderConfig = {
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
if (request.responseFormat && !hasActiveTools) {
|
||||
payload.messages = await applyResponseFormat(
|
||||
payload,
|
||||
payload.messages,
|
||||
request.responseFormat,
|
||||
requestedModel
|
||||
)
|
||||
}
|
||||
|
||||
if (request.stream && (!tools || tools.length === 0 || !hasActiveTools)) {
|
||||
const streamResponse = await client.chat.completions.create({
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
})
|
||||
|
||||
const tokenUsage = { prompt: 0, completion: 0, total: 0 }
|
||||
}
|
||||
const streamResponse = await client.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokenUsage.prompt,
|
||||
completion: usage.completion_tokens || tokenUsage.completion,
|
||||
total: usage.total_tokens || tokenUsage.total,
|
||||
}
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
}
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.prompt_tokens,
|
||||
completion: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
requestedModel,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
|
||||
const end = Date.now()
|
||||
const endISO = new Date(end).toISOString()
|
||||
if (streamingResult.execution.output.providerTiming) {
|
||||
@@ -147,7 +193,7 @@ export const openRouterProvider: ProviderConfig = {
|
||||
output: {
|
||||
content: '',
|
||||
model: requestedModel,
|
||||
tokens: tokenUsage,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
@@ -163,6 +209,7 @@ export const openRouterProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
@@ -217,87 +264,125 @@ export const openRouterProvider: ProviderConfig = {
|
||||
usedForcedTools = forcedToolResult.usedForcedTools
|
||||
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
const toolsStartTime = Date.now()
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
const toolCallStartTime = Date.now()
|
||||
if (!tool) return null
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
} catch (error) {
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call (OpenRouter):', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
toolName: toolCall?.function?.name,
|
||||
toolName,
|
||||
})
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: toolCallsInResponse.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
const nextPayload: any = {
|
||||
const nextPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
}
|
||||
@@ -343,26 +428,52 @@ export const openRouterProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
if (request.stream) {
|
||||
const streamingPayload = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto',
|
||||
const accumulatedCost = calculateCost(requestedModel, tokens.prompt, tokens.completion)
|
||||
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming & { provider?: any } = {
|
||||
model: payload.model,
|
||||
messages: [...currentMessages],
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
|
||||
const streamResponse = await client.chat.completions.create(streamingPayload)
|
||||
if (payload.temperature !== undefined) {
|
||||
streamingParams.temperature = payload.temperature
|
||||
}
|
||||
if (payload.max_tokens !== undefined) {
|
||||
streamingParams.max_tokens = payload.max_tokens
|
||||
}
|
||||
|
||||
if (request.responseFormat) {
|
||||
;(streamingParams as any).messages = await applyResponseFormat(
|
||||
streamingParams as any,
|
||||
streamingParams.messages,
|
||||
request.responseFormat,
|
||||
requestedModel
|
||||
)
|
||||
}
|
||||
|
||||
const streamResponse = await client.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => {
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokens.prompt,
|
||||
completion: usage.completion_tokens || tokens.completion,
|
||||
total: usage.total_tokens || tokens.total,
|
||||
}
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
}
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: tokens.prompt + usage.prompt_tokens,
|
||||
completion: tokens.completion + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(
|
||||
requestedModel,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
@@ -387,6 +498,11 @@ export const openRouterProvider: ProviderConfig = {
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
@@ -400,6 +516,49 @@ export const openRouterProvider: ProviderConfig = {
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
if (request.responseFormat && hasActiveTools && toolCalls.length > 0) {
|
||||
const finalPayload: any = {
|
||||
model: payload.model,
|
||||
messages: [...currentMessages],
|
||||
}
|
||||
if (payload.temperature !== undefined) {
|
||||
finalPayload.temperature = payload.temperature
|
||||
}
|
||||
if (payload.max_tokens !== undefined) {
|
||||
finalPayload.max_tokens = payload.max_tokens
|
||||
}
|
||||
|
||||
finalPayload.messages = await applyResponseFormat(
|
||||
finalPayload,
|
||||
finalPayload.messages,
|
||||
request.responseFormat,
|
||||
requestedModel
|
||||
)
|
||||
|
||||
const finalStartTime = Date.now()
|
||||
const finalResponse = await client.chat.completions.create(finalPayload)
|
||||
const finalEndTime = Date.now()
|
||||
const finalDuration = finalEndTime - finalStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: 'Final structured response',
|
||||
startTime: finalStartTime,
|
||||
endTime: finalEndTime,
|
||||
duration: finalDuration,
|
||||
})
|
||||
modelTime += finalDuration
|
||||
|
||||
if (finalResponse.choices[0]?.message?.content) {
|
||||
content = finalResponse.choices[0].message.content
|
||||
}
|
||||
if (finalResponse.usage) {
|
||||
tokens.prompt += finalResponse.usage.prompt_tokens || 0
|
||||
tokens.completion += finalResponse.usage.completion_tokens || 0
|
||||
tokens.total += finalResponse.usage.total_tokens || 0
|
||||
}
|
||||
}
|
||||
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -425,10 +584,21 @@ export const openRouterProvider: ProviderConfig = {
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
logger.error('Error in OpenRouter request:', {
|
||||
|
||||
const errorDetails: Record<string, any> = {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
duration: totalDuration,
|
||||
})
|
||||
}
|
||||
if (error && typeof error === 'object') {
|
||||
const err = error as any
|
||||
if (err.status) errorDetails.status = err.status
|
||||
if (err.code) errorDetails.code = err.code
|
||||
if (err.type) errorDetails.type = err.type
|
||||
if (err.error?.message) errorDetails.providerMessage = err.error.message
|
||||
if (err.error?.metadata) errorDetails.metadata = err.error.metadata
|
||||
}
|
||||
|
||||
logger.error('Error in OpenRouter request:', errorDetails)
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore
|
||||
enhancedError.timing = {
|
||||
|
||||
@@ -1,55 +1,107 @@
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { trackForcedToolUsage } from '@/providers/utils'
|
||||
import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('OpenRouterProvider')
|
||||
const logger = createLogger('OpenRouterUtils')
|
||||
|
||||
interface OpenRouterModelData {
|
||||
id: string
|
||||
supported_parameters?: string[]
|
||||
}
|
||||
|
||||
interface ModelCapabilities {
|
||||
supportsStructuredOutputs: boolean
|
||||
supportsTools: boolean
|
||||
}
|
||||
|
||||
let modelCapabilitiesCache: Map<string, ModelCapabilities> | null = null
|
||||
let cacheTimestamp = 0
|
||||
const CACHE_TTL_MS = 5 * 60 * 1000 // 5 minutes
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from an OpenAI-compatible stream response
|
||||
* @param openaiStream - The OpenAI stream to convert
|
||||
* @param onComplete - Optional callback when streaming is complete with content and usage data
|
||||
* @returns ReadableStream that emits text chunks
|
||||
* Fetches and caches OpenRouter model capabilities from their API.
|
||||
*/
|
||||
export function createReadableStreamFromOpenAIStream(
|
||||
openaiStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
async function fetchModelCapabilities(): Promise<Map<string, ModelCapabilities>> {
|
||||
try {
|
||||
const response = await fetch('https://openrouter.ai/api/v1/models', {
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
})
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of openaiStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
if (!response.ok) {
|
||||
logger.warn('Failed to fetch OpenRouter model capabilities', {
|
||||
status: response.status,
|
||||
})
|
||||
return new Map()
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
const data = await response.json()
|
||||
const capabilities = new Map<string, ModelCapabilities>()
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
for (const model of (data.data ?? []) as OpenRouterModelData[]) {
|
||||
const supportedParams = model.supported_parameters ?? []
|
||||
capabilities.set(model.id, {
|
||||
supportsStructuredOutputs: supportedParams.includes('structured_outputs'),
|
||||
supportsTools: supportedParams.includes('tools'),
|
||||
})
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
logger.info('Cached OpenRouter model capabilities', {
|
||||
modelCount: capabilities.size,
|
||||
withStructuredOutputs: Array.from(capabilities.values()).filter(
|
||||
(c) => c.supportsStructuredOutputs
|
||||
).length,
|
||||
})
|
||||
|
||||
return capabilities
|
||||
} catch (error) {
|
||||
logger.error('Error fetching OpenRouter model capabilities', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
return new Map()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a forced tool was used in the response and updates tracking
|
||||
* @param response - The API response containing tool calls
|
||||
* @param toolChoice - The tool choice configuration (string or object)
|
||||
* @param forcedTools - Array of forced tool names
|
||||
* @param usedForcedTools - Array of already used forced tools
|
||||
* @returns Object with hasUsedForcedTool flag and updated usedForcedTools array
|
||||
* Gets capabilities for a specific OpenRouter model.
|
||||
* Fetches from API if cache is stale or empty.
|
||||
*/
|
||||
export async function getOpenRouterModelCapabilities(
|
||||
modelId: string
|
||||
): Promise<ModelCapabilities | null> {
|
||||
const now = Date.now()
|
||||
|
||||
if (!modelCapabilitiesCache || now - cacheTimestamp > CACHE_TTL_MS) {
|
||||
modelCapabilitiesCache = await fetchModelCapabilities()
|
||||
cacheTimestamp = now
|
||||
}
|
||||
|
||||
const normalizedId = modelId.replace(/^openrouter\//, '')
|
||||
return modelCapabilitiesCache.get(normalizedId) ?? null
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a model supports native structured outputs (json_schema).
|
||||
*/
|
||||
export async function supportsNativeStructuredOutputs(modelId: string): Promise<boolean> {
|
||||
const capabilities = await getOpenRouterModelCapabilities(modelId)
|
||||
return capabilities?.supportsStructuredOutputs ?? false
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from an OpenRouter streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
*/
|
||||
export function createReadableStreamFromOpenAIStream(
|
||||
openaiStream: AsyncIterable<ChatCompletionChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createOpenAICompatibleStream(openaiStream, 'OpenRouter', onComplete)
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a forced tool was used in an OpenRouter response.
|
||||
* Uses the shared OpenAI-compatible forced tool usage helper.
|
||||
*/
|
||||
export function checkForForcedToolUsage(
|
||||
response: any,
|
||||
@@ -57,22 +109,11 @@ export function checkForForcedToolUsage(
|
||||
forcedTools: string[],
|
||||
usedForcedTools: string[]
|
||||
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
|
||||
let hasUsedForcedTool = false
|
||||
let updatedUsedForcedTools = usedForcedTools
|
||||
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
'openrouter',
|
||||
forcedTools,
|
||||
updatedUsedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
updatedUsedForcedTools = result.usedForcedTools
|
||||
}
|
||||
|
||||
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
|
||||
return checkForForcedToolUsageOpenAI(
|
||||
response,
|
||||
toolChoice,
|
||||
'OpenRouter',
|
||||
forcedTools,
|
||||
usedForcedTools
|
||||
)
|
||||
}
|
||||
|
||||
@@ -19,15 +19,12 @@ export type ProviderId =
|
||||
* Model pricing information per million tokens
|
||||
*/
|
||||
export interface ModelPricing {
|
||||
input: number // Cost per million tokens for input
|
||||
cachedInput?: number // Cost per million tokens for cached input (optional)
|
||||
output: number // Cost per million tokens for output
|
||||
updatedAt: string // ISO timestamp when pricing was last updated
|
||||
input: number // Per 1M tokens
|
||||
cachedInput?: number // Per 1M tokens (if supported)
|
||||
output: number // Per 1M tokens
|
||||
updatedAt: string // Last updated date
|
||||
}
|
||||
|
||||
/**
|
||||
* Map of model IDs to their pricing information
|
||||
*/
|
||||
export type ModelPricingMap = Record<string, ModelPricing>
|
||||
|
||||
export interface TokenInfo {
|
||||
@@ -84,20 +81,20 @@ export interface ProviderResponse {
|
||||
toolCalls?: FunctionCallResponse[]
|
||||
toolResults?: any[]
|
||||
timing?: {
|
||||
startTime: string // ISO timestamp when provider execution started
|
||||
endTime: string // ISO timestamp when provider execution completed
|
||||
duration: number // Total duration in milliseconds
|
||||
modelTime?: number // Time spent in model generation (excluding tool calls)
|
||||
toolsTime?: number // Time spent in tool calls
|
||||
firstResponseTime?: number // Time to first token/response
|
||||
iterations?: number // Number of model calls for tool use
|
||||
timeSegments?: TimeSegment[] // Detailed timeline of all operations
|
||||
startTime: string
|
||||
endTime: string
|
||||
duration: number
|
||||
modelTime?: number
|
||||
toolsTime?: number
|
||||
firstResponseTime?: number
|
||||
iterations?: number
|
||||
timeSegments?: TimeSegment[]
|
||||
}
|
||||
cost?: {
|
||||
input: number // Cost in USD for input tokens
|
||||
output: number // Cost in USD for output tokens
|
||||
total: number // Total cost in USD
|
||||
pricing: ModelPricing // The pricing used for calculation
|
||||
input: number
|
||||
output: number
|
||||
total: number
|
||||
pricing: ModelPricing
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,27 +147,23 @@ export interface ProviderRequest {
|
||||
strict?: boolean
|
||||
}
|
||||
local_execution?: boolean
|
||||
workflowId?: string // Optional workflow ID for authentication context
|
||||
workspaceId?: string // Optional workspace ID for MCP tool scoping
|
||||
chatId?: string // Optional chat ID for checkpoint context
|
||||
userId?: string // Optional user ID for tool execution context
|
||||
workflowId?: string
|
||||
workspaceId?: string
|
||||
chatId?: string
|
||||
userId?: string
|
||||
stream?: boolean
|
||||
streamToolCalls?: boolean // Whether to stream tool call responses back to user (default: false)
|
||||
environmentVariables?: Record<string, string> // Environment variables for tool execution
|
||||
workflowVariables?: Record<string, any> // Workflow variables for <variable.name> resolution
|
||||
blockData?: Record<string, any> // Runtime block outputs for <block.field> resolution in custom tools
|
||||
blockNameMapping?: Record<string, string> // Mapping of block names to IDs for resolution
|
||||
isCopilotRequest?: boolean // Flag to indicate this request is from the copilot system
|
||||
// Azure OpenAI specific parameters
|
||||
streamToolCalls?: boolean
|
||||
environmentVariables?: Record<string, string>
|
||||
workflowVariables?: Record<string, any>
|
||||
blockData?: Record<string, any>
|
||||
blockNameMapping?: Record<string, string>
|
||||
isCopilotRequest?: boolean
|
||||
azureEndpoint?: string
|
||||
azureApiVersion?: string
|
||||
// Vertex AI specific parameters
|
||||
vertexProject?: string
|
||||
vertexLocation?: string
|
||||
// GPT-5 specific parameters
|
||||
reasoningEffort?: string
|
||||
verbosity?: string
|
||||
}
|
||||
|
||||
// Map of provider IDs to their configurations
|
||||
export const providers: Record<string, ProviderConfig> = {}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import { getEnv, isTruthy } from '@/lib/core/config/env'
|
||||
import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { createLogger, type Logger } from '@/lib/logs/console/logger'
|
||||
import { anthropicProvider } from '@/providers/anthropic'
|
||||
import { azureOpenAIProvider } from '@/providers/azure-openai'
|
||||
import { cerebrasProvider } from '@/providers/cerebras'
|
||||
@@ -40,9 +42,6 @@ import { useProvidersStore } from '@/stores/providers/store'
|
||||
|
||||
const logger = createLogger('ProviderUtils')
|
||||
|
||||
/**
|
||||
* Provider configurations - built from the comprehensive definitions
|
||||
*/
|
||||
export const providers: Record<
|
||||
ProviderId,
|
||||
ProviderConfig & {
|
||||
@@ -213,7 +212,6 @@ export function getProviderFromModel(model: string): ProviderId {
|
||||
}
|
||||
|
||||
export function getProvider(id: string): ProviderConfig | undefined {
|
||||
// Handle both formats: 'openai' and 'openai/chat'
|
||||
const providerId = id.split('/')[0] as ProviderId
|
||||
return providers[providerId]
|
||||
}
|
||||
@@ -273,14 +271,27 @@ export function filterBlacklistedModels(models: string[]): string[] {
|
||||
return models.filter((model) => !isModelBlacklisted(model))
|
||||
}
|
||||
|
||||
/**
|
||||
* Get provider icon for a given model
|
||||
*/
|
||||
export function getProviderIcon(model: string): React.ComponentType<{ className?: string }> | null {
|
||||
const providerId = getProviderFromModel(model)
|
||||
return PROVIDER_DEFINITIONS[providerId]?.icon || null
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates prompt instructions for structured JSON output from a JSON schema.
|
||||
* Used as a fallback when native structured outputs are not supported.
|
||||
*/
|
||||
export function generateSchemaInstructions(schema: any, schemaName?: string): string {
|
||||
const name = schemaName || 'response'
|
||||
return `IMPORTANT: You must respond with a valid JSON object that conforms to the following schema.
|
||||
Do not include any text before or after the JSON object. Only output the JSON.
|
||||
|
||||
Schema name: ${name}
|
||||
JSON Schema:
|
||||
${JSON.stringify(schema, null, 2)}
|
||||
|
||||
Your response must be valid JSON that exactly matches this schema structure.`
|
||||
}
|
||||
|
||||
export function generateStructuredOutputInstructions(responseFormat: any): string {
|
||||
if (!responseFormat) return ''
|
||||
|
||||
@@ -479,7 +490,6 @@ export async function transformBlockTool(
|
||||
|
||||
const llmSchema = await createLLMToolSchema(toolConfig, userProvidedParams)
|
||||
|
||||
// Create unique tool ID by appending resource ID for multi-instance tools
|
||||
let uniqueToolId = toolConfig.id
|
||||
if (toolId === 'workflow_executor' && userProvidedParams.workflowId) {
|
||||
uniqueToolId = `${toolConfig.id}_${userProvidedParams.workflowId}`
|
||||
@@ -554,9 +564,6 @@ export function calculateCost(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get pricing information for a specific model (including embedding models)
|
||||
*/
|
||||
export function getModelPricing(modelId: string): any {
|
||||
const embeddingPricing = getEmbeddingModelPricing(modelId)
|
||||
if (embeddingPricing) {
|
||||
@@ -615,29 +622,25 @@ export function shouldBillModelUsage(model: string): boolean {
|
||||
* For use server-side only
|
||||
*/
|
||||
export function getApiKey(provider: string, model: string, userProvidedKey?: string): string {
|
||||
// If user provided a key, use it as a fallback
|
||||
const hasUserKey = !!userProvidedKey
|
||||
|
||||
// Ollama and vLLM models don't require API keys
|
||||
const isOllamaModel =
|
||||
provider === 'ollama' || useProvidersStore.getState().providers.ollama.models.includes(model)
|
||||
if (isOllamaModel) {
|
||||
return 'empty' // Ollama uses 'empty' as a placeholder API key
|
||||
return 'empty'
|
||||
}
|
||||
|
||||
const isVllmModel =
|
||||
provider === 'vllm' || useProvidersStore.getState().providers.vllm.models.includes(model)
|
||||
if (isVllmModel) {
|
||||
return userProvidedKey || 'empty' // vLLM uses 'empty' as a placeholder if no key provided
|
||||
return userProvidedKey || 'empty'
|
||||
}
|
||||
|
||||
// Use server key rotation for all OpenAI models, Anthropic's Claude models, and Google's Gemini models on the hosted platform
|
||||
const isOpenAIModel = provider === 'openai'
|
||||
const isClaudeModel = provider === 'anthropic'
|
||||
const isGeminiModel = provider === 'google'
|
||||
|
||||
if (isHosted && (isOpenAIModel || isClaudeModel || isGeminiModel)) {
|
||||
// Only use server key if model is explicitly in our hosted list
|
||||
const hostedModels = getHostedModels()
|
||||
const isModelHosted = hostedModels.some((m) => m.toLowerCase() === model.toLowerCase())
|
||||
|
||||
@@ -656,7 +659,6 @@ export function getApiKey(provider: string, model: string, userProvidedKey?: str
|
||||
}
|
||||
}
|
||||
|
||||
// For all other cases, require user-provided key
|
||||
if (!hasUserKey) {
|
||||
throw new Error(`API key is required for ${provider} ${model}`)
|
||||
}
|
||||
@@ -688,16 +690,14 @@ export function prepareToolsWithUsageControl(
|
||||
| { type: 'any'; any: { model: string; name: string } }
|
||||
| undefined
|
||||
toolConfig?: {
|
||||
// Add toolConfig for Google's format
|
||||
functionCallingConfig: {
|
||||
mode: 'AUTO' | 'ANY' | 'NONE'
|
||||
allowedFunctionNames?: string[]
|
||||
}
|
||||
}
|
||||
hasFilteredTools: boolean
|
||||
forcedTools: string[] // Return all forced tool IDs
|
||||
forcedTools: string[]
|
||||
} {
|
||||
// If no tools, return early
|
||||
if (!tools || tools.length === 0) {
|
||||
return {
|
||||
tools: undefined,
|
||||
@@ -707,14 +707,12 @@ export function prepareToolsWithUsageControl(
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out tools marked with usageControl='none'
|
||||
const filteredTools = tools.filter((tool) => {
|
||||
const toolId = tool.function?.name || tool.name
|
||||
const toolConfig = providerTools?.find((t) => t.id === toolId)
|
||||
return toolConfig?.usageControl !== 'none'
|
||||
})
|
||||
|
||||
// Check if any tools were filtered out
|
||||
const hasFilteredTools = filteredTools.length < tools.length
|
||||
if (hasFilteredTools) {
|
||||
logger.info(
|
||||
@@ -722,7 +720,6 @@ export function prepareToolsWithUsageControl(
|
||||
)
|
||||
}
|
||||
|
||||
// If all tools were filtered out, return empty
|
||||
if (filteredTools.length === 0) {
|
||||
logger.info('All tools were filtered out due to usageControl="none"')
|
||||
return {
|
||||
@@ -733,11 +730,9 @@ export function prepareToolsWithUsageControl(
|
||||
}
|
||||
}
|
||||
|
||||
// Get all tools that should be forced
|
||||
const forcedTools = providerTools?.filter((tool) => tool.usageControl === 'force') || []
|
||||
const forcedToolIds = forcedTools.map((tool) => tool.id)
|
||||
|
||||
// Determine tool_choice setting
|
||||
let toolChoice:
|
||||
| 'auto'
|
||||
| 'none'
|
||||
@@ -745,7 +740,6 @@ export function prepareToolsWithUsageControl(
|
||||
| { type: 'tool'; name: string }
|
||||
| { type: 'any'; any: { model: string; name: string } } = 'auto'
|
||||
|
||||
// For Google, we'll use a separate toolConfig object
|
||||
let toolConfig:
|
||||
| {
|
||||
functionCallingConfig: {
|
||||
@@ -756,30 +750,22 @@ export function prepareToolsWithUsageControl(
|
||||
| undefined
|
||||
|
||||
if (forcedTools.length > 0) {
|
||||
// Force the first tool that has usageControl='force'
|
||||
const forcedTool = forcedTools[0]
|
||||
|
||||
// Adjust format based on provider
|
||||
if (provider === 'anthropic') {
|
||||
toolChoice = {
|
||||
type: 'tool',
|
||||
name: forcedTool.id,
|
||||
}
|
||||
} else if (provider === 'google') {
|
||||
// Google Gemini format uses a separate toolConfig object
|
||||
toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: 'ANY',
|
||||
allowedFunctionNames:
|
||||
forcedTools.length === 1
|
||||
? [forcedTool.id] // If only one tool, specify just that one
|
||||
: forcedToolIds, // If multiple tools, include all of them
|
||||
allowedFunctionNames: forcedTools.length === 1 ? [forcedTool.id] : forcedToolIds,
|
||||
},
|
||||
}
|
||||
// Keep toolChoice as 'auto' since we use toolConfig instead
|
||||
toolChoice = 'auto'
|
||||
} else {
|
||||
// Default OpenAI format
|
||||
toolChoice = {
|
||||
type: 'function',
|
||||
function: { name: forcedTool.id },
|
||||
@@ -794,7 +780,6 @@ export function prepareToolsWithUsageControl(
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Default to auto if no forced tools
|
||||
toolChoice = 'auto'
|
||||
if (provider === 'google') {
|
||||
toolConfig = { functionCallingConfig: { mode: 'AUTO' } }
|
||||
@@ -845,7 +830,6 @@ export function trackForcedToolUsage(
|
||||
}
|
||||
}
|
||||
} {
|
||||
// Default to keeping the original tool_choice
|
||||
let hasUsedForcedTool = false
|
||||
let nextToolChoice = originalToolChoice
|
||||
let nextToolConfig:
|
||||
@@ -859,13 +843,10 @@ export function trackForcedToolUsage(
|
||||
|
||||
const updatedUsedForcedTools = [...usedForcedTools]
|
||||
|
||||
// Special handling for Google format
|
||||
const isGoogleFormat = provider === 'google'
|
||||
|
||||
// Get the name of the current forced tool(s)
|
||||
let forcedToolNames: string[] = []
|
||||
if (isGoogleFormat && originalToolChoice?.functionCallingConfig?.allowedFunctionNames) {
|
||||
// For Google format
|
||||
forcedToolNames = originalToolChoice.functionCallingConfig.allowedFunctionNames
|
||||
} else if (
|
||||
typeof originalToolChoice === 'object' &&
|
||||
@@ -873,7 +854,6 @@ export function trackForcedToolUsage(
|
||||
(originalToolChoice?.type === 'tool' && originalToolChoice?.name) ||
|
||||
(originalToolChoice?.type === 'any' && originalToolChoice?.any?.name))
|
||||
) {
|
||||
// For other providers
|
||||
forcedToolNames = [
|
||||
originalToolChoice?.function?.name ||
|
||||
originalToolChoice?.name ||
|
||||
@@ -881,27 +861,20 @@ export function trackForcedToolUsage(
|
||||
].filter(Boolean)
|
||||
}
|
||||
|
||||
// If we're forcing specific tools and we have tool calls in the response
|
||||
if (forcedToolNames.length > 0 && toolCallsResponse && toolCallsResponse.length > 0) {
|
||||
// Check if any of the tool calls used the forced tools
|
||||
const toolNames = toolCallsResponse.map((tc) => tc.function?.name || tc.name || tc.id)
|
||||
|
||||
// Find any forced tools that were used
|
||||
const usedTools = forcedToolNames.filter((toolName) => toolNames.includes(toolName))
|
||||
|
||||
if (usedTools.length > 0) {
|
||||
// At least one forced tool was used
|
||||
hasUsedForcedTool = true
|
||||
updatedUsedForcedTools.push(...usedTools)
|
||||
|
||||
// Find the next tools to force that haven't been used yet
|
||||
const remainingTools = forcedTools.filter((tool) => !updatedUsedForcedTools.includes(tool))
|
||||
|
||||
if (remainingTools.length > 0) {
|
||||
// There are still forced tools to use
|
||||
const nextToolToForce = remainingTools[0]
|
||||
|
||||
// Format based on provider
|
||||
if (provider === 'anthropic') {
|
||||
nextToolChoice = {
|
||||
type: 'tool',
|
||||
@@ -912,13 +885,10 @@ export function trackForcedToolUsage(
|
||||
functionCallingConfig: {
|
||||
mode: 'ANY',
|
||||
allowedFunctionNames:
|
||||
remainingTools.length === 1
|
||||
? [nextToolToForce] // If only one tool left, specify just that one
|
||||
: remainingTools, // If multiple tools, include all remaining
|
||||
remainingTools.length === 1 ? [nextToolToForce] : remainingTools,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Default OpenAI format
|
||||
nextToolChoice = {
|
||||
type: 'function',
|
||||
function: { name: nextToolToForce },
|
||||
@@ -929,9 +899,7 @@ export function trackForcedToolUsage(
|
||||
`Forced tool(s) ${usedTools.join(', ')} used, switching to next forced tool(s): ${remainingTools.join(', ')}`
|
||||
)
|
||||
} else {
|
||||
// All forced tools have been used, switch to auto mode
|
||||
if (provider === 'anthropic') {
|
||||
// Anthropic: return null to signal the parameter should be deleted/omitted
|
||||
nextToolChoice = null
|
||||
} else if (provider === 'google') {
|
||||
nextToolConfig = { functionCallingConfig: { mode: 'AUTO' } }
|
||||
@@ -963,9 +931,6 @@ export const MODELS_WITH_REASONING_EFFORT = getModelsWithReasoningEffort()
|
||||
export const MODELS_WITH_VERBOSITY = getModelsWithVerbosity()
|
||||
export const PROVIDERS_WITH_TOOL_USAGE_CONTROL = getProvidersWithToolUsageControl()
|
||||
|
||||
/**
|
||||
* Check if a model supports temperature parameter
|
||||
*/
|
||||
export function supportsTemperature(model: string): boolean {
|
||||
return supportsTemperatureFromDefinitions(model)
|
||||
}
|
||||
@@ -978,9 +943,6 @@ export function getMaxTemperature(model: string): number | undefined {
|
||||
return getMaxTempFromDefinitions(model)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a provider supports tool usage control
|
||||
*/
|
||||
export function supportsToolUsageControl(provider: string): boolean {
|
||||
return supportsToolUsageControlFromDefinitions(provider)
|
||||
}
|
||||
@@ -1021,8 +983,6 @@ export function prepareToolExecution(
|
||||
toolParams: Record<string, any>
|
||||
executionParams: Record<string, any>
|
||||
} {
|
||||
// Filter out empty/null/undefined values from user params
|
||||
// so that cleared fields don't override LLM-generated values
|
||||
const filteredUserParams: Record<string, any> = {}
|
||||
if (tool.params) {
|
||||
for (const [key, value] of Object.entries(tool.params)) {
|
||||
@@ -1032,13 +992,11 @@ export function prepareToolExecution(
|
||||
}
|
||||
}
|
||||
|
||||
// User-provided params take precedence over LLM-generated params
|
||||
const toolParams = {
|
||||
...llmArgs,
|
||||
...filteredUserParams,
|
||||
}
|
||||
|
||||
// Add system parameters for execution
|
||||
const executionParams = {
|
||||
...toolParams,
|
||||
...(request.workflowId
|
||||
@@ -1055,9 +1013,107 @@ export function prepareToolExecution(
|
||||
...(request.workflowVariables ? { workflowVariables: request.workflowVariables } : {}),
|
||||
...(request.blockData ? { blockData: request.blockData } : {}),
|
||||
...(request.blockNameMapping ? { blockNameMapping: request.blockNameMapping } : {}),
|
||||
// Pass tool schema for MCP tools to skip discovery
|
||||
...(tool.parameters ? { _toolSchema: tool.parameters } : {}),
|
||||
}
|
||||
|
||||
return { toolParams, executionParams }
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a ReadableStream from an OpenAI-compatible streaming response.
|
||||
* This is a shared utility used by all OpenAI-compatible providers:
|
||||
* OpenAI, Groq, DeepSeek, xAI, OpenRouter, Mistral, Ollama, vLLM, Azure OpenAI, Cerebras
|
||||
*
|
||||
* @param stream - The async iterable stream from the provider
|
||||
* @param providerName - Name of the provider for logging purposes
|
||||
* @param onComplete - Optional callback called when stream completes with full content and usage
|
||||
* @returns A ReadableStream that can be used for streaming responses
|
||||
*/
|
||||
export function createOpenAICompatibleStream(
|
||||
stream: AsyncIterable<ChatCompletionChunk>,
|
||||
providerName: string,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
const streamLogger = createLogger(`${providerName}Utils`)
|
||||
let fullContent = ''
|
||||
let promptTokens = 0
|
||||
let completionTokens = 0
|
||||
let totalTokens = 0
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
if (chunk.usage) {
|
||||
promptTokens = chunk.usage.prompt_tokens ?? 0
|
||||
completionTokens = chunk.usage.completion_tokens ?? 0
|
||||
totalTokens = chunk.usage.total_tokens ?? 0
|
||||
}
|
||||
|
||||
const content = chunk.choices?.[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
if (promptTokens === 0 && completionTokens === 0) {
|
||||
streamLogger.warn(`${providerName} stream completed without usage data`)
|
||||
}
|
||||
onComplete(fullContent, {
|
||||
prompt_tokens: promptTokens,
|
||||
completion_tokens: completionTokens,
|
||||
total_tokens: totalTokens || promptTokens + completionTokens,
|
||||
})
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a forced tool was used in an OpenAI-compatible response and updates tracking.
|
||||
* This is a shared utility used by OpenAI-compatible providers:
|
||||
* OpenAI, Groq, DeepSeek, xAI, OpenRouter, Mistral, Ollama, vLLM, Azure OpenAI, Cerebras
|
||||
*
|
||||
* @param response - The API response containing tool calls
|
||||
* @param toolChoice - The tool choice configuration (string or object)
|
||||
* @param providerName - Name of the provider for logging purposes
|
||||
* @param forcedTools - Array of forced tool names
|
||||
* @param usedForcedTools - Array of already used forced tools
|
||||
* @param customLogger - Optional custom logger instance
|
||||
* @returns Object with hasUsedForcedTool flag and updated usedForcedTools array
|
||||
*/
|
||||
export function checkForForcedToolUsageOpenAI(
|
||||
response: any,
|
||||
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
|
||||
providerName: string,
|
||||
forcedTools: string[],
|
||||
usedForcedTools: string[],
|
||||
customLogger?: Logger
|
||||
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
|
||||
const checkLogger = customLogger || createLogger(`${providerName}Utils`)
|
||||
let hasUsedForcedTool = false
|
||||
let updatedUsedForcedTools = [...usedForcedTools]
|
||||
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
checkLogger,
|
||||
providerName.toLowerCase().replace(/\s+/g, '-'),
|
||||
forcedTools,
|
||||
updatedUsedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
updatedUsedForcedTools = result.usedForcedTools
|
||||
}
|
||||
|
||||
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import type {
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
@@ -238,14 +239,21 @@ export const vertexProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.promptTokenCount || 0,
|
||||
completion: usage.candidatesTokenCount || 0,
|
||||
total:
|
||||
usage.totalTokenCount ||
|
||||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0),
|
||||
}
|
||||
const promptTokens = usage?.promptTokenCount || 0
|
||||
const completionTokens = usage?.candidatesTokenCount || 0
|
||||
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
|
||||
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: promptTokens,
|
||||
completion: completionTokens,
|
||||
total: totalTokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(request.model, promptTokens, completionTokens)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -541,8 +549,6 @@ export const vertexProvider: ProviderConfig = {
|
||||
|
||||
logger.info('No function call detected, proceeding with streaming response')
|
||||
|
||||
// Apply structured output for the final response if responseFormat is specified
|
||||
// This works regardless of whether tools were forced or auto
|
||||
if (request.responseFormat) {
|
||||
streamingPayload.tools = undefined
|
||||
streamingPayload.toolConfig = undefined
|
||||
@@ -654,21 +660,40 @@ export const vertexProvider: ProviderConfig = {
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
const existingTokens = streamingExecution.execution.output.tokens || {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
streamingExecution.execution.output.tokens = {
|
||||
prompt: (existingTokens.prompt || 0) + (usage.promptTokenCount || 0),
|
||||
completion:
|
||||
(existingTokens.completion || 0) + (usage.candidatesTokenCount || 0),
|
||||
total:
|
||||
(existingTokens.total || 0) +
|
||||
(usage.totalTokenCount ||
|
||||
(usage.promptTokenCount || 0) + (usage.candidatesTokenCount || 0)),
|
||||
}
|
||||
const promptTokens = usage?.promptTokenCount || 0
|
||||
const completionTokens = usage?.candidatesTokenCount || 0
|
||||
const totalTokens = usage?.totalTokenCount || promptTokens + completionTokens
|
||||
|
||||
const existingTokens = streamingExecution.execution.output.tokens || {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
|
||||
const existingPrompt = existingTokens.prompt || 0
|
||||
const existingCompletion = existingTokens.completion || 0
|
||||
const existingTotal = existingTokens.total || 0
|
||||
|
||||
streamingExecution.execution.output.tokens = {
|
||||
prompt: existingPrompt + promptTokens,
|
||||
completion: existingCompletion + completionTokens,
|
||||
total: existingTotal + totalTokens,
|
||||
}
|
||||
|
||||
const accumulatedCost = calculateCost(
|
||||
request.model,
|
||||
existingPrompt,
|
||||
existingCompletion
|
||||
)
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
promptTokens,
|
||||
completionTokens
|
||||
)
|
||||
streamingExecution.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -753,7 +778,6 @@ export const vertexProvider: ProviderConfig = {
|
||||
const nextFunctionCall = extractFunctionCall(nextCandidate)
|
||||
|
||||
if (!nextFunctionCall) {
|
||||
// If responseFormat is specified, make one final request with structured output
|
||||
if (request.responseFormat) {
|
||||
const finalPayload = {
|
||||
...payload,
|
||||
@@ -886,7 +910,7 @@ export const vertexProvider: ProviderConfig = {
|
||||
})
|
||||
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
// @ts-ignore
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
|
||||
@@ -89,9 +89,7 @@ export function createReadableStreamFromVertexStream(
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (arrayError) {
|
||||
// Buffer is not valid JSON array
|
||||
}
|
||||
} catch (arrayError) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import OpenAI from 'openai'
|
||||
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
@@ -11,6 +12,7 @@ import type {
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
trackForcedToolUsage,
|
||||
@@ -180,17 +182,12 @@ export const vllmProvider: ProviderConfig = {
|
||||
if (request.stream && (!tools || tools.length === 0 || !hasActiveTools)) {
|
||||
logger.info('Using streaming response for vLLM request')
|
||||
|
||||
const streamResponse = await vllm.chat.completions.create({
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
})
|
||||
|
||||
const tokenUsage = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
const streamResponse = await vllm.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromVLLMStream(streamResponse, (content, usage) => {
|
||||
@@ -200,6 +197,22 @@ export const vllmProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
streamingResult.execution.output.content = cleanContent
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.prompt_tokens,
|
||||
completion: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
|
||||
const streamEndTime = Date.now()
|
||||
const streamEndTimeISO = new Date(streamEndTime).toISOString()
|
||||
@@ -216,23 +229,13 @@ export const vllmProvider: ProviderConfig = {
|
||||
streamEndTime - providerStartTime
|
||||
}
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokenUsage.prompt,
|
||||
completion: usage.completion_tokens || tokenUsage.completion,
|
||||
total: usage.total_tokens || tokenUsage.total,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '',
|
||||
model: request.model,
|
||||
tokens: tokenUsage,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
@@ -248,6 +251,7 @@ export const vllmProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
@@ -324,6 +328,13 @@ export const vllmProvider: ProviderConfig = {
|
||||
checkForForcedToolUsage(currentResponse, originalToolChoice)
|
||||
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
if (request.responseFormat) {
|
||||
content = content.replace(/```json\n?|\n?```/g, '').trim()
|
||||
}
|
||||
}
|
||||
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
@@ -335,77 +346,105 @@ export const vllmProvider: ProviderConfig = {
|
||||
|
||||
const toolsStartTime = Date.now()
|
||||
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
if (!tool) continue
|
||||
|
||||
const toolCallStartTime = Date.now()
|
||||
if (!tool) return null
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error processing tool call:', {
|
||||
error,
|
||||
toolName: toolCall?.function?.name,
|
||||
})
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('Error processing tool call:', { error, toolName })
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: toolCallsInResponse.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
@@ -469,15 +508,16 @@ export const vllmProvider: ProviderConfig = {
|
||||
if (request.stream) {
|
||||
logger.info('Using streaming for final response after tool processing')
|
||||
|
||||
const streamingPayload = {
|
||||
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
|
||||
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = {
|
||||
...payload,
|
||||
messages: currentMessages,
|
||||
tool_choice: 'auto',
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
|
||||
const streamResponse = await vllm.chat.completions.create(streamingPayload)
|
||||
const streamResponse = await vllm.chat.completions.create(streamingParams)
|
||||
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromVLLMStream(streamResponse, (content, usage) => {
|
||||
@@ -487,15 +527,21 @@ export const vllmProvider: ProviderConfig = {
|
||||
}
|
||||
|
||||
streamingResult.execution.output.content = cleanContent
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: tokens.prompt + usage.prompt_tokens,
|
||||
completion: tokens.completion + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
const newTokens = {
|
||||
prompt: usage.prompt_tokens || tokens.prompt,
|
||||
completion: usage.completion_tokens || tokens.completion,
|
||||
total: usage.total_tokens || tokens.total,
|
||||
}
|
||||
|
||||
streamingResult.execution.output.tokens = newTokens
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
@@ -525,6 +571,11 @@ export const vllmProvider: ProviderConfig = {
|
||||
iterations: iterationCount + 1,
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [],
|
||||
metadata: {
|
||||
@@ -585,7 +636,7 @@ export const vllmProvider: ProviderConfig = {
|
||||
})
|
||||
|
||||
const enhancedError = new Error(errorMessage)
|
||||
// @ts-ignore - Adding timing and vLLM error properties
|
||||
// @ts-ignore
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
|
||||
@@ -1,37 +1,14 @@
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import { createOpenAICompatibleStream } from '@/providers/utils'
|
||||
|
||||
/**
|
||||
* Helper function to convert a vLLM stream to a standard ReadableStream
|
||||
* and collect completion metrics
|
||||
* Creates a ReadableStream from a vLLM streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
*/
|
||||
export function createReadableStreamFromVLLMStream(
|
||||
vllmStream: any,
|
||||
onComplete?: (content: string, usage?: any) => void
|
||||
): ReadableStream {
|
||||
let fullContent = ''
|
||||
let usageData: any = null
|
||||
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of vllmStream) {
|
||||
if (chunk.usage) {
|
||||
usageData = chunk.usage
|
||||
}
|
||||
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
fullContent += content
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
|
||||
if (onComplete) {
|
||||
onComplete(fullContent, usageData)
|
||||
}
|
||||
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
})
|
||||
vllmStream: AsyncIterable<ChatCompletionChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createOpenAICompatibleStream(vllmStream, 'vLLM', onComplete)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import OpenAI from 'openai'
|
||||
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { MAX_TOOL_ITERATIONS } from '@/providers'
|
||||
@@ -9,7 +10,11 @@ import type {
|
||||
ProviderResponse,
|
||||
TimeSegment,
|
||||
} from '@/providers/types'
|
||||
import { prepareToolExecution, prepareToolsWithUsageControl } from '@/providers/utils'
|
||||
import {
|
||||
calculateCost,
|
||||
prepareToolExecution,
|
||||
prepareToolsWithUsageControl,
|
||||
} from '@/providers/utils'
|
||||
import {
|
||||
checkForForcedToolUsage,
|
||||
createReadableStreamFromXAIStream,
|
||||
@@ -66,8 +71,6 @@ export const xAIProvider: ProviderConfig = {
|
||||
if (request.messages) {
|
||||
allMessages.push(...request.messages)
|
||||
}
|
||||
|
||||
// Set up tools
|
||||
const tools = request.tools?.length
|
||||
? request.tools.map((tool) => ({
|
||||
type: 'function',
|
||||
@@ -78,15 +81,11 @@ export const xAIProvider: ProviderConfig = {
|
||||
},
|
||||
}))
|
||||
: undefined
|
||||
|
||||
// Log tools and response format conflict detection
|
||||
if (tools?.length && request.responseFormat) {
|
||||
logger.warn(
|
||||
'XAI Provider - Detected both tools and response format. Using tools first, then response format for final response.'
|
||||
)
|
||||
}
|
||||
|
||||
// Build the base request payload
|
||||
const basePayload: any = {
|
||||
model: request.model || 'grok-3-latest',
|
||||
messages: allMessages,
|
||||
@@ -94,52 +93,54 @@ export const xAIProvider: ProviderConfig = {
|
||||
|
||||
if (request.temperature !== undefined) basePayload.temperature = request.temperature
|
||||
if (request.maxTokens !== undefined) basePayload.max_tokens = request.maxTokens
|
||||
|
||||
// Handle tools and tool usage control
|
||||
let preparedTools: ReturnType<typeof prepareToolsWithUsageControl> | null = null
|
||||
|
||||
if (tools?.length) {
|
||||
preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'xai')
|
||||
}
|
||||
|
||||
// EARLY STREAMING: if caller requested streaming and there are no tools to execute,
|
||||
// we can directly stream the completion with response format if needed
|
||||
if (request.stream && (!tools || tools.length === 0)) {
|
||||
logger.info('XAI Provider - Using direct streaming (no tools)')
|
||||
|
||||
// Start execution timer for the entire provider execution
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
// Use response format payload if needed, otherwise use base payload
|
||||
const streamingPayload = request.responseFormat
|
||||
? createResponseFormatPayload(basePayload, allMessages, request.responseFormat)
|
||||
: { ...basePayload, stream: true }
|
||||
const streamingParams: ChatCompletionCreateParamsStreaming = request.responseFormat
|
||||
? {
|
||||
...createResponseFormatPayload(basePayload, allMessages, request.responseFormat),
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
: { ...basePayload, stream: true, stream_options: { include_usage: true } }
|
||||
|
||||
if (!request.responseFormat) {
|
||||
streamingPayload.stream = true
|
||||
} else {
|
||||
streamingPayload.stream = true
|
||||
}
|
||||
const streamResponse = await xai.chat.completions.create(streamingParams)
|
||||
|
||||
const streamResponse = await xai.chat.completions.create(streamingPayload)
|
||||
|
||||
// Start collecting token usage
|
||||
const tokenUsage = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
}
|
||||
|
||||
// Create a StreamingExecution response with a readable stream
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromXAIStream(streamResponse),
|
||||
stream: createReadableStreamFromXAIStream(streamResponse, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: usage.prompt_tokens,
|
||||
completion: usage.completion_tokens,
|
||||
total: usage.total_tokens,
|
||||
}
|
||||
|
||||
const costResult = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: costResult.input,
|
||||
output: costResult.output,
|
||||
total: costResult.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by streaming content in chat component
|
||||
content: '',
|
||||
model: request.model || 'grok-3-latest',
|
||||
tokens: tokenUsage,
|
||||
tokens: { prompt: 0, completion: 0, total: 0 },
|
||||
toolCalls: undefined,
|
||||
providerTiming: {
|
||||
startTime: providerStartTimeISO,
|
||||
@@ -155,14 +156,9 @@ export const xAIProvider: ProviderConfig = {
|
||||
},
|
||||
],
|
||||
},
|
||||
// Estimate token cost
|
||||
cost: {
|
||||
total: 0.0,
|
||||
input: 0.0,
|
||||
output: 0.0,
|
||||
},
|
||||
cost: { input: 0, output: 0, total: 0 },
|
||||
},
|
||||
logs: [], // No block logs for direct streaming
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -172,26 +168,18 @@ export const xAIProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
|
||||
// Return the streaming execution object
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Start execution timer for the entire provider execution
|
||||
const providerStartTime = Date.now()
|
||||
const providerStartTimeISO = new Date(providerStartTime).toISOString()
|
||||
|
||||
try {
|
||||
// Make the initial API request
|
||||
const initialCallTime = Date.now()
|
||||
|
||||
// For the initial request with tools, we NEVER include response_format
|
||||
// This is the key fix: tools and response_format cannot be used together with xAI
|
||||
// xAI cannot use tools and response_format together in the same request
|
||||
const initialPayload = { ...basePayload }
|
||||
|
||||
// Track the original tool_choice for forced tool tracking
|
||||
let originalToolChoice: any
|
||||
|
||||
// Track forced tools and their usage
|
||||
const forcedTools = preparedTools?.forcedTools || []
|
||||
let usedForcedTools: string[] = []
|
||||
|
||||
@@ -201,7 +189,6 @@ export const xAIProvider: ProviderConfig = {
|
||||
initialPayload.tool_choice = toolChoice
|
||||
originalToolChoice = toolChoice
|
||||
} else if (request.responseFormat) {
|
||||
// Only add response format if there are no tools
|
||||
const responseFormatPayload = createResponseFormatPayload(
|
||||
basePayload,
|
||||
allMessages,
|
||||
@@ -224,14 +211,9 @@ export const xAIProvider: ProviderConfig = {
|
||||
const currentMessages = [...allMessages]
|
||||
let iterationCount = 0
|
||||
|
||||
// Track if a forced tool has been used
|
||||
let hasUsedForcedTool = false
|
||||
|
||||
// Track time spent in model vs tools
|
||||
let modelTime = firstResponseTime
|
||||
let toolsTime = 0
|
||||
|
||||
// Track each model and tool call segment with timestamps
|
||||
const timeSegments: TimeSegment[] = [
|
||||
{
|
||||
type: 'model',
|
||||
@@ -241,8 +223,6 @@ export const xAIProvider: ProviderConfig = {
|
||||
duration: firstResponseTime,
|
||||
},
|
||||
]
|
||||
|
||||
// Check if a forced tool was used in the first response
|
||||
if (originalToolChoice) {
|
||||
const result = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
@@ -256,119 +236,136 @@ export const xAIProvider: ProviderConfig = {
|
||||
|
||||
try {
|
||||
while (iterationCount < MAX_TOOL_ITERATIONS) {
|
||||
// Check for tool calls
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
content = currentResponse.choices[0].message.content
|
||||
}
|
||||
|
||||
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
|
||||
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
|
||||
break
|
||||
}
|
||||
|
||||
// Track time for tool calls in this batch
|
||||
const toolsStartTime = Date.now()
|
||||
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
|
||||
const toolCallStartTime = Date.now()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
for (const toolCall of toolCallsInResponse) {
|
||||
try {
|
||||
const toolName = toolCall.function.name
|
||||
const toolArgs = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
const tool = request.tools?.find((t) => t.id === toolName)
|
||||
|
||||
if (!tool) {
|
||||
logger.warn('XAI Provider - Tool not found:', { toolName })
|
||||
continue
|
||||
return null
|
||||
}
|
||||
|
||||
const toolCallStartTime = Date.now()
|
||||
|
||||
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
|
||||
|
||||
const result = await executeTool(toolName, executionParams, true)
|
||||
const toolCallEndTime = Date.now()
|
||||
const toolCallDuration = toolCallEndTime - toolCallStartTime
|
||||
|
||||
// Add to time segments for both success and failure
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams,
|
||||
result,
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallDuration,
|
||||
})
|
||||
|
||||
// Prepare result content for the LLM
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
// Include error information so LLM can respond appropriately
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
|
||||
logger.warn('XAI Provider - Tool execution failed:', {
|
||||
toolName,
|
||||
error: result.error,
|
||||
})
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(toolCallStartTime).toISOString(),
|
||||
endTime: new Date(toolCallEndTime).toISOString(),
|
||||
duration: toolCallDuration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
// Add the tool call and result to messages (both success and failure)
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: [
|
||||
{
|
||||
id: toolCall.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: toolCall.function.arguments,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
} catch (error) {
|
||||
const toolCallEndTime = Date.now()
|
||||
logger.error('XAI Provider - Error processing tool call:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
toolCall: toolCall.function.name,
|
||||
})
|
||||
|
||||
return {
|
||||
toolCall,
|
||||
toolName,
|
||||
toolParams: {},
|
||||
result: {
|
||||
success: false,
|
||||
output: undefined,
|
||||
error: error instanceof Error ? error.message : 'Tool execution failed',
|
||||
},
|
||||
startTime: toolCallStartTime,
|
||||
endTime: toolCallEndTime,
|
||||
duration: toolCallEndTime - toolCallStartTime,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const executionResults = await Promise.allSettled(toolExecutionPromises)
|
||||
currentMessages.push({
|
||||
role: 'assistant',
|
||||
content: null,
|
||||
tool_calls: toolCallsInResponse.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
},
|
||||
})),
|
||||
})
|
||||
|
||||
for (const settledResult of executionResults) {
|
||||
if (settledResult.status === 'rejected' || !settledResult.value) continue
|
||||
|
||||
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
|
||||
settledResult.value
|
||||
|
||||
timeSegments.push({
|
||||
type: 'tool',
|
||||
name: toolName,
|
||||
startTime: startTime,
|
||||
endTime: endTime,
|
||||
duration: duration,
|
||||
})
|
||||
let resultContent: any
|
||||
if (result.success) {
|
||||
toolResults.push(result.output)
|
||||
resultContent = result.output
|
||||
} else {
|
||||
resultContent = {
|
||||
error: true,
|
||||
message: result.error || 'Tool execution failed',
|
||||
tool: toolName,
|
||||
}
|
||||
logger.warn('XAI Provider - Tool execution failed:', {
|
||||
toolName,
|
||||
error: result.error,
|
||||
})
|
||||
}
|
||||
|
||||
toolCalls.push({
|
||||
name: toolName,
|
||||
arguments: toolParams,
|
||||
startTime: new Date(startTime).toISOString(),
|
||||
endTime: new Date(endTime).toISOString(),
|
||||
duration: duration,
|
||||
result: resultContent,
|
||||
success: result.success,
|
||||
})
|
||||
currentMessages.push({
|
||||
role: 'tool',
|
||||
tool_call_id: toolCall.id,
|
||||
content: JSON.stringify(resultContent),
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate tool call time for this iteration
|
||||
const thisToolsTime = Date.now() - toolsStartTime
|
||||
toolsTime += thisToolsTime
|
||||
|
||||
// After tool calls, create next payload based on whether we need more tools or final response
|
||||
let nextPayload: any
|
||||
|
||||
// Update tool_choice based on which forced tools have been used
|
||||
if (
|
||||
typeof originalToolChoice === 'object' &&
|
||||
hasUsedForcedTool &&
|
||||
forcedTools.length > 0
|
||||
) {
|
||||
// If we have remaining forced tools, get the next one to force
|
||||
const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool))
|
||||
|
||||
if (remainingTools.length > 0) {
|
||||
// Force the next tool - continue with tools, no response format
|
||||
nextPayload = {
|
||||
...basePayload,
|
||||
messages: currentMessages,
|
||||
@@ -379,7 +376,6 @@ export const xAIProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// All forced tools have been used, check if we need response format for final response
|
||||
if (request.responseFormat) {
|
||||
nextPayload = createResponseFormatPayload(
|
||||
basePayload,
|
||||
@@ -397,9 +393,7 @@ export const xAIProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Normal tool processing - check if this might be the final response
|
||||
if (request.responseFormat) {
|
||||
// Use response format for what might be the final response
|
||||
nextPayload = createResponseFormatPayload(
|
||||
basePayload,
|
||||
allMessages,
|
||||
@@ -416,12 +410,9 @@ export const xAIProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// Time the next model call
|
||||
const nextModelStartTime = Date.now()
|
||||
|
||||
currentResponse = await xai.chat.completions.create(nextPayload)
|
||||
|
||||
// Check if any forced tools were used in this response
|
||||
if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') {
|
||||
const result = checkForForcedToolUsage(
|
||||
currentResponse,
|
||||
@@ -435,8 +426,6 @@ export const xAIProvider: ProviderConfig = {
|
||||
|
||||
const nextModelEndTime = Date.now()
|
||||
const thisModelTime = nextModelEndTime - nextModelStartTime
|
||||
|
||||
// Add to time segments
|
||||
timeSegments.push({
|
||||
type: 'model',
|
||||
name: `Model response (iteration ${iterationCount + 1})`,
|
||||
@@ -445,7 +434,6 @@ export const xAIProvider: ProviderConfig = {
|
||||
duration: thisModelTime,
|
||||
})
|
||||
|
||||
// Add to model time
|
||||
modelTime += thisModelTime
|
||||
|
||||
if (currentResponse.choices[0]?.message?.content) {
|
||||
@@ -466,14 +454,10 @@ export const xAIProvider: ProviderConfig = {
|
||||
iterationCount,
|
||||
})
|
||||
}
|
||||
|
||||
// After all tool processing complete, if streaming was requested, use streaming for the final response
|
||||
if (request.stream) {
|
||||
// For final streaming response, choose between tools (auto) or response_format (never both)
|
||||
let finalStreamingPayload: any
|
||||
|
||||
if (request.responseFormat) {
|
||||
// Use response format, no tools
|
||||
finalStreamingPayload = {
|
||||
...createResponseFormatPayload(
|
||||
basePayload,
|
||||
@@ -484,7 +468,6 @@ export const xAIProvider: ProviderConfig = {
|
||||
stream: true,
|
||||
}
|
||||
} else {
|
||||
// Use tools with auto choice
|
||||
finalStreamingPayload = {
|
||||
...basePayload,
|
||||
messages: currentMessages,
|
||||
@@ -494,15 +477,34 @@ export const xAIProvider: ProviderConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
const streamResponse = await xai.chat.completions.create(finalStreamingPayload)
|
||||
const streamResponse = await xai.chat.completions.create(finalStreamingPayload as any)
|
||||
|
||||
const accumulatedCost = calculateCost(request.model, tokens.prompt, tokens.completion)
|
||||
|
||||
// Create a StreamingExecution response with all collected data
|
||||
const streamingResult = {
|
||||
stream: createReadableStreamFromXAIStream(streamResponse),
|
||||
stream: createReadableStreamFromXAIStream(streamResponse as any, (content, usage) => {
|
||||
streamingResult.execution.output.content = content
|
||||
streamingResult.execution.output.tokens = {
|
||||
prompt: tokens.prompt + usage.prompt_tokens,
|
||||
completion: tokens.completion + usage.completion_tokens,
|
||||
total: tokens.total + usage.total_tokens,
|
||||
}
|
||||
|
||||
const streamCost = calculateCost(
|
||||
request.model,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens
|
||||
)
|
||||
streamingResult.execution.output.cost = {
|
||||
input: accumulatedCost.input + streamCost.input,
|
||||
output: accumulatedCost.output + streamCost.output,
|
||||
total: accumulatedCost.total + streamCost.total,
|
||||
}
|
||||
}),
|
||||
execution: {
|
||||
success: true,
|
||||
output: {
|
||||
content: '', // Will be filled by the callback
|
||||
content: '',
|
||||
model: request.model || 'grok-3-latest',
|
||||
tokens: {
|
||||
prompt: tokens.prompt,
|
||||
@@ -527,12 +529,12 @@ export const xAIProvider: ProviderConfig = {
|
||||
timeSegments: timeSegments,
|
||||
},
|
||||
cost: {
|
||||
total: (tokens.total || 0) * 0.0001,
|
||||
input: (tokens.prompt || 0) * 0.0001,
|
||||
output: (tokens.completion || 0) * 0.0001,
|
||||
input: accumulatedCost.input,
|
||||
output: accumulatedCost.output,
|
||||
total: accumulatedCost.total,
|
||||
},
|
||||
},
|
||||
logs: [], // No block logs at provider level
|
||||
logs: [],
|
||||
metadata: {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: new Date().toISOString(),
|
||||
@@ -542,11 +544,8 @@ export const xAIProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
|
||||
// Return the streaming execution object
|
||||
return streamingResult as StreamingExecution
|
||||
}
|
||||
|
||||
// Calculate overall timing
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -577,7 +576,6 @@ export const xAIProvider: ProviderConfig = {
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
// Include timing information even for errors
|
||||
const providerEndTime = Date.now()
|
||||
const providerEndTimeISO = new Date(providerEndTime).toISOString()
|
||||
const totalDuration = providerEndTime - providerStartTime
|
||||
@@ -589,9 +587,8 @@ export const xAIProvider: ProviderConfig = {
|
||||
hasResponseFormat: !!request.responseFormat,
|
||||
})
|
||||
|
||||
// Create a new error with timing information
|
||||
const enhancedError = new Error(error instanceof Error ? error.message : String(error))
|
||||
// @ts-ignore - Adding timing property to the error
|
||||
// @ts-ignore - Adding timing property to error for debugging
|
||||
enhancedError.timing = {
|
||||
startTime: providerStartTimeISO,
|
||||
endTime: providerEndTimeISO,
|
||||
|
||||
@@ -1,32 +1,20 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { trackForcedToolUsage } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('XAIProvider')
|
||||
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
|
||||
import type { CompletionUsage } from 'openai/resources/completions'
|
||||
import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils'
|
||||
|
||||
/**
|
||||
* Helper to wrap XAI (OpenAI-compatible) streaming into a browser-friendly
|
||||
* ReadableStream of raw assistant text chunks.
|
||||
* Creates a ReadableStream from an xAI streaming response.
|
||||
* Uses the shared OpenAI-compatible streaming utility.
|
||||
*/
|
||||
export function createReadableStreamFromXAIStream(xaiStream: any): ReadableStream {
|
||||
return new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of xaiStream) {
|
||||
const content = chunk.choices[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
controller.enqueue(new TextEncoder().encode(content))
|
||||
}
|
||||
}
|
||||
controller.close()
|
||||
} catch (err) {
|
||||
controller.error(err)
|
||||
}
|
||||
},
|
||||
})
|
||||
export function createReadableStreamFromXAIStream(
|
||||
xaiStream: AsyncIterable<ChatCompletionChunk>,
|
||||
onComplete?: (content: string, usage: CompletionUsage) => void
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createOpenAICompatibleStream(xaiStream, 'xAI', onComplete)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a response format payload for XAI API requests.
|
||||
* Creates a response format payload for xAI requests with JSON schema.
|
||||
*/
|
||||
export function createResponseFormatPayload(
|
||||
basePayload: any,
|
||||
@@ -54,7 +42,8 @@ export function createResponseFormatPayload(
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to check for forced tool usage in responses.
|
||||
* Checks if a forced tool was used in an xAI response.
|
||||
* Uses the shared OpenAI-compatible forced tool usage helper.
|
||||
*/
|
||||
export function checkForForcedToolUsage(
|
||||
response: any,
|
||||
@@ -62,22 +51,5 @@ export function checkForForcedToolUsage(
|
||||
forcedTools: string[],
|
||||
usedForcedTools: string[]
|
||||
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
|
||||
let hasUsedForcedTool = false
|
||||
let updatedUsedForcedTools = usedForcedTools
|
||||
|
||||
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
|
||||
const toolCallsResponse = response.choices[0].message.tool_calls
|
||||
const result = trackForcedToolUsage(
|
||||
toolCallsResponse,
|
||||
toolChoice,
|
||||
logger,
|
||||
'xai',
|
||||
forcedTools,
|
||||
updatedUsedForcedTools
|
||||
)
|
||||
hasUsedForcedTool = result.hasUsedForcedTool
|
||||
updatedUsedForcedTools = result.usedForcedTools
|
||||
}
|
||||
|
||||
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
|
||||
return checkForForcedToolUsageOpenAI(response, toolChoice, 'xAI', forcedTools, usedForcedTools)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { create } from 'zustand'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { ProvidersStore } from '@/stores/providers/types'
|
||||
import type { OpenRouterModelInfo, ProvidersStore } from '@/stores/providers/types'
|
||||
|
||||
const logger = createLogger('ProvidersStore')
|
||||
|
||||
@@ -11,6 +11,7 @@ export const useProvidersStore = create<ProvidersStore>((set, get) => ({
|
||||
vllm: { models: [], isLoading: false },
|
||||
openrouter: { models: [], isLoading: false },
|
||||
},
|
||||
openRouterModelInfo: {},
|
||||
|
||||
setProviderModels: (provider, models) => {
|
||||
logger.info(`Updated ${provider} models`, { count: models.length })
|
||||
@@ -37,7 +38,22 @@ export const useProvidersStore = create<ProvidersStore>((set, get) => ({
|
||||
}))
|
||||
},
|
||||
|
||||
setOpenRouterModelInfo: (modelInfo: Record<string, OpenRouterModelInfo>) => {
|
||||
const structuredOutputCount = Object.values(modelInfo).filter(
|
||||
(m) => m.supportsStructuredOutputs
|
||||
).length
|
||||
logger.info('Updated OpenRouter model info', {
|
||||
count: Object.keys(modelInfo).length,
|
||||
withStructuredOutputs: structuredOutputCount,
|
||||
})
|
||||
set({ openRouterModelInfo: modelInfo })
|
||||
},
|
||||
|
||||
getProvider: (provider) => {
|
||||
return get().providers[provider]
|
||||
},
|
||||
|
||||
getOpenRouterModelInfo: (modelId: string) => {
|
||||
return get().openRouterModelInfo[modelId]
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -1,5 +1,16 @@
|
||||
export type ProviderName = 'ollama' | 'vllm' | 'openrouter' | 'base'
|
||||
|
||||
export interface OpenRouterModelInfo {
|
||||
id: string
|
||||
contextLength?: number
|
||||
supportsStructuredOutputs?: boolean
|
||||
supportsTools?: boolean
|
||||
pricing?: {
|
||||
input: number
|
||||
output: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface ProviderState {
|
||||
models: string[]
|
||||
isLoading: boolean
|
||||
@@ -7,7 +18,10 @@ export interface ProviderState {
|
||||
|
||||
export interface ProvidersStore {
|
||||
providers: Record<ProviderName, ProviderState>
|
||||
openRouterModelInfo: Record<string, OpenRouterModelInfo>
|
||||
setProviderModels: (provider: ProviderName, models: string[]) => void
|
||||
setProviderLoading: (provider: ProviderName, isLoading: boolean) => void
|
||||
setOpenRouterModelInfo: (modelInfo: Record<string, OpenRouterModelInfo>) => void
|
||||
getProvider: (provider: ProviderName) => ProviderState
|
||||
getOpenRouterModelInfo: (modelId: string) => OpenRouterModelInfo | undefined
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ export const memoryAddTool: ToolConfig<any, MemoryResponse> = {
|
||||
type: 'string',
|
||||
required: false,
|
||||
description:
|
||||
'Conversation identifier (e.g., user-123, session-abc). If a memory with this conversationId already exists for this block, the new message will be appended to it.',
|
||||
'Conversation identifier (e.g., user-123, session-abc). If a memory with this conversationId already exists, the new message will be appended to it.',
|
||||
},
|
||||
id: {
|
||||
type: 'string',
|
||||
@@ -31,12 +31,6 @@ export const memoryAddTool: ToolConfig<any, MemoryResponse> = {
|
||||
required: true,
|
||||
description: 'Content for agent memory',
|
||||
},
|
||||
blockId: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
description:
|
||||
'Optional block ID. If not provided, uses the current block ID from execution context, or defaults to "default".',
|
||||
},
|
||||
},
|
||||
|
||||
request: {
|
||||
@@ -46,29 +40,24 @@ export const memoryAddTool: ToolConfig<any, MemoryResponse> = {
|
||||
'Content-Type': 'application/json',
|
||||
}),
|
||||
body: (params) => {
|
||||
const workflowId = params._context?.workflowId
|
||||
const contextBlockId = params._context?.blockId
|
||||
const workspaceId = params._context?.workspaceId
|
||||
|
||||
if (!workflowId) {
|
||||
if (!workspaceId) {
|
||||
return {
|
||||
_errorResponse: {
|
||||
status: 400,
|
||||
data: {
|
||||
success: false,
|
||||
error: {
|
||||
message: 'workflowId is required and must be provided in execution context',
|
||||
message: 'workspaceId is required and must be provided in execution context',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Use 'id' as fallback for 'conversationId' for backwards compatibility
|
||||
const conversationId = params.conversationId || params.id
|
||||
|
||||
// Default blockId to 'default' if not provided in params or context
|
||||
const blockId = params.blockId || contextBlockId || 'default'
|
||||
|
||||
if (!conversationId || conversationId.trim() === '') {
|
||||
return {
|
||||
_errorResponse: {
|
||||
@@ -97,11 +86,11 @@ export const memoryAddTool: ToolConfig<any, MemoryResponse> = {
|
||||
}
|
||||
}
|
||||
|
||||
const key = buildMemoryKey(conversationId, blockId)
|
||||
const key = buildMemoryKey(conversationId)
|
||||
|
||||
const body: Record<string, any> = {
|
||||
key,
|
||||
workflowId,
|
||||
workspaceId,
|
||||
data: {
|
||||
role: params.role,
|
||||
content: params.content,
|
||||
|
||||
@@ -4,8 +4,7 @@ import type { ToolConfig } from '@/tools/types'
|
||||
export const memoryDeleteTool: ToolConfig<any, MemoryResponse> = {
|
||||
id: 'memory_delete',
|
||||
name: 'Delete Memory',
|
||||
description:
|
||||
'Delete memories by conversationId, blockId, blockName, or a combination. Supports bulk deletion.',
|
||||
description: 'Delete memories by conversationId.',
|
||||
version: '1.0.0',
|
||||
|
||||
params: {
|
||||
@@ -13,7 +12,7 @@ export const memoryDeleteTool: ToolConfig<any, MemoryResponse> = {
|
||||
type: 'string',
|
||||
required: false,
|
||||
description:
|
||||
'Conversation identifier (e.g., user-123, session-abc). If provided alone, deletes all memories for this conversation across all blocks.',
|
||||
'Conversation identifier (e.g., user-123, session-abc). Deletes all memories for this conversation.',
|
||||
},
|
||||
id: {
|
||||
type: 'string',
|
||||
@@ -21,50 +20,36 @@ export const memoryDeleteTool: ToolConfig<any, MemoryResponse> = {
|
||||
description:
|
||||
'Legacy parameter for conversation identifier. Use conversationId instead. Provided for backwards compatibility.',
|
||||
},
|
||||
blockId: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
description:
|
||||
'Block identifier. If provided alone, deletes all memories for this block across all conversations. If provided with conversationId, deletes memories for that specific conversation in this block.',
|
||||
},
|
||||
blockName: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
description:
|
||||
'Block name. Alternative to blockId. If provided alone, deletes all memories for blocks with this name. If provided with conversationId, deletes memories for that conversation in blocks with this name.',
|
||||
},
|
||||
},
|
||||
|
||||
request: {
|
||||
url: (params): any => {
|
||||
const workflowId = params._context?.workflowId
|
||||
const workspaceId = params._context?.workspaceId
|
||||
|
||||
if (!workflowId) {
|
||||
if (!workspaceId) {
|
||||
return {
|
||||
_errorResponse: {
|
||||
status: 400,
|
||||
data: {
|
||||
success: false,
|
||||
error: {
|
||||
message: 'workflowId is required and must be provided in execution context',
|
||||
message: 'workspaceId is required and must be provided in execution context',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Use 'id' as fallback for 'conversationId' for backwards compatibility
|
||||
const conversationId = params.conversationId || params.id
|
||||
|
||||
if (!conversationId && !params.blockId && !params.blockName) {
|
||||
if (!conversationId) {
|
||||
return {
|
||||
_errorResponse: {
|
||||
status: 400,
|
||||
data: {
|
||||
success: false,
|
||||
error: {
|
||||
message:
|
||||
'At least one of conversationId, id, blockId, or blockName must be provided',
|
||||
message: 'conversationId or id must be provided',
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -72,17 +57,8 @@ export const memoryDeleteTool: ToolConfig<any, MemoryResponse> = {
|
||||
}
|
||||
|
||||
const url = new URL('/api/memory', 'http://dummy')
|
||||
url.searchParams.set('workflowId', workflowId)
|
||||
|
||||
if (conversationId) {
|
||||
url.searchParams.set('conversationId', conversationId)
|
||||
}
|
||||
if (params.blockId) {
|
||||
url.searchParams.set('blockId', params.blockId)
|
||||
}
|
||||
if (params.blockName) {
|
||||
url.searchParams.set('blockName', params.blockName)
|
||||
}
|
||||
url.searchParams.set('workspaceId', workspaceId)
|
||||
url.searchParams.set('conversationId', conversationId)
|
||||
|
||||
return url.pathname + url.search
|
||||
},
|
||||
|
||||
@@ -5,8 +5,7 @@ import type { ToolConfig } from '@/tools/types'
|
||||
export const memoryGetTool: ToolConfig<any, MemoryResponse> = {
|
||||
id: 'memory_get',
|
||||
name: 'Get Memory',
|
||||
description:
|
||||
'Retrieve memory by conversationId, blockId, blockName, or a combination. Returns all matching memories.',
|
||||
description: 'Retrieve memory by conversationId. Returns matching memories.',
|
||||
version: '1.0.0',
|
||||
|
||||
params: {
|
||||
@@ -14,7 +13,7 @@ export const memoryGetTool: ToolConfig<any, MemoryResponse> = {
|
||||
type: 'string',
|
||||
required: false,
|
||||
description:
|
||||
'Conversation identifier (e.g., user-123, session-abc). If provided alone, returns all memories for this conversation across all blocks.',
|
||||
'Conversation identifier (e.g., user-123, session-abc). Returns memories for this conversation.',
|
||||
},
|
||||
id: {
|
||||
type: 'string',
|
||||
@@ -22,75 +21,47 @@ export const memoryGetTool: ToolConfig<any, MemoryResponse> = {
|
||||
description:
|
||||
'Legacy parameter for conversation identifier. Use conversationId instead. Provided for backwards compatibility.',
|
||||
},
|
||||
blockId: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
description:
|
||||
'Block identifier. If provided alone, returns all memories for this block across all conversations. If provided with conversationId, returns memories for that specific conversation in this block.',
|
||||
},
|
||||
blockName: {
|
||||
type: 'string',
|
||||
required: false,
|
||||
description:
|
||||
'Block name. Alternative to blockId. If provided alone, returns all memories for blocks with this name. If provided with conversationId, returns memories for that conversation in blocks with this name.',
|
||||
},
|
||||
},
|
||||
|
||||
request: {
|
||||
url: (params): any => {
|
||||
const workflowId = params._context?.workflowId
|
||||
const workspaceId = params._context?.workspaceId
|
||||
|
||||
if (!workflowId) {
|
||||
if (!workspaceId) {
|
||||
return {
|
||||
_errorResponse: {
|
||||
status: 400,
|
||||
data: {
|
||||
success: false,
|
||||
error: {
|
||||
message: 'workflowId is required and must be provided in execution context',
|
||||
message: 'workspaceId is required and must be provided in execution context',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Use 'id' as fallback for 'conversationId' for backwards compatibility
|
||||
const conversationId = params.conversationId || params.id
|
||||
|
||||
if (!conversationId && !params.blockId && !params.blockName) {
|
||||
if (!conversationId) {
|
||||
return {
|
||||
_errorResponse: {
|
||||
status: 400,
|
||||
data: {
|
||||
success: false,
|
||||
error: {
|
||||
message:
|
||||
'At least one of conversationId, id, blockId, or blockName must be provided',
|
||||
message: 'conversationId or id must be provided',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
let query = ''
|
||||
|
||||
if (conversationId && params.blockId) {
|
||||
query = buildMemoryKey(conversationId, params.blockId)
|
||||
} else if (conversationId) {
|
||||
// Also check for legacy format (conversationId without blockId)
|
||||
query = `${conversationId}:`
|
||||
} else if (params.blockId) {
|
||||
query = `:${params.blockId}`
|
||||
}
|
||||
const query = buildMemoryKey(conversationId)
|
||||
|
||||
const url = new URL('/api/memory', 'http://dummy')
|
||||
url.searchParams.set('workflowId', workflowId)
|
||||
if (query) {
|
||||
url.searchParams.set('query', query)
|
||||
}
|
||||
if (params.blockName) {
|
||||
url.searchParams.set('blockName', params.blockName)
|
||||
}
|
||||
url.searchParams.set('workspaceId', workspaceId)
|
||||
url.searchParams.set('query', query)
|
||||
url.searchParams.set('limit', '1000')
|
||||
|
||||
return url.pathname + url.search
|
||||
@@ -128,8 +99,7 @@ export const memoryGetTool: ToolConfig<any, MemoryResponse> = {
|
||||
success: { type: 'boolean', description: 'Whether the memory was retrieved successfully' },
|
||||
memories: {
|
||||
type: 'array',
|
||||
description:
|
||||
'Array of memory objects with conversationId, blockId, blockName, and data fields',
|
||||
description: 'Array of memory objects with conversationId and data fields',
|
||||
},
|
||||
message: { type: 'string', description: 'Success or error message' },
|
||||
error: { type: 'string', description: 'Error message if operation failed' },
|
||||
|
||||
@@ -11,23 +11,23 @@ export const memoryGetAllTool: ToolConfig<any, MemoryResponse> = {
|
||||
|
||||
request: {
|
||||
url: (params): any => {
|
||||
const workflowId = params._context?.workflowId
|
||||
const workspaceId = params._context?.workspaceId
|
||||
|
||||
if (!workflowId) {
|
||||
if (!workspaceId) {
|
||||
return {
|
||||
_errorResponse: {
|
||||
status: 400,
|
||||
data: {
|
||||
success: false,
|
||||
error: {
|
||||
message: 'workflowId is required and must be provided in execution context',
|
||||
message: 'workspaceId is required and must be provided in execution context',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return `/api/memory?workflowId=${encodeURIComponent(workflowId)}`
|
||||
return `/api/memory?workspaceId=${encodeURIComponent(workspaceId)}`
|
||||
},
|
||||
method: 'GET',
|
||||
headers: () => ({
|
||||
@@ -64,8 +64,7 @@ export const memoryGetAllTool: ToolConfig<any, MemoryResponse> = {
|
||||
success: { type: 'boolean', description: 'Whether all memories were retrieved successfully' },
|
||||
memories: {
|
||||
type: 'array',
|
||||
description:
|
||||
'Array of all memory objects with key, conversationId, blockId, blockName, and data fields',
|
||||
description: 'Array of all memory objects with key, conversationId, and data fields',
|
||||
},
|
||||
message: { type: 'string', description: 'Success or error message' },
|
||||
error: { type: 'string', description: 'Error message if operation failed' },
|
||||
|
||||
@@ -1,45 +1,25 @@
|
||||
/**
|
||||
* Parse memory key into conversationId and blockId
|
||||
* Supports two formats:
|
||||
* - New format: conversationId:blockId (splits on LAST colon to handle IDs with colons)
|
||||
* - Legacy format: id (without colon, treated as conversationId with blockId='default')
|
||||
* @param key The memory key to parse
|
||||
* @returns Object with conversationId and blockId, or null if invalid
|
||||
* Parse memory key to extract conversationId
|
||||
* Memory is now thread-scoped, so the key is just the conversationId
|
||||
* @param key The memory key (conversationId)
|
||||
* @returns Object with conversationId, or null if invalid
|
||||
*/
|
||||
export function parseMemoryKey(key: string): { conversationId: string; blockId: string } | null {
|
||||
export function parseMemoryKey(key: string): { conversationId: string } | null {
|
||||
if (!key) {
|
||||
return null
|
||||
}
|
||||
|
||||
const lastColonIndex = key.lastIndexOf(':')
|
||||
|
||||
// Legacy format: no colon found
|
||||
if (lastColonIndex === -1) {
|
||||
return {
|
||||
conversationId: key,
|
||||
blockId: 'default',
|
||||
}
|
||||
}
|
||||
|
||||
// Invalid: colon at start or end
|
||||
if (lastColonIndex === 0 || lastColonIndex === key.length - 1) {
|
||||
return null
|
||||
}
|
||||
|
||||
// New format: split on last colon to handle IDs with colons
|
||||
// This allows conversationIds like "user:123" to work correctly
|
||||
return {
|
||||
conversationId: key.substring(0, lastColonIndex),
|
||||
blockId: key.substring(lastColonIndex + 1),
|
||||
conversationId: key,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build memory key from conversationId and blockId
|
||||
* Build memory key from conversationId
|
||||
* Memory is thread-scoped, so key is just the conversationId
|
||||
* @param conversationId The conversation ID
|
||||
* @param blockId The block ID
|
||||
* @returns The memory key in format conversationId:blockId
|
||||
* @returns The memory key (same as conversationId)
|
||||
*/
|
||||
export function buildMemoryKey(conversationId: string, blockId: string): string {
|
||||
return `${conversationId}:${blockId}`
|
||||
export function buildMemoryKey(conversationId: string): string {
|
||||
return conversationId
|
||||
}
|
||||
|
||||
@@ -16,8 +16,6 @@ export interface MemoryRecord {
|
||||
id: string
|
||||
key: string
|
||||
conversationId: string
|
||||
blockId: string
|
||||
blockName: string
|
||||
data: AgentMemoryData[]
|
||||
createdAt: string
|
||||
updatedAt: string
|
||||
|
||||
Reference in New Issue
Block a user