Files
sim/apps/sim/lib/tokenization/estimators.ts
Waleed 1edaf197b2 fix(azure): add azure-anthropic support to router, evaluator, copilot, and tokenization (#3158)
* fix(azure): add azure-anthropic support to router, evaluator, copilot, and tokenization

* added azure anthropic values to env

* fix(azure): make anthropic-version configurable for azure-anthropic provider

* fix(azure): thread provider credentials through guardrails and fix translate missing bedrockAccessKeyId

* updated guardrails

* ack'd PR comments

* fix(azure): unify credential passing pattern across all LLM handlers

- Pass all provider credentials unconditionally in router, evaluator (matching agent pattern)
- Remove conditional if-branching on providerId for credential fields
- Thread workspaceId through guardrails → hallucination validator for BYOK key resolution
- Remove getApiKey() from hallucination validator, let executeProviderRequest handle it
- Resolve vertex OAuth credentials in hallucination validator matching agent handler pattern

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 15:26:10 -08:00

342 lines
8.6 KiB
TypeScript

/**
* Token estimation and accurate counting functions for different providers
*/
import { createLogger } from '@sim/logger'
import { encodingForModel, type Tiktoken } from 'js-tiktoken'
import { MIN_TEXT_LENGTH_FOR_ESTIMATION, TOKENIZATION_CONFIG } from '@/lib/tokenization/constants'
import type { TokenEstimate } from '@/lib/tokenization/types'
import { getProviderConfig } from '@/lib/tokenization/utils'
const logger = createLogger('TokenizationEstimators')
const encodingCache = new Map<string, Tiktoken>()
/**
* Get or create a cached encoding for a model
*/
function getEncoding(modelName: string): Tiktoken {
if (encodingCache.has(modelName)) {
return encodingCache.get(modelName)!
}
try {
const encoding = encodingForModel(modelName as Parameters<typeof encodingForModel>[0])
encodingCache.set(modelName, encoding)
return encoding
} catch (error) {
logger.warn(`Failed to get encoding for model ${modelName}, falling back to cl100k_base`)
const encoding = encodingForModel('gpt-4')
encodingCache.set(modelName, encoding)
return encoding
}
}
if (typeof process !== 'undefined') {
process.on('beforeExit', () => {
clearEncodingCache()
})
}
/**
* Get accurate token count for text using tiktoken
* This is the exact count OpenAI's API will use
*/
export function getAccurateTokenCount(text: string, modelName = 'text-embedding-3-small'): number {
if (!text || text.length === 0) {
return 0
}
try {
const encoding = getEncoding(modelName)
const tokens = encoding.encode(text)
return tokens.length
} catch (error) {
logger.error('Error counting tokens with tiktoken:', error)
return Math.ceil(text.length / 4)
}
}
/**
* Get individual tokens as strings for visualization
* Returns an array of token strings that can be displayed with colors
*/
export function getTokenStrings(text: string, modelName = 'text-embedding-3-small'): string[] {
if (!text || text.length === 0) {
return []
}
try {
const encoding = getEncoding(modelName)
const tokenIds = encoding.encode(text)
const textChars = [...text]
const result: string[] = []
let prevCharCount = 0
for (let i = 0; i < tokenIds.length; i++) {
const decoded = encoding.decode(tokenIds.slice(0, i + 1))
const currentCharCount = [...decoded].length
const tokenCharCount = currentCharCount - prevCharCount
const tokenStr = textChars.slice(prevCharCount, prevCharCount + tokenCharCount).join('')
result.push(tokenStr)
prevCharCount = currentCharCount
}
return result
} catch (error) {
logger.error('Error getting token strings:', error)
return text.split(/(\s+)/).filter((s) => s.length > 0)
}
}
/**
* Truncate text to a maximum token count
* Useful for handling texts that exceed model limits
*/
export function truncateToTokenLimit(
text: string,
maxTokens: number,
modelName = 'text-embedding-3-small'
): string {
if (!text || maxTokens <= 0) {
return ''
}
try {
const encoding = getEncoding(modelName)
const tokens = encoding.encode(text)
if (tokens.length <= maxTokens) {
return text
}
const truncatedTokens = tokens.slice(0, maxTokens)
const truncatedText = encoding.decode(truncatedTokens)
logger.warn(
`Truncated text from ${tokens.length} to ${maxTokens} tokens (${text.length} to ${truncatedText.length} chars)`
)
return truncatedText
} catch (error) {
logger.error('Error truncating text:', error)
const maxChars = maxTokens * 4
return text.slice(0, maxChars)
}
}
/**
* Batch texts by token count to stay within API limits
* Returns array of batches where each batch's total tokens <= maxTokensPerBatch
*/
export function batchByTokenLimit(
texts: string[],
maxTokensPerBatch: number,
modelName = 'text-embedding-3-small'
): string[][] {
const batches: string[][] = []
let currentBatch: string[] = []
let currentTokenCount = 0
for (const text of texts) {
const tokenCount = getAccurateTokenCount(text, modelName)
if (tokenCount > maxTokensPerBatch) {
if (currentBatch.length > 0) {
batches.push(currentBatch)
currentBatch = []
currentTokenCount = 0
}
const truncated = truncateToTokenLimit(text, maxTokensPerBatch, modelName)
batches.push([truncated])
continue
}
if (currentBatch.length > 0 && currentTokenCount + tokenCount > maxTokensPerBatch) {
batches.push(currentBatch)
currentBatch = [text]
currentTokenCount = tokenCount
} else {
currentBatch.push(text)
currentTokenCount += tokenCount
}
}
if (currentBatch.length > 0) {
batches.push(currentBatch)
}
return batches
}
/**
* Clean up cached encodings (call when shutting down)
*/
export function clearEncodingCache(): void {
encodingCache.clear()
logger.info('Cleared tiktoken encoding cache')
}
/**
* Estimates token count for text using provider-specific heuristics
*/
export function estimateTokenCount(text: string, providerId?: string): TokenEstimate {
if (!text || text.length < MIN_TEXT_LENGTH_FOR_ESTIMATION) {
return {
count: 0,
confidence: 'high',
provider: providerId || 'unknown',
method: 'fallback',
}
}
const effectiveProviderId = providerId || TOKENIZATION_CONFIG.defaults.provider
const config = getProviderConfig(effectiveProviderId)
let estimatedTokens: number
switch (effectiveProviderId) {
case 'openai':
case 'azure-openai':
estimatedTokens = estimateOpenAITokens(text)
break
case 'anthropic':
case 'azure-anthropic':
estimatedTokens = estimateAnthropicTokens(text)
break
case 'google':
estimatedTokens = estimateGoogleTokens(text)
break
default:
estimatedTokens = estimateGenericTokens(text, config.avgCharsPerToken)
}
return {
count: Math.max(1, Math.round(estimatedTokens)),
confidence: config.confidence,
provider: effectiveProviderId,
method: 'heuristic',
}
}
/**
* OpenAI-specific token estimation using BPE characteristics
*/
function estimateOpenAITokens(text: string): number {
const words = text.trim().split(/\s+/)
let tokenCount = 0
for (const word of words) {
if (word.length === 0) continue
if (word.length <= 4) {
tokenCount += 1
} else if (word.length <= 8) {
tokenCount += Math.ceil(word.length / 4.5)
} else {
tokenCount += Math.ceil(word.length / 4)
}
const punctuationCount = (word.match(/[.,!?;:"'()[\]{}<>]/g) || []).length
tokenCount += punctuationCount * 0.5
}
const newlineCount = (text.match(/\n/g) || []).length
tokenCount += newlineCount * 0.5
return tokenCount
}
/**
* Anthropic Claude-specific token estimation
*/
function estimateAnthropicTokens(text: string): number {
const words = text.trim().split(/\s+/)
let tokenCount = 0
for (const word of words) {
if (word.length === 0) continue
if (word.length <= 4) {
tokenCount += 1
} else if (word.length <= 8) {
tokenCount += Math.ceil(word.length / 5)
} else {
tokenCount += Math.ceil(word.length / 4.5)
}
}
const newlineCount = (text.match(/\n/g) || []).length
tokenCount += newlineCount * 0.3
return tokenCount
}
/**
* Google Gemini-specific token estimation
*/
function estimateGoogleTokens(text: string): number {
const words = text.trim().split(/\s+/)
let tokenCount = 0
for (const word of words) {
if (word.length === 0) continue
if (word.length <= 5) {
tokenCount += 1
} else if (word.length <= 10) {
tokenCount += Math.ceil(word.length / 6)
} else {
tokenCount += Math.ceil(word.length / 5)
}
}
return tokenCount
}
/**
* Generic token estimation fallback
*/
function estimateGenericTokens(text: string, avgCharsPerToken: number): number {
const charCount = text.trim().length
return Math.ceil(charCount / avgCharsPerToken)
}
/**
* Estimates tokens for input content including context
*/
export function estimateInputTokens(
systemPrompt?: string,
context?: string,
messages?: Array<{ role: string; content: string }>,
providerId?: string
): TokenEstimate {
let totalText = ''
if (systemPrompt) {
totalText += `${systemPrompt}\n`
}
if (context) {
totalText += `${context}\n`
}
if (messages) {
for (const message of messages) {
totalText += `${message.role}: ${message.content}\n`
}
}
return estimateTokenCount(totalText, providerId)
}
/**
* Estimates tokens for output content
*/
export function estimateOutputTokens(content: string, providerId?: string): TokenEstimate {
return estimateTokenCount(content, providerId)
}