fix(kb): added tiktoken for embedding token estimation (#1616)

* fix(kb): added tiktoken for embedding token estimation

* added missing mock
This commit is contained in:
Waleed
2025-10-13 11:53:50 -07:00
committed by GitHub
parent ec73e2e9ce
commit 1e81cd6850
10 changed files with 250 additions and 51 deletions

View File

@@ -9,23 +9,22 @@ import {
InvalidRequestError,
} from '@/app/api/files/utils'
// Allowlist of permitted file extensions for security
const ALLOWED_EXTENSIONS = new Set([
// Documents
'pdf',
'doc',
'docx',
'txt',
'md',
// Images (safe formats)
'png',
'jpg',
'jpeg',
'gif',
// Data files
'csv',
'xlsx',
'xls',
'json',
'yaml',
'yml',
])
/**
@@ -50,19 +49,16 @@ export async function POST(request: NextRequest) {
const formData = await request.formData()
// Check if multiple files are being uploaded or a single file
const files = formData.getAll('file') as File[]
if (!files || files.length === 0) {
throw new InvalidRequestError('No files provided')
}
// Get optional scoping parameters for execution-scoped storage
const workflowId = formData.get('workflowId') as string | null
const executionId = formData.get('executionId') as string | null
const workspaceId = formData.get('workspaceId') as string | null
// Log storage mode
const usingCloudStorage = isUsingCloudStorage()
logger.info(`Using storage mode: ${usingCloudStorage ? 'Cloud' : 'Local'} for file upload`)
@@ -74,7 +70,6 @@ export async function POST(request: NextRequest) {
const uploadResults = []
// Process each file
for (const file of files) {
const originalName = file.name
@@ -88,9 +83,7 @@ export async function POST(request: NextRequest) {
const bytes = await file.arrayBuffer()
const buffer = Buffer.from(bytes)
// For execution-scoped files, use the dedicated execution file storage
if (workflowId && executionId) {
// Use the dedicated execution file storage system
const { uploadExecutionFile } = await import('@/lib/workflows/execution-file-storage')
const userFile = await uploadExecutionFile(
{
@@ -107,13 +100,10 @@ export async function POST(request: NextRequest) {
continue
}
// Upload to cloud or local storage using the standard uploadFile function
try {
logger.info(`Uploading file: ${originalName}`)
const result = await uploadFile(buffer, originalName, file.type, file.size)
// Generate a presigned URL for cloud storage with appropriate expiry
// Regular files get 24 hours (execution files are handled above)
let presignedUrl: string | undefined
if (usingCloudStorage) {
try {
@@ -144,7 +134,6 @@ export async function POST(request: NextRequest) {
}
}
// Return all file information
if (uploadResults.length === 1) {
return NextResponse.json(uploadResults[0])
}
@@ -155,7 +144,6 @@ export async function POST(request: NextRequest) {
}
}
// Handle preflight requests
export async function OPTIONS() {
return createOptionsResponse()
}

View File

@@ -32,6 +32,7 @@ vi.stubGlobal(
vi.mock('@/lib/env', () => ({
env: {},
getEnv: (key: string) => process.env[key],
isTruthy: (value: string | boolean | number | undefined) =>
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
}))

View File

@@ -17,6 +17,7 @@ vi.mock('drizzle-orm', () => ({
vi.mock('@/lib/env', () => ({
env: { OPENAI_API_KEY: 'test-key' },
getEnv: (key: string) => process.env[key],
isTruthy: (value: string | boolean | number | undefined) =>
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
}))

View File

@@ -1,18 +1,29 @@
import * as yaml from 'js-yaml'
import { createLogger } from '@/lib/logs/console/logger'
import { getAccurateTokenCount } from '@/lib/tokenization'
import { estimateTokenCount } from '@/lib/tokenization/estimators'
import type { Chunk, ChunkerOptions } from './types'
const logger = createLogger('JsonYamlChunker')
function getTokenCount(text: string): number {
const estimate = estimateTokenCount(text)
return estimate.count
try {
return getAccurateTokenCount(text, 'text-embedding-3-small')
} catch (error) {
logger.warn('Tiktoken failed, falling back to estimation')
const estimate = estimateTokenCount(text)
return estimate.count
}
}
/**
* Configuration for JSON/YAML chunking
* Reduced limits to ensure we stay well under OpenAI's 8,191 token limit per embedding request
*/
const JSON_YAML_CHUNKING_CONFIG = {
TARGET_CHUNK_SIZE: 2000, // Target tokens per chunk
TARGET_CHUNK_SIZE: 1000, // Target tokens per chunk
MIN_CHUNK_SIZE: 100, // Minimum tokens per chunk
MAX_CHUNK_SIZE: 3000, // Maximum tokens per chunk
MAX_CHUNK_SIZE: 1500, // Maximum tokens per chunk
MAX_DEPTH_FOR_SPLITTING: 5, // Maximum depth to traverse for splitting
}
@@ -34,7 +45,6 @@ export class JsonYamlChunker {
return true
} catch {
try {
const yaml = require('js-yaml')
yaml.load(content)
return true
} catch {
@@ -48,9 +58,26 @@ export class JsonYamlChunker {
*/
async chunk(content: string): Promise<Chunk[]> {
try {
const data = JSON.parse(content)
return this.chunkStructuredData(data)
let data: any
try {
data = JSON.parse(content)
} catch {
data = yaml.load(content)
}
const chunks = this.chunkStructuredData(data)
const tokenCounts = chunks.map((c) => c.tokenCount)
const totalTokens = tokenCounts.reduce((a, b) => a + b, 0)
const maxTokens = Math.max(...tokenCounts)
const avgTokens = Math.round(totalTokens / chunks.length)
logger.info(
`JSON chunking complete: ${chunks.length} chunks, ${totalTokens} total tokens (avg: ${avgTokens}, max: ${maxTokens})`
)
return chunks
} catch (error) {
logger.info('JSON parsing failed, falling back to text chunking')
return this.chunkAsText(content)
}
}
@@ -102,7 +129,6 @@ export class JsonYamlChunker {
const itemTokens = getTokenCount(itemStr)
if (itemTokens > this.chunkSize) {
// Save current batch if it has items
if (currentBatch.length > 0) {
const batchContent = contextHeader + JSON.stringify(currentBatch, null, 2)
chunks.push({
@@ -134,7 +160,7 @@ export class JsonYamlChunker {
const batchContent = contextHeader + JSON.stringify(currentBatch, null, 2)
chunks.push({
text: batchContent,
tokenCount: currentTokens,
tokenCount: getTokenCount(batchContent),
metadata: {
startIndex: i - currentBatch.length,
endIndex: i - 1,
@@ -152,7 +178,7 @@ export class JsonYamlChunker {
const batchContent = contextHeader + JSON.stringify(currentBatch, null, 2)
chunks.push({
text: batchContent,
tokenCount: currentTokens,
tokenCount: getTokenCount(batchContent),
metadata: {
startIndex: arr.length - currentBatch.length,
endIndex: arr.length - 1,
@@ -194,12 +220,11 @@ export class JsonYamlChunker {
const valueTokens = getTokenCount(valueStr)
if (valueTokens > this.chunkSize) {
// Save current object if it has properties
if (Object.keys(currentObj).length > 0) {
const objContent = JSON.stringify(currentObj, null, 2)
chunks.push({
text: objContent,
tokenCount: currentTokens,
tokenCount: getTokenCount(objContent),
metadata: {
startIndex: 0,
endIndex: objContent.length,
@@ -230,7 +255,7 @@ export class JsonYamlChunker {
const objContent = JSON.stringify(currentObj, null, 2)
chunks.push({
text: objContent,
tokenCount: currentTokens,
tokenCount: getTokenCount(objContent),
metadata: {
startIndex: 0,
endIndex: objContent.length,
@@ -250,7 +275,7 @@ export class JsonYamlChunker {
const objContent = JSON.stringify(currentObj, null, 2)
chunks.push({
text: objContent,
tokenCount: currentTokens,
tokenCount: getTokenCount(objContent),
metadata: {
startIndex: 0,
endIndex: objContent.length,
@@ -262,7 +287,7 @@ export class JsonYamlChunker {
}
/**
* Fall back to text chunking if JSON parsing fails.
* Fall back to text chunking if JSON parsing fails
*/
private async chunkAsText(content: string): Promise<Chunk[]> {
const chunks: Chunk[] = []
@@ -308,7 +333,7 @@ export class JsonYamlChunker {
}
/**
* Static method for chunking JSON/YAML data with default options.
* Static method for chunking JSON/YAML data with default options
*/
static async chunkJsonYaml(content: string, options: ChunkerOptions = {}): Promise<Chunk[]> {
const chunker = new JsonYamlChunker(options)

View File

@@ -1,9 +1,12 @@
import { env } from '@/lib/env'
import { isRetryableError, retryWithExponentialBackoff } from '@/lib/knowledge/documents/utils'
import { createLogger } from '@/lib/logs/console/logger'
import { batchByTokenLimit, getTotalTokenCount } from '@/lib/tokenization'
const logger = createLogger('EmbeddingUtils')
const MAX_TOKENS_PER_REQUEST = 8000
export class EmbeddingAPIError extends Error {
public status: number
@@ -104,7 +107,8 @@ async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Prom
}
/**
* Generate embeddings for multiple texts with simple batching
* Generate embeddings for multiple texts with token-aware batching
* Uses tiktoken for token counting
*/
export async function generateEmbeddings(
texts: string[],
@@ -112,27 +116,45 @@ export async function generateEmbeddings(
): Promise<number[][]> {
const config = getEmbeddingConfig(embeddingModel)
logger.info(`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for embeddings generation`)
logger.info(
`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for embeddings generation (${texts.length} texts)`
)
const batches = batchByTokenLimit(texts, MAX_TOKENS_PER_REQUEST, embeddingModel)
logger.info(
`Split ${texts.length} texts into ${batches.length} batches (max ${MAX_TOKENS_PER_REQUEST} tokens per batch)`
)
// Reduced batch size to prevent API timeouts and improve reliability
const batchSize = 50 // Reduced from 100 to prevent issues with large documents
const allEmbeddings: number[][] = []
for (let i = 0; i < texts.length; i += batchSize) {
const batch = texts.slice(i, i + batchSize)
const batchEmbeddings = await callEmbeddingAPI(batch, config)
allEmbeddings.push(...batchEmbeddings)
for (let i = 0; i < batches.length; i++) {
const batch = batches[i]
const batchTokenCount = getTotalTokenCount(batch, embeddingModel)
logger.info(
`Generated embeddings for batch ${Math.floor(i / batchSize) + 1}/${Math.ceil(texts.length / batchSize)}`
`Processing batch ${i + 1}/${batches.length}: ${batch.length} texts, ${batchTokenCount} tokens`
)
// Add small delay between batches to avoid rate limiting
if (i + batchSize < texts.length) {
try {
const batchEmbeddings = await callEmbeddingAPI(batch, config)
allEmbeddings.push(...batchEmbeddings)
logger.info(
`Generated ${batchEmbeddings.length} embeddings for batch ${i + 1}/${batches.length}`
)
} catch (error) {
logger.error(`Failed to generate embeddings for batch ${i + 1}:`, error)
throw error
}
if (i + 1 < batches.length) {
await new Promise((resolve) => setTimeout(resolve, 100))
}
}
logger.info(`Successfully generated ${allEmbeddings.length} embeddings total`)
return allEmbeddings
}

View File

@@ -1,7 +1,8 @@
/**
* Token estimation functions for different providers
* Token estimation and accurate counting functions for different providers
*/
import { encoding_for_model, type Tiktoken } from 'tiktoken'
import { createLogger } from '@/lib/logs/console/logger'
import { MIN_TEXT_LENGTH_FOR_ESTIMATION, TOKENIZATION_CONFIG } from '@/lib/tokenization/constants'
import type { TokenEstimate } from '@/lib/tokenization/types'
@@ -9,6 +10,163 @@ import { getProviderConfig } from '@/lib/tokenization/utils'
const logger = createLogger('TokenizationEstimators')
const encodingCache = new Map<string, Tiktoken>()
/**
* Get or create a cached encoding for a model
*/
function getEncoding(modelName: string): Tiktoken {
if (encodingCache.has(modelName)) {
return encodingCache.get(modelName)!
}
try {
const encoding = encoding_for_model(modelName as Parameters<typeof encoding_for_model>[0])
encodingCache.set(modelName, encoding)
return encoding
} catch (error) {
logger.warn(`Failed to get encoding for model ${modelName}, falling back to cl100k_base`)
const encoding = encoding_for_model('gpt-4')
encodingCache.set(modelName, encoding)
return encoding
}
}
if (typeof process !== 'undefined') {
process.on('beforeExit', () => {
clearEncodingCache()
})
}
/**
* Get accurate token count for text using tiktoken
* This is the exact count OpenAI's API will use
*/
export function getAccurateTokenCount(text: string, modelName = 'text-embedding-3-small'): number {
if (!text || text.length === 0) {
return 0
}
try {
const encoding = getEncoding(modelName)
const tokens = encoding.encode(text)
return tokens.length
} catch (error) {
logger.error('Error counting tokens with tiktoken:', error)
return Math.ceil(text.length / 4)
}
}
/**
* Truncate text to a maximum token count
* Useful for handling texts that exceed model limits
*/
export function truncateToTokenLimit(
text: string,
maxTokens: number,
modelName = 'text-embedding-3-small'
): string {
if (!text || maxTokens <= 0) {
return ''
}
try {
const encoding = getEncoding(modelName)
const tokens = encoding.encode(text)
if (tokens.length <= maxTokens) {
return text
}
const truncatedTokens = tokens.slice(0, maxTokens)
const truncatedText = new TextDecoder().decode(encoding.decode(truncatedTokens))
logger.warn(
`Truncated text from ${tokens.length} to ${maxTokens} tokens (${text.length} to ${truncatedText.length} chars)`
)
return truncatedText
} catch (error) {
logger.error('Error truncating text:', error)
const maxChars = maxTokens * 4
return text.slice(0, maxChars)
}
}
/**
* Get token count for multiple texts (for batching decisions)
* Returns array of token counts in same order as input
*/
export function getTokenCountsForBatch(
texts: string[],
modelName = 'text-embedding-3-small'
): number[] {
return texts.map((text) => getAccurateTokenCount(text, modelName))
}
/**
* Calculate total tokens across multiple texts
*/
export function getTotalTokenCount(texts: string[], modelName = 'text-embedding-3-small'): number {
return texts.reduce((total, text) => total + getAccurateTokenCount(text, modelName), 0)
}
/**
* Batch texts by token count to stay within API limits
* Returns array of batches where each batch's total tokens <= maxTokensPerBatch
*/
export function batchByTokenLimit(
texts: string[],
maxTokensPerBatch: number,
modelName = 'text-embedding-3-small'
): string[][] {
const batches: string[][] = []
let currentBatch: string[] = []
let currentTokenCount = 0
for (const text of texts) {
const tokenCount = getAccurateTokenCount(text, modelName)
if (tokenCount > maxTokensPerBatch) {
if (currentBatch.length > 0) {
batches.push(currentBatch)
currentBatch = []
currentTokenCount = 0
}
const truncated = truncateToTokenLimit(text, maxTokensPerBatch, modelName)
batches.push([truncated])
continue
}
if (currentBatch.length > 0 && currentTokenCount + tokenCount > maxTokensPerBatch) {
batches.push(currentBatch)
currentBatch = [text]
currentTokenCount = tokenCount
} else {
currentBatch.push(text)
currentTokenCount += tokenCount
}
}
if (currentBatch.length > 0) {
batches.push(currentBatch)
}
return batches
}
/**
* Clean up cached encodings (call when shutting down)
*/
export function clearEncodingCache(): void {
for (const encoding of encodingCache.values()) {
encoding.free()
}
encodingCache.clear()
logger.info('Cleared tiktoken encoding cache')
}
/**
* Estimates token count for text using provider-specific heuristics
*/
@@ -60,7 +218,6 @@ function estimateOpenAITokens(text: string): number {
for (const word of words) {
if (word.length === 0) continue
// GPT tokenizer characteristics based on BPE
if (word.length <= 4) {
tokenCount += 1
} else if (word.length <= 8) {
@@ -69,12 +226,10 @@ function estimateOpenAITokens(text: string): number {
tokenCount += Math.ceil(word.length / 4)
}
// Add extra tokens for punctuation
const punctuationCount = (word.match(/[.,!?;:"'()[\]{}<>]/g) || []).length
tokenCount += punctuationCount * 0.5
}
// Add tokens for newlines and formatting
const newlineCount = (text.match(/\n/g) || []).length
tokenCount += newlineCount * 0.5
@@ -91,7 +246,6 @@ function estimateAnthropicTokens(text: string): number {
for (const word of words) {
if (word.length === 0) continue
// Claude tokenizer tends to be slightly more efficient
if (word.length <= 4) {
tokenCount += 1
} else if (word.length <= 8) {
@@ -101,7 +255,6 @@ function estimateAnthropicTokens(text: string): number {
}
}
// Claude handles formatting slightly better
const newlineCount = (text.match(/\n/g) || []).length
tokenCount += newlineCount * 0.3
@@ -118,7 +271,6 @@ function estimateGoogleTokens(text: string): number {
for (const word of words) {
if (word.length === 0) continue
// Gemini tokenizer characteristics
if (word.length <= 5) {
tokenCount += 1
} else if (word.length <= 10) {

View File

@@ -6,9 +6,15 @@ export {
export { LLM_BLOCK_TYPES, TOKENIZATION_CONFIG } from '@/lib/tokenization/constants'
export { createTokenizationError, TokenizationError } from '@/lib/tokenization/errors'
export {
batchByTokenLimit,
clearEncodingCache,
estimateInputTokens,
estimateOutputTokens,
estimateTokenCount,
getAccurateTokenCount,
getTokenCountsForBatch,
getTotalTokenCount,
truncateToTokenLimit,
} from '@/lib/tokenization/estimators'
export { processStreamingBlockLog, processStreamingBlockLogs } from '@/lib/tokenization/streaming'
export type {

View File

@@ -75,7 +75,7 @@ const nextConfig: NextConfig = {
turbopack: {
resolveExtensions: ['.tsx', '.ts', '.jsx', '.js', '.mjs', '.json'],
},
serverExternalPackages: ['pdf-parse'],
serverExternalPackages: ['pdf-parse', 'tiktoken'],
experimental: {
optimizeCss: true,
turbopackSourceMaps: false,

View File

@@ -11,6 +11,7 @@
"postgres": "^3.4.5",
"remark-gfm": "4.0.1",
"socket.io-client": "4.8.1",
"tiktoken": "1.0.22",
"twilio": "5.9.0",
},
"devDependencies": {
@@ -2921,6 +2922,8 @@
"through": ["through@2.3.8", "", {}, "sha512-w89qg7PI8wAdvX60bMDP+bFoD5Dvhm9oLheFp5O4a2QF0cSBGsBX4qZmadPMvVqlLJBBci+WqGGOAPvcDeNSVg=="],
"tiktoken": ["tiktoken@1.0.22", "", {}, "sha512-PKvy1rVF1RibfF3JlXBSP0Jrcw2uq3yXdgcEXtKTYn3QJ/cBRBHDnrJ5jHky+MENZ6DIPwNUGWpkVx+7joCpNA=="],
"tiny-inflate": ["tiny-inflate@1.0.3", "", {}, "sha512-pkY1fj1cKHb2seWDy0B16HeWyczlJA9/WW3u3c4z/NiWDsO3DOU5D7nhTLE9CF0yXv/QZFY7sEJmj24dK+Rrqw=="],
"tinybench": ["tinybench@2.9.0", "", {}, "sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg=="],
@@ -3943,4 +3946,4 @@
"lint-staged/listr2/log-update/cli-cursor/restore-cursor/onetime": ["onetime@7.0.0", "", { "dependencies": { "mimic-function": "^5.0.0" } }, "sha512-VXJjc87FScF88uafS3JllDgvAm+c/Slfz06lorj2uAY34rlUu0Nt+v8wreiImcrgAjjIHp1rXpTDlLOGw29WwQ=="],
}
}
}

View File

@@ -42,6 +42,7 @@
"postgres": "^3.4.5",
"remark-gfm": "4.0.1",
"socket.io-client": "4.8.1",
"tiktoken": "1.0.22",
"twilio": "5.9.0"
},
"devDependencies": {