diff --git a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.test.ts b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.test.ts new file mode 100644 index 0000000000..6ae1f715c2 --- /dev/null +++ b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.test.ts @@ -0,0 +1,413 @@ +/** + * Tests for knowledge document chunks API route + * + * @vitest-environment node + */ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { + createMockRequest, + mockAuth, + mockConsoleLogger, + mockDrizzleOrm, + mockKnowledgeSchemas, +} from '@/app/api/__test-utils__/utils' +import type { DocumentAccessCheck } from '../../../../utils' + +mockKnowledgeSchemas() +mockDrizzleOrm() +mockConsoleLogger() + +vi.mock('@/lib/tokenization/estimators', () => ({ + estimateTokenCount: vi.fn().mockReturnValue({ count: 452 }), +})) + +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn().mockReturnValue({ + input: 0.00000904, + output: 0, + total: 0.00000904, + pricing: { + input: 0.02, + output: 0, + updatedAt: '2025-07-10', + }, + }), +})) + +vi.mock('../../../../utils', () => ({ + checkDocumentAccess: vi.fn(), + generateEmbeddings: vi.fn().mockResolvedValue([[0.1, 0.2, 0.3, 0.4, 0.5]]), +})) + +describe('Knowledge Document Chunks API Route', () => { + const mockAuth$ = mockAuth() + + const mockDbChain = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + orderBy: vi.fn().mockReturnThis(), + limit: vi.fn().mockReturnThis(), + offset: vi.fn().mockReturnThis(), + insert: vi.fn().mockReturnThis(), + values: vi.fn().mockResolvedValue(undefined), + update: vi.fn().mockReturnThis(), + set: vi.fn().mockReturnThis(), + returning: vi.fn().mockResolvedValue([]), + delete: vi.fn().mockReturnThis(), + transaction: vi.fn(), + } + + const mockGetUserId = vi.fn() + + beforeEach(async () => { + vi.clearAllMocks() + + vi.doMock('@/db', () => ({ + db: mockDbChain, + })) + + vi.doMock('@/app/api/auth/oauth/utils', () => ({ + getUserId: mockGetUserId, + })) + + Object.values(mockDbChain).forEach((fn) => { + if (typeof fn === 'function' && fn !== mockDbChain.values && fn !== mockDbChain.returning) { + fn.mockClear().mockReturnThis() + } + }) + + vi.stubGlobal('crypto', { + randomUUID: vi.fn().mockReturnValue('mock-chunk-uuid-1234'), + createHash: vi.fn().mockReturnValue({ + update: vi.fn().mockReturnThis(), + digest: vi.fn().mockReturnValue('mock-hash-123'), + }), + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('POST /api/knowledge/[id]/documents/[documentId]/chunks', () => { + const validChunkData = { + content: 'This is test chunk content for uploading to the knowledge base document.', + enabled: true, + } + + const mockDocumentAccess = { + hasAccess: true, + notFound: false, + reason: '', + document: { + id: 'doc-123', + processingStatus: 'completed', + tag1: 'tag1-value', + tag2: 'tag2-value', + tag3: null, + tag4: null, + tag5: null, + tag6: null, + tag7: null, + }, + } + + const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' }) + + it('should create chunk successfully with cost tracking', async () => { + const { checkDocumentAccess } = await import('../../../../utils') + const { estimateTokenCount } = await import('@/lib/tokenization/estimators') + const { calculateCost } = await import('@/providers/utils') + + mockGetUserId.mockResolvedValue('user-123') + vi.mocked(checkDocumentAccess).mockResolvedValue(mockDocumentAccess as DocumentAccessCheck) + + // Mock transaction + const mockTx = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + orderBy: vi.fn().mockReturnThis(), + limit: vi.fn().mockResolvedValue([{ chunkIndex: 0 }]), + insert: vi.fn().mockReturnThis(), + values: vi.fn().mockResolvedValue(undefined), + update: vi.fn().mockReturnThis(), + set: vi.fn().mockReturnThis(), + } + + mockDbChain.transaction.mockImplementation(async (callback) => { + return await callback(mockTx) + }) + + const req = createMockRequest('POST', validChunkData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + + // Verify cost tracking + expect(data.data.cost).toBeDefined() + expect(data.data.cost.input).toBe(0.00000904) + expect(data.data.cost.output).toBe(0) + expect(data.data.cost.total).toBe(0.00000904) + expect(data.data.cost.tokens).toEqual({ + prompt: 452, + completion: 0, + total: 452, + }) + expect(data.data.cost.model).toBe('text-embedding-3-small') + expect(data.data.cost.pricing).toEqual({ + input: 0.02, + output: 0, + updatedAt: '2025-07-10', + }) + + // Verify function calls + expect(estimateTokenCount).toHaveBeenCalledWith(validChunkData.content, 'openai') + expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 452, 0, false) + }) + + it('should handle workflow-based authentication', async () => { + const { checkDocumentAccess } = await import('../../../../utils') + + const workflowData = { + ...validChunkData, + workflowId: 'workflow-123', + } + + mockGetUserId.mockResolvedValue('user-123') + vi.mocked(checkDocumentAccess).mockResolvedValue(mockDocumentAccess as DocumentAccessCheck) + + const mockTx = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + orderBy: vi.fn().mockReturnThis(), + limit: vi.fn().mockResolvedValue([]), + insert: vi.fn().mockReturnThis(), + values: vi.fn().mockResolvedValue(undefined), + update: vi.fn().mockReturnThis(), + set: vi.fn().mockReturnThis(), + } + + mockDbChain.transaction.mockImplementation(async (callback) => { + return await callback(mockTx) + }) + + const req = createMockRequest('POST', workflowData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(mockGetUserId).toHaveBeenCalledWith(expect.any(String), 'workflow-123') + }) + + it.concurrent('should return unauthorized for unauthenticated request', async () => { + mockGetUserId.mockResolvedValue(null) + + const req = createMockRequest('POST', validChunkData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should return not found for workflow that does not exist', async () => { + const workflowData = { + ...validChunkData, + workflowId: 'nonexistent-workflow', + } + + mockGetUserId.mockResolvedValue(null) + + const req = createMockRequest('POST', workflowData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Workflow not found') + }) + + it.concurrent('should return not found for document access denied', async () => { + const { checkDocumentAccess } = await import('../../../../utils') + + mockGetUserId.mockResolvedValue('user-123') + vi.mocked(checkDocumentAccess).mockResolvedValue({ + hasAccess: false, + notFound: true, + reason: 'Document not found', + }) + + const req = createMockRequest('POST', validChunkData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Document not found') + }) + + it('should return unauthorized for unauthorized document access', async () => { + const { checkDocumentAccess } = await import('../../../../utils') + + mockGetUserId.mockResolvedValue('user-123') + vi.mocked(checkDocumentAccess).mockResolvedValue({ + hasAccess: false, + notFound: false, + reason: 'Unauthorized access', + }) + + const req = createMockRequest('POST', validChunkData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should reject chunks for failed documents', async () => { + const { checkDocumentAccess } = await import('../../../../utils') + + mockGetUserId.mockResolvedValue('user-123') + vi.mocked(checkDocumentAccess).mockResolvedValue({ + ...mockDocumentAccess, + document: { + ...mockDocumentAccess.document!, + processingStatus: 'failed', + }, + } as DocumentAccessCheck) + + const req = createMockRequest('POST', validChunkData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toBe('Cannot add chunks to failed document') + }) + + it.concurrent('should validate chunk data', async () => { + const { checkDocumentAccess } = await import('../../../../utils') + + mockGetUserId.mockResolvedValue('user-123') + vi.mocked(checkDocumentAccess).mockResolvedValue(mockDocumentAccess as DocumentAccessCheck) + + const invalidData = { + content: '', // Empty content + enabled: true, + } + + const req = createMockRequest('POST', invalidData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toBe('Invalid request data') + expect(data.details).toBeDefined() + }) + + it('should inherit tags from parent document', async () => { + const { checkDocumentAccess } = await import('../../../../utils') + + mockGetUserId.mockResolvedValue('user-123') + vi.mocked(checkDocumentAccess).mockResolvedValue(mockDocumentAccess as DocumentAccessCheck) + + const mockTx = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + orderBy: vi.fn().mockReturnThis(), + limit: vi.fn().mockResolvedValue([]), + insert: vi.fn().mockReturnThis(), + values: vi.fn().mockImplementation((data) => { + // Verify that tags are inherited from document + expect(data.tag1).toBe('tag1-value') + expect(data.tag2).toBe('tag2-value') + expect(data.tag3).toBe(null) + return Promise.resolve(undefined) + }), + update: vi.fn().mockReturnThis(), + set: vi.fn().mockReturnThis(), + } + + mockDbChain.transaction.mockImplementation(async (callback) => { + return await callback(mockTx) + }) + + const req = createMockRequest('POST', validChunkData) + const { POST } = await import('./route') + await POST(req, { params: mockParams }) + + expect(mockTx.values).toHaveBeenCalled() + }) + + it.concurrent('should handle cost calculation with different content lengths', async () => { + const { estimateTokenCount } = await import('@/lib/tokenization/estimators') + const { calculateCost } = await import('@/providers/utils') + const { checkDocumentAccess } = await import('../../../../utils') + + // Mock larger content with more tokens + vi.mocked(estimateTokenCount).mockReturnValue({ + count: 1000, + confidence: 'high', + provider: 'openai', + method: 'precise', + }) + vi.mocked(calculateCost).mockReturnValue({ + input: 0.00002, + output: 0, + total: 0.00002, + pricing: { + input: 0.02, + output: 0, + updatedAt: '2025-07-10', + }, + }) + + const largeChunkData = { + content: + 'This is a much larger chunk of content that would result in significantly more tokens when processed through the OpenAI tokenization system for embedding generation. This content is designed to test the cost calculation accuracy with larger input sizes.', + enabled: true, + } + + mockGetUserId.mockResolvedValue('user-123') + vi.mocked(checkDocumentAccess).mockResolvedValue(mockDocumentAccess as DocumentAccessCheck) + + const mockTx = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + orderBy: vi.fn().mockReturnThis(), + limit: vi.fn().mockResolvedValue([]), + insert: vi.fn().mockReturnThis(), + values: vi.fn().mockResolvedValue(undefined), + update: vi.fn().mockReturnThis(), + set: vi.fn().mockReturnThis(), + } + + mockDbChain.transaction.mockImplementation(async (callback) => { + return await callback(mockTx) + }) + + const req = createMockRequest('POST', largeChunkData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.data.cost.input).toBe(0.00002) + expect(data.data.cost.tokens.prompt).toBe(1000) + expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 1000, 0, false) + }) + }) +}) diff --git a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts index ace701827f..91c666e724 100644 --- a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts +++ b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts @@ -4,9 +4,11 @@ import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console-logger' +import { estimateTokenCount } from '@/lib/tokenization/estimators' import { getUserId } from '@/app/api/auth/oauth/utils' import { db } from '@/db' import { document, embedding } from '@/db/schema' +import { calculateCost } from '@/providers/utils' import { checkDocumentAccess, generateEmbeddings } from '../../../../utils' const logger = createLogger('DocumentChunksAPI') @@ -217,6 +219,9 @@ export async function POST( logger.info(`[${requestId}] Generating embedding for manual chunk`) const embeddings = await generateEmbeddings([validatedData.content]) + // Calculate accurate token count for both database storage and cost calculation + const tokenCount = estimateTokenCount(validatedData.content, 'openai') + const chunkId = crypto.randomUUID() const now = new Date() @@ -240,7 +245,7 @@ export async function POST( chunkHash: crypto.createHash('sha256').update(validatedData.content).digest('hex'), content: validatedData.content, contentLength: validatedData.content.length, - tokenCount: Math.ceil(validatedData.content.length / 4), // Rough approximation + tokenCount: tokenCount.count, // Use accurate token count embedding: embeddings[0], embeddingModel: 'text-embedding-3-small', startOffset: 0, // Manual chunks don't have document offsets @@ -276,9 +281,38 @@ export async function POST( logger.info(`[${requestId}] Manual chunk created: ${chunkId} in document ${documentId}`) + // Calculate cost for the embedding (with fallback if calculation fails) + let cost = null + try { + cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false) + } catch (error) { + logger.warn(`[${requestId}] Failed to calculate cost for chunk upload`, { + error: error instanceof Error ? error.message : 'Unknown error', + }) + // Continue without cost information rather than failing the upload + } + return NextResponse.json({ success: true, - data: newChunk, + data: { + ...newChunk, + ...(cost + ? { + cost: { + input: cost.input, + output: cost.output, + total: cost.total, + tokens: { + prompt: tokenCount.count, + completion: 0, + total: tokenCount.count, + }, + model: 'text-embedding-3-small', + pricing: cost.pricing, + }, + } + : {}), + }, }) } catch (validationError) { if (validationError instanceof z.ZodError) { diff --git a/apps/sim/app/api/knowledge/route.ts b/apps/sim/app/api/knowledge/route.ts index 8d57c3a710..de47ffc995 100644 --- a/apps/sim/app/api/knowledge/route.ts +++ b/apps/sim/app/api/knowledge/route.ts @@ -8,7 +8,6 @@ import { document, knowledgeBase } from '@/db/schema' const logger = createLogger('KnowledgeBaseAPI') -// Schema for knowledge base creation const CreateKnowledgeBaseSchema = z.object({ name: z.string().min(1, 'Name is required'), description: z.string().optional(), diff --git a/apps/sim/app/api/knowledge/search/route.test.ts b/apps/sim/app/api/knowledge/search/route.test.ts index 8cf86c202c..1824d9b708 100644 --- a/apps/sim/app/api/knowledge/search/route.test.ts +++ b/apps/sim/app/api/knowledge/search/route.test.ts @@ -34,6 +34,23 @@ vi.mock('@/lib/documents/utils', () => ({ retryWithExponentialBackoff: vi.fn().mockImplementation((fn) => fn()), })) +vi.mock('@/lib/tokenization/estimators', () => ({ + estimateTokenCount: vi.fn().mockReturnValue({ count: 521 }), +})) + +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn().mockReturnValue({ + input: 0.00001042, + output: 0, + total: 0.00001042, + pricing: { + input: 0.02, + output: 0, + updatedAt: '2025-07-10', + }, + }), +})) + mockConsoleLogger() describe('Knowledge Search API Route', () => { @@ -206,7 +223,7 @@ describe('Knowledge Search API Route', () => { expect(mockGetUserId).toHaveBeenCalledWith(expect.any(String), 'workflow-123') }) - it('should return unauthorized for unauthenticated request', async () => { + it.concurrent('should return unauthorized for unauthenticated request', async () => { mockGetUserId.mockResolvedValue(null) const req = createMockRequest('POST', validSearchData) @@ -218,7 +235,7 @@ describe('Knowledge Search API Route', () => { expect(data.error).toBe('Unauthorized') }) - it('should return not found for workflow that does not exist', async () => { + it.concurrent('should return not found for workflow that does not exist', async () => { const workflowData = { ...validSearchData, workflowId: 'nonexistent-workflow', @@ -268,7 +285,7 @@ describe('Knowledge Search API Route', () => { expect(data.error).toBe('Knowledge bases not found: kb-missing') }) - it('should validate search parameters', async () => { + it.concurrent('should validate search parameters', async () => { const invalidData = { knowledgeBaseIds: '', // Empty string query: '', // Empty query @@ -314,7 +331,7 @@ describe('Knowledge Search API Route', () => { expect(data.data.topK).toBe(10) // Default value }) - it('should handle OpenAI API errors', async () => { + it.concurrent('should handle OpenAI API errors', async () => { mockGetUserId.mockResolvedValue('user-123') mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases) @@ -334,7 +351,7 @@ describe('Knowledge Search API Route', () => { expect(data.error).toBe('Failed to perform vector search') }) - it('should handle missing OpenAI API key', async () => { + it.concurrent('should handle missing OpenAI API key', async () => { vi.doMock('@/lib/env', () => ({ env: { OPENAI_API_KEY: undefined, @@ -353,7 +370,7 @@ describe('Knowledge Search API Route', () => { expect(data.error).toBe('Failed to perform vector search') }) - it('should handle database errors during search', async () => { + it.concurrent('should handle database errors during search', async () => { mockGetUserId.mockResolvedValue('user-123') mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases) mockDbChain.limit.mockRejectedValueOnce(new Error('Database error')) @@ -375,7 +392,7 @@ describe('Knowledge Search API Route', () => { expect(data.error).toBe('Failed to perform vector search') }) - it('should handle invalid OpenAI response format', async () => { + it.concurrent('should handle invalid OpenAI response format', async () => { mockGetUserId.mockResolvedValue('user-123') mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases) @@ -395,5 +412,124 @@ describe('Knowledge Search API Route', () => { expect(response.status).toBe(500) expect(data.error).toBe('Failed to perform vector search') }) + + describe('Cost tracking', () => { + it.concurrent('should include cost information in successful search response', async () => { + mockGetUserId.mockResolvedValue('user-123') + mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) + mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) + + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + data: [{ embedding: mockEmbedding }], + }), + }) + + const req = createMockRequest('POST', validSearchData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + + // Verify cost information is included + expect(data.data.cost).toBeDefined() + expect(data.data.cost.input).toBe(0.00001042) + expect(data.data.cost.output).toBe(0) + expect(data.data.cost.total).toBe(0.00001042) + expect(data.data.cost.tokens).toEqual({ + prompt: 521, + completion: 0, + total: 521, + }) + expect(data.data.cost.model).toBe('text-embedding-3-small') + expect(data.data.cost.pricing).toEqual({ + input: 0.02, + output: 0, + updatedAt: '2025-07-10', + }) + }) + + it('should call cost calculation functions with correct parameters', async () => { + const { estimateTokenCount } = await import('@/lib/tokenization/estimators') + const { calculateCost } = await import('@/providers/utils') + + mockGetUserId.mockResolvedValue('user-123') + mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) + mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) + + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + data: [{ embedding: mockEmbedding }], + }), + }) + + const req = createMockRequest('POST', validSearchData) + const { POST } = await import('./route') + await POST(req) + + // Verify token estimation was called with correct parameters + expect(estimateTokenCount).toHaveBeenCalledWith('test search query', 'openai') + + // Verify cost calculation was called with correct parameters + expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 521, 0, false) + }) + + it('should handle cost calculation with different query lengths', async () => { + const { estimateTokenCount } = await import('@/lib/tokenization/estimators') + const { calculateCost } = await import('@/providers/utils') + + // Mock different token count for longer query + vi.mocked(estimateTokenCount).mockReturnValue({ + count: 1042, + confidence: 'high', + provider: 'openai', + method: 'precise', + }) + vi.mocked(calculateCost).mockReturnValue({ + input: 0.00002084, + output: 0, + total: 0.00002084, + pricing: { + input: 0.02, + output: 0, + updatedAt: '2025-07-10', + }, + }) + + const longQueryData = { + ...validSearchData, + query: + 'This is a much longer search query with many more tokens to test cost calculation accuracy', + } + + mockGetUserId.mockResolvedValue('user-123') + mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) + mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) + + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + data: [{ embedding: mockEmbedding }], + }), + }) + + const req = createMockRequest('POST', longQueryData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.data.cost.input).toBe(0.00002084) + expect(data.data.cost.tokens.prompt).toBe(1042) + expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 1042, 0, false) + }) + }) }) }) diff --git a/apps/sim/app/api/knowledge/search/route.ts b/apps/sim/app/api/knowledge/search/route.ts index 7be9e4b169..6ae89e8449 100644 --- a/apps/sim/app/api/knowledge/search/route.ts +++ b/apps/sim/app/api/knowledge/search/route.ts @@ -4,13 +4,14 @@ import { z } from 'zod' import { retryWithExponentialBackoff } from '@/lib/documents/utils' import { env } from '@/lib/env' import { createLogger } from '@/lib/logs/console-logger' +import { estimateTokenCount } from '@/lib/tokenization/estimators' import { getUserId } from '@/app/api/auth/oauth/utils' import { db } from '@/db' import { embedding, knowledgeBase } from '@/db/schema' +import { calculateCost } from '@/providers/utils' const logger = createLogger('VectorSearchAPI') -// Helper function to create tag filters function getTagFilters(filters: Record, embedding: any) { return Object.entries(filters).map(([key, value]) => { switch (key) { @@ -51,7 +52,6 @@ const VectorSearchSchema = z.object({ ]), query: z.string().min(1, 'Search query is required'), topK: z.number().min(1).max(100).default(10), - // Tag filters for pre-filtering filters: z .object({ tag1: z.string().optional(), @@ -166,7 +166,6 @@ async function executeParallelQueries( eq(embedding.knowledgeBaseId, kbId), eq(embedding.enabled, true), sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`, - // Apply tag filters if provided (case-insensitive) ...(filters ? getTagFilters(filters, embedding) : []) ) ) @@ -208,7 +207,6 @@ async function executeSingleQuery( inArray(embedding.knowledgeBaseId, knowledgeBaseIds), eq(embedding.enabled, true), sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`, - // Apply tag filters if provided (case-insensitive) ...(filters ? Object.entries(filters).map(([key, value]) => { switch (key) { @@ -321,6 +319,19 @@ export async function POST(request: NextRequest) { ) } + // Calculate cost for the embedding (with fallback if calculation fails) + let cost = null + let tokenCount = null + try { + tokenCount = estimateTokenCount(validatedData.query, 'openai') + cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false) + } catch (error) { + logger.warn(`[${requestId}] Failed to calculate cost for search query`, { + error: error instanceof Error ? error.message : 'Unknown error', + }) + // Continue without cost information rather than failing the search + } + return NextResponse.json({ success: true, data: { @@ -343,6 +354,22 @@ export async function POST(request: NextRequest) { knowledgeBaseId: foundKbIds[0], topK: validatedData.topK, totalResults: results.length, + ...(cost && tokenCount + ? { + cost: { + input: cost.input, + output: cost.output, + total: cost.total, + tokens: { + prompt: tokenCount.count, + completion: 0, + total: tokenCount.count, + }, + model: 'text-embedding-3-small', + pricing: cost.pricing, + }, + } + : {}), }, }) } catch (validationError) { diff --git a/apps/sim/executor/handlers/generic/generic-handler.test.ts b/apps/sim/executor/handlers/generic/generic-handler.test.ts index 34a74842a0..2cd7b33c7a 100644 --- a/apps/sim/executor/handlers/generic/generic-handler.test.ts +++ b/apps/sim/executor/handlers/generic/generic-handler.test.ts @@ -73,7 +73,7 @@ describe('GenericBlockHandler', () => { mockExecuteTool.mockResolvedValue({ success: true, output: { customResult: 'OK' } }) }) - it('should always handle any block type', () => { + it.concurrent('should always handle any block type', () => { const agentBlock: SerializedBlock = { ...mockBlock, metadata: { id: 'agent' } } expect(handler.canHandle(agentBlock)).toBe(true) expect(handler.canHandle(mockBlock)).toBe(true) @@ -81,7 +81,7 @@ describe('GenericBlockHandler', () => { expect(handler.canHandle(noMetaIdBlock)).toBe(true) }) - it('should execute generic block by calling its associated tool', async () => { + it.concurrent('should execute generic block by calling its associated tool', async () => { const inputs = { param1: 'resolvedValue1' } const expectedToolParams = { ...inputs, @@ -133,7 +133,7 @@ describe('GenericBlockHandler', () => { expect(mockExecuteTool).toHaveBeenCalledTimes(2) // Called twice now }) - it('should handle tool execution errors with no specific message', async () => { + it.concurrent('should handle tool execution errors with no specific message', async () => { const inputs = { param1: 'value' } const errorResult = { success: false, output: {} } mockExecuteTool.mockResolvedValue(errorResult) @@ -142,4 +142,218 @@ describe('GenericBlockHandler', () => { 'Block execution of Some Custom Tool failed with no error message' ) }) + + describe('Knowledge block cost tracking', () => { + beforeEach(() => { + // Set up knowledge block mock + mockBlock = { + ...mockBlock, + config: { tool: 'knowledge_search', params: {} }, + } + + mockTool = { + ...mockTool, + id: 'knowledge_search', + name: 'Knowledge Search', + } + + mockGetTool.mockImplementation((toolId) => { + if (toolId === 'knowledge_search') { + return mockTool + } + return undefined + }) + }) + + it.concurrent( + 'should extract and restructure cost information from knowledge tools', + async () => { + const inputs = { query: 'test query' } + const mockToolResponse = { + success: true, + output: { + results: [], + query: 'test query', + totalResults: 0, + cost: { + input: 0.00001042, + output: 0, + total: 0.00001042, + tokens: { + prompt: 521, + completion: 0, + total: 521, + }, + model: 'text-embedding-3-small', + pricing: { + input: 0.02, + output: 0, + updatedAt: '2025-07-10', + }, + }, + }, + } + + mockExecuteTool.mockResolvedValue(mockToolResponse) + + const result = await handler.execute(mockBlock, inputs, mockContext) + + // Verify cost information is restructured correctly for enhanced logging + expect(result).toEqual({ + results: [], + query: 'test query', + totalResults: 0, + cost: { + input: 0.00001042, + output: 0, + total: 0.00001042, + }, + tokens: { + prompt: 521, + completion: 0, + total: 521, + }, + model: 'text-embedding-3-small', + }) + } + ) + + it.concurrent('should handle knowledge_upload_chunk cost information', async () => { + // Update to upload_chunk tool + mockBlock.config.tool = 'knowledge_upload_chunk' + mockTool.id = 'knowledge_upload_chunk' + mockTool.name = 'Knowledge Upload Chunk' + + mockGetTool.mockImplementation((toolId) => { + if (toolId === 'knowledge_upload_chunk') { + return mockTool + } + return undefined + }) + + const inputs = { content: 'test content' } + const mockToolResponse = { + success: true, + output: { + data: { + id: 'chunk-123', + content: 'test content', + chunkIndex: 0, + }, + message: 'Successfully uploaded chunk', + documentId: 'doc-123', + cost: { + input: 0.00000521, + output: 0, + total: 0.00000521, + tokens: { + prompt: 260, + completion: 0, + total: 260, + }, + model: 'text-embedding-3-small', + pricing: { + input: 0.02, + output: 0, + updatedAt: '2025-07-10', + }, + }, + }, + } + + mockExecuteTool.mockResolvedValue(mockToolResponse) + + const result = await handler.execute(mockBlock, inputs, mockContext) + + // Verify cost information is restructured correctly + expect(result).toEqual({ + data: { + id: 'chunk-123', + content: 'test content', + chunkIndex: 0, + }, + message: 'Successfully uploaded chunk', + documentId: 'doc-123', + cost: { + input: 0.00000521, + output: 0, + total: 0.00000521, + }, + tokens: { + prompt: 260, + completion: 0, + total: 260, + }, + model: 'text-embedding-3-small', + }) + }) + + it('should pass through output unchanged for knowledge tools without cost info', async () => { + const inputs = { query: 'test query' } + const mockToolResponse = { + success: true, + output: { + results: [], + query: 'test query', + totalResults: 0, + // No cost information + }, + } + + mockExecuteTool.mockResolvedValue(mockToolResponse) + + const result = await handler.execute(mockBlock, inputs, mockContext) + + // Should return original output without cost transformation + expect(result).toEqual({ + results: [], + query: 'test query', + totalResults: 0, + }) + }) + + it.concurrent('should not process cost info for non-knowledge tools', async () => { + // Set up non-knowledge tool + mockBlock.config.tool = 'some_other_tool' + mockTool.id = 'some_other_tool' + + mockGetTool.mockImplementation((toolId) => { + if (toolId === 'some_other_tool') { + return mockTool + } + return undefined + }) + + const inputs = { param: 'value' } + const mockToolResponse = { + success: true, + output: { + result: 'success', + cost: { + input: 0.001, + output: 0.002, + total: 0.003, + tokens: { prompt: 100, completion: 50, total: 150 }, + model: 'some-model', + }, + }, + } + + mockExecuteTool.mockResolvedValue(mockToolResponse) + + const result = await handler.execute(mockBlock, inputs, mockContext) + + // Should return original output without cost transformation + expect(result).toEqual({ + result: 'success', + cost: { + input: 0.001, + output: 0.002, + total: 0.003, + tokens: { prompt: 100, completion: 50, total: 150 }, + model: 'some-model', + }, + }) + }) + }) }) diff --git a/apps/sim/executor/handlers/generic/generic-handler.ts b/apps/sim/executor/handlers/generic/generic-handler.ts index 0c45b11b17..513f5f9906 100644 --- a/apps/sim/executor/handlers/generic/generic-handler.ts +++ b/apps/sim/executor/handlers/generic/generic-handler.ts @@ -59,7 +59,30 @@ export class GenericBlockHandler implements BlockHandler { throw error } - return result.output + // Extract cost information from tool response if available + const output = result.output + let cost = null + + // Check if the tool is a knowledge tool and has cost information + if (block.config.tool?.startsWith('knowledge_') && output?.cost) { + cost = output.cost + } + + // Return the output with cost information if available + if (cost) { + return { + ...output, + cost: { + input: cost.input, + output: cost.output, + total: cost.total, + }, + tokens: cost.tokens, + model: cost.model, + } + } + + return output } catch (error: any) { // Ensure we have a meaningful error message if (!error.message || error.message === 'undefined (undefined)') { diff --git a/apps/sim/providers/models.ts b/apps/sim/providers/models.ts index 3724d294cd..7861cb9718 100644 --- a/apps/sim/providers/models.ts +++ b/apps/sim/providers/models.ts @@ -647,3 +647,31 @@ export function updateOllamaModels(models: string[]): void { capabilities: {}, })) } + +/** + * Embedding model pricing - separate from chat models + */ +export const EMBEDDING_MODEL_PRICING: Record = { + 'text-embedding-3-small': { + input: 0.02, // $0.02 per 1M tokens + output: 0.0, + updatedAt: '2025-07-10', + }, + 'text-embedding-3-large': { + input: 0.13, // $0.13 per 1M tokens + output: 0.0, + updatedAt: '2025-07-10', + }, + 'text-embedding-ada-002': { + input: 0.1, // $0.1 per 1M tokens + output: 0.0, + updatedAt: '2025-07-10', + }, +} + +/** + * Get pricing for embedding models specifically + */ +export function getEmbeddingModelPricing(modelId: string): ModelPricing | null { + return EMBEDDING_MODEL_PRICING[modelId] || null +} diff --git a/apps/sim/providers/utils.ts b/apps/sim/providers/utils.ts index 845d82c37a..fdf58e92dd 100644 --- a/apps/sim/providers/utils.ts +++ b/apps/sim/providers/utils.ts @@ -9,6 +9,7 @@ import { googleProvider } from './google' import { groqProvider } from './groq' import { getComputerUseModels, + getEmbeddingModelPricing, getHostedModels as getHostedModelsFromDefinitions, getMaxTemperature as getMaxTempFromDefinitions, getModelPricing as getModelPricingFromDefinitions, @@ -444,7 +445,13 @@ export function calculateCost( completionTokens = 0, useCachedInput = false ) { - const pricing = getModelPricingFromDefinitions(model) + // First check if it's an embedding model + let pricing = getEmbeddingModelPricing(model) + + // If not found, check chat models + if (!pricing) { + pricing = getModelPricingFromDefinitions(model) + } // If no pricing found, return default pricing if (!pricing) { @@ -475,14 +482,32 @@ export function calculateCost( const costMultiplier = getCostMultiplier() + const finalInputCost = inputCost * costMultiplier + const finalOutputCost = outputCost * costMultiplier + const finalTotalCost = totalCost * costMultiplier + return { - input: Number.parseFloat((inputCost * costMultiplier).toFixed(6)), - output: Number.parseFloat((outputCost * costMultiplier).toFixed(6)), - total: Number.parseFloat((totalCost * costMultiplier).toFixed(6)), + input: Number.parseFloat(finalInputCost.toFixed(8)), // Use 8 decimal places for small costs + output: Number.parseFloat(finalOutputCost.toFixed(8)), + total: Number.parseFloat(finalTotalCost.toFixed(8)), pricing, } } +/** + * Get pricing information for a specific model (including embedding models) + */ +export function getModelPricing(modelId: string): any { + // First check if it's an embedding model + const embeddingPricing = getEmbeddingModelPricing(modelId) + if (embeddingPricing) { + return embeddingPricing + } + + // Then check chat models + return getModelPricingFromDefinitions(modelId) +} + /** * Format cost as a currency string * diff --git a/apps/sim/tools/knowledge/search.ts b/apps/sim/tools/knowledge/search.ts index 48b96649e5..f2d885e4f6 100644 --- a/apps/sim/tools/knowledge/search.ts +++ b/apps/sim/tools/knowledge/search.ts @@ -120,6 +120,7 @@ export const knowledgeSearchTool: ToolConfig = { results: data.results || [], query: data.query, totalResults: data.totalResults || 0, + cost: data.cost, }, } } catch (error: any) { @@ -129,6 +130,7 @@ export const knowledgeSearchTool: ToolConfig = { results: [], query: '', totalResults: 0, + cost: undefined, }, error: `Vector search failed: ${error.message || 'Unknown error'}`, } @@ -142,6 +144,7 @@ export const knowledgeSearchTool: ToolConfig = { results: [], query: '', totalResults: 0, + cost: undefined, }, error: errorMessage, } diff --git a/apps/sim/tools/knowledge/types.ts b/apps/sim/tools/knowledge/types.ts index ec3b686e2a..3bf3234f8e 100644 --- a/apps/sim/tools/knowledge/types.ts +++ b/apps/sim/tools/knowledge/types.ts @@ -13,6 +13,22 @@ export interface KnowledgeSearchResponse { results: KnowledgeSearchResult[] query: string totalResults: number + cost?: { + input: number + output: number + total: number + tokens: { + prompt: number + completion: number + total: number + } + model: string + pricing: { + input: number + output: number + updatedAt: string + } + } } error?: string } @@ -40,6 +56,22 @@ export interface KnowledgeUploadChunkResponse { data: KnowledgeUploadChunkResult message: string documentId: string + cost?: { + input: number + output: number + total: number + tokens: { + prompt: number + completion: number + total: number + } + model: string + pricing: { + input: number + output: number + updatedAt: string + } + } } error?: string } diff --git a/apps/sim/tools/knowledge/upload_chunk.ts b/apps/sim/tools/knowledge/upload_chunk.ts index 6543c63f32..27d3c2e49f 100644 --- a/apps/sim/tools/knowledge/upload_chunk.ts +++ b/apps/sim/tools/knowledge/upload_chunk.ts @@ -69,6 +69,7 @@ export const knowledgeUploadChunkTool: ToolConfig