feat(kb): added cost for kb blocks (#654)

* added cost to kb upload + search

* small fix

* ack PR comments
This commit is contained in:
Waleed Latif
2025-07-10 13:53:20 -07:00
committed by GitHub
parent 614d826217
commit 3887733da5
12 changed files with 957 additions and 22 deletions

View File

@@ -0,0 +1,413 @@
/**
* Tests for knowledge document chunks API route
*
* @vitest-environment node
*/
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockRequest,
mockAuth,
mockConsoleLogger,
mockDrizzleOrm,
mockKnowledgeSchemas,
} from '@/app/api/__test-utils__/utils'
import type { DocumentAccessCheck } from '../../../../utils'
mockKnowledgeSchemas()
mockDrizzleOrm()
mockConsoleLogger()
vi.mock('@/lib/tokenization/estimators', () => ({
estimateTokenCount: vi.fn().mockReturnValue({ count: 452 }),
}))
vi.mock('@/providers/utils', () => ({
calculateCost: vi.fn().mockReturnValue({
input: 0.00000904,
output: 0,
total: 0.00000904,
pricing: {
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
},
}),
}))
vi.mock('../../../../utils', () => ({
checkDocumentAccess: vi.fn(),
generateEmbeddings: vi.fn().mockResolvedValue([[0.1, 0.2, 0.3, 0.4, 0.5]]),
}))
describe('Knowledge Document Chunks API Route', () => {
const mockAuth$ = mockAuth()
const mockDbChain = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockReturnThis(),
offset: vi.fn().mockReturnThis(),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockResolvedValue(undefined),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
returning: vi.fn().mockResolvedValue([]),
delete: vi.fn().mockReturnThis(),
transaction: vi.fn(),
}
const mockGetUserId = vi.fn()
beforeEach(async () => {
vi.clearAllMocks()
vi.doMock('@/db', () => ({
db: mockDbChain,
}))
vi.doMock('@/app/api/auth/oauth/utils', () => ({
getUserId: mockGetUserId,
}))
Object.values(mockDbChain).forEach((fn) => {
if (typeof fn === 'function' && fn !== mockDbChain.values && fn !== mockDbChain.returning) {
fn.mockClear().mockReturnThis()
}
})
vi.stubGlobal('crypto', {
randomUUID: vi.fn().mockReturnValue('mock-chunk-uuid-1234'),
createHash: vi.fn().mockReturnValue({
update: vi.fn().mockReturnThis(),
digest: vi.fn().mockReturnValue('mock-hash-123'),
}),
})
})
afterEach(() => {
vi.clearAllMocks()
})
describe('POST /api/knowledge/[id]/documents/[documentId]/chunks', () => {
const validChunkData = {
content: 'This is test chunk content for uploading to the knowledge base document.',
enabled: true,
}
const mockDocumentAccess = {
hasAccess: true,
notFound: false,
reason: '',
document: {
id: 'doc-123',
processingStatus: 'completed',
tag1: 'tag1-value',
tag2: 'tag2-value',
tag3: null,
tag4: null,
tag5: null,
tag6: null,
tag7: null,
},
}
const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' })
it('should create chunk successfully with cost tracking', async () => {
const { checkDocumentAccess } = await import('../../../../utils')
const { estimateTokenCount } = await import('@/lib/tokenization/estimators')
const { calculateCost } = await import('@/providers/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentAccess).mockResolvedValue(mockDocumentAccess as DocumentAccessCheck)
// Mock transaction
const mockTx = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockResolvedValue([{ chunkIndex: 0 }]),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockResolvedValue(undefined),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
}
mockDbChain.transaction.mockImplementation(async (callback) => {
return await callback(mockTx)
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('./route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(200)
expect(data.success).toBe(true)
// Verify cost tracking
expect(data.data.cost).toBeDefined()
expect(data.data.cost.input).toBe(0.00000904)
expect(data.data.cost.output).toBe(0)
expect(data.data.cost.total).toBe(0.00000904)
expect(data.data.cost.tokens).toEqual({
prompt: 452,
completion: 0,
total: 452,
})
expect(data.data.cost.model).toBe('text-embedding-3-small')
expect(data.data.cost.pricing).toEqual({
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
})
// Verify function calls
expect(estimateTokenCount).toHaveBeenCalledWith(validChunkData.content, 'openai')
expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 452, 0, false)
})
it('should handle workflow-based authentication', async () => {
const { checkDocumentAccess } = await import('../../../../utils')
const workflowData = {
...validChunkData,
workflowId: 'workflow-123',
}
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentAccess).mockResolvedValue(mockDocumentAccess as DocumentAccessCheck)
const mockTx = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockResolvedValue([]),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockResolvedValue(undefined),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
}
mockDbChain.transaction.mockImplementation(async (callback) => {
return await callback(mockTx)
})
const req = createMockRequest('POST', workflowData)
const { POST } = await import('./route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(mockGetUserId).toHaveBeenCalledWith(expect.any(String), 'workflow-123')
})
it.concurrent('should return unauthorized for unauthenticated request', async () => {
mockGetUserId.mockResolvedValue(null)
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('./route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(401)
expect(data.error).toBe('Unauthorized')
})
it('should return not found for workflow that does not exist', async () => {
const workflowData = {
...validChunkData,
workflowId: 'nonexistent-workflow',
}
mockGetUserId.mockResolvedValue(null)
const req = createMockRequest('POST', workflowData)
const { POST } = await import('./route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(404)
expect(data.error).toBe('Workflow not found')
})
it.concurrent('should return not found for document access denied', async () => {
const { checkDocumentAccess } = await import('../../../../utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentAccess).mockResolvedValue({
hasAccess: false,
notFound: true,
reason: 'Document not found',
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('./route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(404)
expect(data.error).toBe('Document not found')
})
it('should return unauthorized for unauthorized document access', async () => {
const { checkDocumentAccess } = await import('../../../../utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentAccess).mockResolvedValue({
hasAccess: false,
notFound: false,
reason: 'Unauthorized access',
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('./route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(401)
expect(data.error).toBe('Unauthorized')
})
it('should reject chunks for failed documents', async () => {
const { checkDocumentAccess } = await import('../../../../utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentAccess).mockResolvedValue({
...mockDocumentAccess,
document: {
...mockDocumentAccess.document!,
processingStatus: 'failed',
},
} as DocumentAccessCheck)
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('./route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(400)
expect(data.error).toBe('Cannot add chunks to failed document')
})
it.concurrent('should validate chunk data', async () => {
const { checkDocumentAccess } = await import('../../../../utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentAccess).mockResolvedValue(mockDocumentAccess as DocumentAccessCheck)
const invalidData = {
content: '', // Empty content
enabled: true,
}
const req = createMockRequest('POST', invalidData)
const { POST } = await import('./route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(400)
expect(data.error).toBe('Invalid request data')
expect(data.details).toBeDefined()
})
it('should inherit tags from parent document', async () => {
const { checkDocumentAccess } = await import('../../../../utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentAccess).mockResolvedValue(mockDocumentAccess as DocumentAccessCheck)
const mockTx = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockResolvedValue([]),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockImplementation((data) => {
// Verify that tags are inherited from document
expect(data.tag1).toBe('tag1-value')
expect(data.tag2).toBe('tag2-value')
expect(data.tag3).toBe(null)
return Promise.resolve(undefined)
}),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
}
mockDbChain.transaction.mockImplementation(async (callback) => {
return await callback(mockTx)
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('./route')
await POST(req, { params: mockParams })
expect(mockTx.values).toHaveBeenCalled()
})
it.concurrent('should handle cost calculation with different content lengths', async () => {
const { estimateTokenCount } = await import('@/lib/tokenization/estimators')
const { calculateCost } = await import('@/providers/utils')
const { checkDocumentAccess } = await import('../../../../utils')
// Mock larger content with more tokens
vi.mocked(estimateTokenCount).mockReturnValue({
count: 1000,
confidence: 'high',
provider: 'openai',
method: 'precise',
})
vi.mocked(calculateCost).mockReturnValue({
input: 0.00002,
output: 0,
total: 0.00002,
pricing: {
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
},
})
const largeChunkData = {
content:
'This is a much larger chunk of content that would result in significantly more tokens when processed through the OpenAI tokenization system for embedding generation. This content is designed to test the cost calculation accuracy with larger input sizes.',
enabled: true,
}
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentAccess).mockResolvedValue(mockDocumentAccess as DocumentAccessCheck)
const mockTx = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockResolvedValue([]),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockResolvedValue(undefined),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
}
mockDbChain.transaction.mockImplementation(async (callback) => {
return await callback(mockTx)
})
const req = createMockRequest('POST', largeChunkData)
const { POST } = await import('./route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(200)
expect(data.data.cost.input).toBe(0.00002)
expect(data.data.cost.tokens.prompt).toBe(1000)
expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 1000, 0, false)
})
})
})

View File

@@ -4,9 +4,11 @@ import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console-logger'
import { estimateTokenCount } from '@/lib/tokenization/estimators'
import { getUserId } from '@/app/api/auth/oauth/utils'
import { db } from '@/db'
import { document, embedding } from '@/db/schema'
import { calculateCost } from '@/providers/utils'
import { checkDocumentAccess, generateEmbeddings } from '../../../../utils'
const logger = createLogger('DocumentChunksAPI')
@@ -217,6 +219,9 @@ export async function POST(
logger.info(`[${requestId}] Generating embedding for manual chunk`)
const embeddings = await generateEmbeddings([validatedData.content])
// Calculate accurate token count for both database storage and cost calculation
const tokenCount = estimateTokenCount(validatedData.content, 'openai')
const chunkId = crypto.randomUUID()
const now = new Date()
@@ -240,7 +245,7 @@ export async function POST(
chunkHash: crypto.createHash('sha256').update(validatedData.content).digest('hex'),
content: validatedData.content,
contentLength: validatedData.content.length,
tokenCount: Math.ceil(validatedData.content.length / 4), // Rough approximation
tokenCount: tokenCount.count, // Use accurate token count
embedding: embeddings[0],
embeddingModel: 'text-embedding-3-small',
startOffset: 0, // Manual chunks don't have document offsets
@@ -276,9 +281,38 @@ export async function POST(
logger.info(`[${requestId}] Manual chunk created: ${chunkId} in document ${documentId}`)
// Calculate cost for the embedding (with fallback if calculation fails)
let cost = null
try {
cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false)
} catch (error) {
logger.warn(`[${requestId}] Failed to calculate cost for chunk upload`, {
error: error instanceof Error ? error.message : 'Unknown error',
})
// Continue without cost information rather than failing the upload
}
return NextResponse.json({
success: true,
data: newChunk,
data: {
...newChunk,
...(cost
? {
cost: {
input: cost.input,
output: cost.output,
total: cost.total,
tokens: {
prompt: tokenCount.count,
completion: 0,
total: tokenCount.count,
},
model: 'text-embedding-3-small',
pricing: cost.pricing,
},
}
: {}),
},
})
} catch (validationError) {
if (validationError instanceof z.ZodError) {

View File

@@ -8,7 +8,6 @@ import { document, knowledgeBase } from '@/db/schema'
const logger = createLogger('KnowledgeBaseAPI')
// Schema for knowledge base creation
const CreateKnowledgeBaseSchema = z.object({
name: z.string().min(1, 'Name is required'),
description: z.string().optional(),

View File

@@ -34,6 +34,23 @@ vi.mock('@/lib/documents/utils', () => ({
retryWithExponentialBackoff: vi.fn().mockImplementation((fn) => fn()),
}))
vi.mock('@/lib/tokenization/estimators', () => ({
estimateTokenCount: vi.fn().mockReturnValue({ count: 521 }),
}))
vi.mock('@/providers/utils', () => ({
calculateCost: vi.fn().mockReturnValue({
input: 0.00001042,
output: 0,
total: 0.00001042,
pricing: {
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
},
}),
}))
mockConsoleLogger()
describe('Knowledge Search API Route', () => {
@@ -206,7 +223,7 @@ describe('Knowledge Search API Route', () => {
expect(mockGetUserId).toHaveBeenCalledWith(expect.any(String), 'workflow-123')
})
it('should return unauthorized for unauthenticated request', async () => {
it.concurrent('should return unauthorized for unauthenticated request', async () => {
mockGetUserId.mockResolvedValue(null)
const req = createMockRequest('POST', validSearchData)
@@ -218,7 +235,7 @@ describe('Knowledge Search API Route', () => {
expect(data.error).toBe('Unauthorized')
})
it('should return not found for workflow that does not exist', async () => {
it.concurrent('should return not found for workflow that does not exist', async () => {
const workflowData = {
...validSearchData,
workflowId: 'nonexistent-workflow',
@@ -268,7 +285,7 @@ describe('Knowledge Search API Route', () => {
expect(data.error).toBe('Knowledge bases not found: kb-missing')
})
it('should validate search parameters', async () => {
it.concurrent('should validate search parameters', async () => {
const invalidData = {
knowledgeBaseIds: '', // Empty string
query: '', // Empty query
@@ -314,7 +331,7 @@ describe('Knowledge Search API Route', () => {
expect(data.data.topK).toBe(10) // Default value
})
it('should handle OpenAI API errors', async () => {
it.concurrent('should handle OpenAI API errors', async () => {
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases)
@@ -334,7 +351,7 @@ describe('Knowledge Search API Route', () => {
expect(data.error).toBe('Failed to perform vector search')
})
it('should handle missing OpenAI API key', async () => {
it.concurrent('should handle missing OpenAI API key', async () => {
vi.doMock('@/lib/env', () => ({
env: {
OPENAI_API_KEY: undefined,
@@ -353,7 +370,7 @@ describe('Knowledge Search API Route', () => {
expect(data.error).toBe('Failed to perform vector search')
})
it('should handle database errors during search', async () => {
it.concurrent('should handle database errors during search', async () => {
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases)
mockDbChain.limit.mockRejectedValueOnce(new Error('Database error'))
@@ -375,7 +392,7 @@ describe('Knowledge Search API Route', () => {
expect(data.error).toBe('Failed to perform vector search')
})
it('should handle invalid OpenAI response format', async () => {
it.concurrent('should handle invalid OpenAI response format', async () => {
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases)
@@ -395,5 +412,124 @@ describe('Knowledge Search API Route', () => {
expect(response.status).toBe(500)
expect(data.error).toBe('Failed to perform vector search')
})
describe('Cost tracking', () => {
it.concurrent('should include cost information in successful search response', async () => {
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
mockFetch.mockResolvedValue({
ok: true,
json: () =>
Promise.resolve({
data: [{ embedding: mockEmbedding }],
}),
})
const req = createMockRequest('POST', validSearchData)
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
expect(response.status).toBe(200)
expect(data.success).toBe(true)
// Verify cost information is included
expect(data.data.cost).toBeDefined()
expect(data.data.cost.input).toBe(0.00001042)
expect(data.data.cost.output).toBe(0)
expect(data.data.cost.total).toBe(0.00001042)
expect(data.data.cost.tokens).toEqual({
prompt: 521,
completion: 0,
total: 521,
})
expect(data.data.cost.model).toBe('text-embedding-3-small')
expect(data.data.cost.pricing).toEqual({
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
})
})
it('should call cost calculation functions with correct parameters', async () => {
const { estimateTokenCount } = await import('@/lib/tokenization/estimators')
const { calculateCost } = await import('@/providers/utils')
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
mockFetch.mockResolvedValue({
ok: true,
json: () =>
Promise.resolve({
data: [{ embedding: mockEmbedding }],
}),
})
const req = createMockRequest('POST', validSearchData)
const { POST } = await import('./route')
await POST(req)
// Verify token estimation was called with correct parameters
expect(estimateTokenCount).toHaveBeenCalledWith('test search query', 'openai')
// Verify cost calculation was called with correct parameters
expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 521, 0, false)
})
it('should handle cost calculation with different query lengths', async () => {
const { estimateTokenCount } = await import('@/lib/tokenization/estimators')
const { calculateCost } = await import('@/providers/utils')
// Mock different token count for longer query
vi.mocked(estimateTokenCount).mockReturnValue({
count: 1042,
confidence: 'high',
provider: 'openai',
method: 'precise',
})
vi.mocked(calculateCost).mockReturnValue({
input: 0.00002084,
output: 0,
total: 0.00002084,
pricing: {
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
},
})
const longQueryData = {
...validSearchData,
query:
'This is a much longer search query with many more tokens to test cost calculation accuracy',
}
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
mockFetch.mockResolvedValue({
ok: true,
json: () =>
Promise.resolve({
data: [{ embedding: mockEmbedding }],
}),
})
const req = createMockRequest('POST', longQueryData)
const { POST } = await import('./route')
const response = await POST(req)
const data = await response.json()
expect(response.status).toBe(200)
expect(data.data.cost.input).toBe(0.00002084)
expect(data.data.cost.tokens.prompt).toBe(1042)
expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 1042, 0, false)
})
})
})
})

View File

@@ -4,13 +4,14 @@ import { z } from 'zod'
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console-logger'
import { estimateTokenCount } from '@/lib/tokenization/estimators'
import { getUserId } from '@/app/api/auth/oauth/utils'
import { db } from '@/db'
import { embedding, knowledgeBase } from '@/db/schema'
import { calculateCost } from '@/providers/utils'
const logger = createLogger('VectorSearchAPI')
// Helper function to create tag filters
function getTagFilters(filters: Record<string, string>, embedding: any) {
return Object.entries(filters).map(([key, value]) => {
switch (key) {
@@ -51,7 +52,6 @@ const VectorSearchSchema = z.object({
]),
query: z.string().min(1, 'Search query is required'),
topK: z.number().min(1).max(100).default(10),
// Tag filters for pre-filtering
filters: z
.object({
tag1: z.string().optional(),
@@ -166,7 +166,6 @@ async function executeParallelQueries(
eq(embedding.knowledgeBaseId, kbId),
eq(embedding.enabled, true),
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`,
// Apply tag filters if provided (case-insensitive)
...(filters ? getTagFilters(filters, embedding) : [])
)
)
@@ -208,7 +207,6 @@ async function executeSingleQuery(
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
eq(embedding.enabled, true),
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`,
// Apply tag filters if provided (case-insensitive)
...(filters
? Object.entries(filters).map(([key, value]) => {
switch (key) {
@@ -321,6 +319,19 @@ export async function POST(request: NextRequest) {
)
}
// Calculate cost for the embedding (with fallback if calculation fails)
let cost = null
let tokenCount = null
try {
tokenCount = estimateTokenCount(validatedData.query, 'openai')
cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false)
} catch (error) {
logger.warn(`[${requestId}] Failed to calculate cost for search query`, {
error: error instanceof Error ? error.message : 'Unknown error',
})
// Continue without cost information rather than failing the search
}
return NextResponse.json({
success: true,
data: {
@@ -343,6 +354,22 @@ export async function POST(request: NextRequest) {
knowledgeBaseId: foundKbIds[0],
topK: validatedData.topK,
totalResults: results.length,
...(cost && tokenCount
? {
cost: {
input: cost.input,
output: cost.output,
total: cost.total,
tokens: {
prompt: tokenCount.count,
completion: 0,
total: tokenCount.count,
},
model: 'text-embedding-3-small',
pricing: cost.pricing,
},
}
: {}),
},
})
} catch (validationError) {

View File

@@ -73,7 +73,7 @@ describe('GenericBlockHandler', () => {
mockExecuteTool.mockResolvedValue({ success: true, output: { customResult: 'OK' } })
})
it('should always handle any block type', () => {
it.concurrent('should always handle any block type', () => {
const agentBlock: SerializedBlock = { ...mockBlock, metadata: { id: 'agent' } }
expect(handler.canHandle(agentBlock)).toBe(true)
expect(handler.canHandle(mockBlock)).toBe(true)
@@ -81,7 +81,7 @@ describe('GenericBlockHandler', () => {
expect(handler.canHandle(noMetaIdBlock)).toBe(true)
})
it('should execute generic block by calling its associated tool', async () => {
it.concurrent('should execute generic block by calling its associated tool', async () => {
const inputs = { param1: 'resolvedValue1' }
const expectedToolParams = {
...inputs,
@@ -133,7 +133,7 @@ describe('GenericBlockHandler', () => {
expect(mockExecuteTool).toHaveBeenCalledTimes(2) // Called twice now
})
it('should handle tool execution errors with no specific message', async () => {
it.concurrent('should handle tool execution errors with no specific message', async () => {
const inputs = { param1: 'value' }
const errorResult = { success: false, output: {} }
mockExecuteTool.mockResolvedValue(errorResult)
@@ -142,4 +142,218 @@ describe('GenericBlockHandler', () => {
'Block execution of Some Custom Tool failed with no error message'
)
})
describe('Knowledge block cost tracking', () => {
beforeEach(() => {
// Set up knowledge block mock
mockBlock = {
...mockBlock,
config: { tool: 'knowledge_search', params: {} },
}
mockTool = {
...mockTool,
id: 'knowledge_search',
name: 'Knowledge Search',
}
mockGetTool.mockImplementation((toolId) => {
if (toolId === 'knowledge_search') {
return mockTool
}
return undefined
})
})
it.concurrent(
'should extract and restructure cost information from knowledge tools',
async () => {
const inputs = { query: 'test query' }
const mockToolResponse = {
success: true,
output: {
results: [],
query: 'test query',
totalResults: 0,
cost: {
input: 0.00001042,
output: 0,
total: 0.00001042,
tokens: {
prompt: 521,
completion: 0,
total: 521,
},
model: 'text-embedding-3-small',
pricing: {
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
},
},
},
}
mockExecuteTool.mockResolvedValue(mockToolResponse)
const result = await handler.execute(mockBlock, inputs, mockContext)
// Verify cost information is restructured correctly for enhanced logging
expect(result).toEqual({
results: [],
query: 'test query',
totalResults: 0,
cost: {
input: 0.00001042,
output: 0,
total: 0.00001042,
},
tokens: {
prompt: 521,
completion: 0,
total: 521,
},
model: 'text-embedding-3-small',
})
}
)
it.concurrent('should handle knowledge_upload_chunk cost information', async () => {
// Update to upload_chunk tool
mockBlock.config.tool = 'knowledge_upload_chunk'
mockTool.id = 'knowledge_upload_chunk'
mockTool.name = 'Knowledge Upload Chunk'
mockGetTool.mockImplementation((toolId) => {
if (toolId === 'knowledge_upload_chunk') {
return mockTool
}
return undefined
})
const inputs = { content: 'test content' }
const mockToolResponse = {
success: true,
output: {
data: {
id: 'chunk-123',
content: 'test content',
chunkIndex: 0,
},
message: 'Successfully uploaded chunk',
documentId: 'doc-123',
cost: {
input: 0.00000521,
output: 0,
total: 0.00000521,
tokens: {
prompt: 260,
completion: 0,
total: 260,
},
model: 'text-embedding-3-small',
pricing: {
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
},
},
},
}
mockExecuteTool.mockResolvedValue(mockToolResponse)
const result = await handler.execute(mockBlock, inputs, mockContext)
// Verify cost information is restructured correctly
expect(result).toEqual({
data: {
id: 'chunk-123',
content: 'test content',
chunkIndex: 0,
},
message: 'Successfully uploaded chunk',
documentId: 'doc-123',
cost: {
input: 0.00000521,
output: 0,
total: 0.00000521,
},
tokens: {
prompt: 260,
completion: 0,
total: 260,
},
model: 'text-embedding-3-small',
})
})
it('should pass through output unchanged for knowledge tools without cost info', async () => {
const inputs = { query: 'test query' }
const mockToolResponse = {
success: true,
output: {
results: [],
query: 'test query',
totalResults: 0,
// No cost information
},
}
mockExecuteTool.mockResolvedValue(mockToolResponse)
const result = await handler.execute(mockBlock, inputs, mockContext)
// Should return original output without cost transformation
expect(result).toEqual({
results: [],
query: 'test query',
totalResults: 0,
})
})
it.concurrent('should not process cost info for non-knowledge tools', async () => {
// Set up non-knowledge tool
mockBlock.config.tool = 'some_other_tool'
mockTool.id = 'some_other_tool'
mockGetTool.mockImplementation((toolId) => {
if (toolId === 'some_other_tool') {
return mockTool
}
return undefined
})
const inputs = { param: 'value' }
const mockToolResponse = {
success: true,
output: {
result: 'success',
cost: {
input: 0.001,
output: 0.002,
total: 0.003,
tokens: { prompt: 100, completion: 50, total: 150 },
model: 'some-model',
},
},
}
mockExecuteTool.mockResolvedValue(mockToolResponse)
const result = await handler.execute(mockBlock, inputs, mockContext)
// Should return original output without cost transformation
expect(result).toEqual({
result: 'success',
cost: {
input: 0.001,
output: 0.002,
total: 0.003,
tokens: { prompt: 100, completion: 50, total: 150 },
model: 'some-model',
},
})
})
})
})

View File

@@ -59,7 +59,30 @@ export class GenericBlockHandler implements BlockHandler {
throw error
}
return result.output
// Extract cost information from tool response if available
const output = result.output
let cost = null
// Check if the tool is a knowledge tool and has cost information
if (block.config.tool?.startsWith('knowledge_') && output?.cost) {
cost = output.cost
}
// Return the output with cost information if available
if (cost) {
return {
...output,
cost: {
input: cost.input,
output: cost.output,
total: cost.total,
},
tokens: cost.tokens,
model: cost.model,
}
}
return output
} catch (error: any) {
// Ensure we have a meaningful error message
if (!error.message || error.message === 'undefined (undefined)') {

View File

@@ -647,3 +647,31 @@ export function updateOllamaModels(models: string[]): void {
capabilities: {},
}))
}
/**
* Embedding model pricing - separate from chat models
*/
export const EMBEDDING_MODEL_PRICING: Record<string, ModelPricing> = {
'text-embedding-3-small': {
input: 0.02, // $0.02 per 1M tokens
output: 0.0,
updatedAt: '2025-07-10',
},
'text-embedding-3-large': {
input: 0.13, // $0.13 per 1M tokens
output: 0.0,
updatedAt: '2025-07-10',
},
'text-embedding-ada-002': {
input: 0.1, // $0.1 per 1M tokens
output: 0.0,
updatedAt: '2025-07-10',
},
}
/**
* Get pricing for embedding models specifically
*/
export function getEmbeddingModelPricing(modelId: string): ModelPricing | null {
return EMBEDDING_MODEL_PRICING[modelId] || null
}

View File

@@ -9,6 +9,7 @@ import { googleProvider } from './google'
import { groqProvider } from './groq'
import {
getComputerUseModels,
getEmbeddingModelPricing,
getHostedModels as getHostedModelsFromDefinitions,
getMaxTemperature as getMaxTempFromDefinitions,
getModelPricing as getModelPricingFromDefinitions,
@@ -444,7 +445,13 @@ export function calculateCost(
completionTokens = 0,
useCachedInput = false
) {
const pricing = getModelPricingFromDefinitions(model)
// First check if it's an embedding model
let pricing = getEmbeddingModelPricing(model)
// If not found, check chat models
if (!pricing) {
pricing = getModelPricingFromDefinitions(model)
}
// If no pricing found, return default pricing
if (!pricing) {
@@ -475,14 +482,32 @@ export function calculateCost(
const costMultiplier = getCostMultiplier()
const finalInputCost = inputCost * costMultiplier
const finalOutputCost = outputCost * costMultiplier
const finalTotalCost = totalCost * costMultiplier
return {
input: Number.parseFloat((inputCost * costMultiplier).toFixed(6)),
output: Number.parseFloat((outputCost * costMultiplier).toFixed(6)),
total: Number.parseFloat((totalCost * costMultiplier).toFixed(6)),
input: Number.parseFloat(finalInputCost.toFixed(8)), // Use 8 decimal places for small costs
output: Number.parseFloat(finalOutputCost.toFixed(8)),
total: Number.parseFloat(finalTotalCost.toFixed(8)),
pricing,
}
}
/**
* Get pricing information for a specific model (including embedding models)
*/
export function getModelPricing(modelId: string): any {
// First check if it's an embedding model
const embeddingPricing = getEmbeddingModelPricing(modelId)
if (embeddingPricing) {
return embeddingPricing
}
// Then check chat models
return getModelPricingFromDefinitions(modelId)
}
/**
* Format cost as a currency string
*

View File

@@ -120,6 +120,7 @@ export const knowledgeSearchTool: ToolConfig<any, KnowledgeSearchResponse> = {
results: data.results || [],
query: data.query,
totalResults: data.totalResults || 0,
cost: data.cost,
},
}
} catch (error: any) {
@@ -129,6 +130,7 @@ export const knowledgeSearchTool: ToolConfig<any, KnowledgeSearchResponse> = {
results: [],
query: '',
totalResults: 0,
cost: undefined,
},
error: `Vector search failed: ${error.message || 'Unknown error'}`,
}
@@ -142,6 +144,7 @@ export const knowledgeSearchTool: ToolConfig<any, KnowledgeSearchResponse> = {
results: [],
query: '',
totalResults: 0,
cost: undefined,
},
error: errorMessage,
}

View File

@@ -13,6 +13,22 @@ export interface KnowledgeSearchResponse {
results: KnowledgeSearchResult[]
query: string
totalResults: number
cost?: {
input: number
output: number
total: number
tokens: {
prompt: number
completion: number
total: number
}
model: string
pricing: {
input: number
output: number
updatedAt: string
}
}
}
error?: string
}
@@ -40,6 +56,22 @@ export interface KnowledgeUploadChunkResponse {
data: KnowledgeUploadChunkResult
message: string
documentId: string
cost?: {
input: number
output: number
total: number
tokens: {
prompt: number
completion: number
total: number
}
model: string
pricing: {
input: number
output: number
updatedAt: string
}
}
}
error?: string
}

View File

@@ -69,6 +69,7 @@ export const knowledgeUploadChunkTool: ToolConfig<any, KnowledgeUploadChunkRespo
},
message: `Successfully uploaded chunk to document`,
documentId: data.documentId,
cost: data.cost,
},
}
} catch (error: any) {